"vscode:/vscode.git/clone" did not exist on "a75da0ca15000f786d908a4d285f9e67edc3555c"
Unverified Commit 50b22da8 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Debug CUDA graph support with operation-based API (#1117)



* Debug CUDA graph support with operation-based API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Refactoring CUDA graph tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestions from @ptrendx

Return default recipe from FP8GlobalStateManager.get_fp8_recipe if needed. Expand error message when failing to load FP8 state after capturing CUDA graph.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid unnecessary recursion when saving/loading FP8 state
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix circular import
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent df949037
This diff is collapsed.
...@@ -277,7 +277,9 @@ class FP8GlobalStateManager: ...@@ -277,7 +277,9 @@ class FP8GlobalStateManager:
@classmethod @classmethod
def get_fp8_recipe(cls) -> DelayedScaling: def get_fp8_recipe(cls) -> DelayedScaling:
"""Return the fp8 recipe""" """Return the fp8 recipe"""
if cls.FP8_RECIPE is not None:
return cls.FP8_RECIPE return cls.FP8_RECIPE
return get_default_fp8_recipe()
@classmethod @classmethod
def get_fp8_group(cls) -> Union[dist_group_type, None]: def get_fp8_group(cls) -> Union[dist_group_type, None]:
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Functions for CUDA Graphs support in FP8""" """Functions for CUDA Graphs support in FP8"""
from collections.abc import Iterable
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
import torch import torch
...@@ -18,7 +19,7 @@ from .fp8 import ( ...@@ -18,7 +19,7 @@ from .fp8 import (
) )
from .distributed import get_all_rng_states, graph_safe_rng_available from .distributed import get_all_rng_states, graph_safe_rng_available
from .module.base import TransformerEngineBaseModule from .module.base import TransformerEngineBaseModule
from .ops.op import BasicOperation
__all__ = ["make_graphed_callables"] __all__ = ["make_graphed_callables"]
...@@ -486,28 +487,46 @@ def _make_graphed_callables( ...@@ -486,28 +487,46 @@ def _make_graphed_callables(
return tuple(ret) return tuple(ret)
def save_fp8_tensors(modules, amax_history_len): def save_fp8_tensors(
modules: Iterable[torch.nn.Module],
fp8_recipe: DelayedScaling,
) -> List[Any]:
""" """
Returns the FP8 tensors for all modules Returns the FP8 tensors for all modules
with adjusted amax history sizes. with adjusted amax history sizes.
""" """
saved_fp8_meta_tensors = [] fp8_tensors = []
for module in modules: for module in modules:
for m in module.modules(): for m in module.modules():
module_tensors = None
if isinstance(m, TransformerEngineBaseModule): if isinstance(m, TransformerEngineBaseModule):
if m.primary_weights_in_fp8: if m.primary_weights_in_fp8:
m.adjust_amax_history_length(amax_history_len) m.adjust_amax_history_length(fp8_recipe.amax_history_len)
saved_fp8_meta_tensors.append(m.get_fp8_meta_tensors()) module_tensors = m.get_fp8_meta_tensors()
return saved_fp8_meta_tensors elif isinstance(m, BasicOperation):
m.pre_forward(fp8_enabled=True, fp8_recipe=fp8_recipe)
module_tensors = m._save_fp8_metas()
def restore_fp8_tensors(modules, fp8_tensors): fp8_tensors.append(module_tensors)
return fp8_tensors
def restore_fp8_tensors(
modules: Iterable[torch.nn.Module],
fp8_tensors: List[Any],
) -> None:
"""Restore FP8 tensors.""" """Restore FP8 tensors."""
for module in modules: for module in modules:
for m in module.modules(): for m in module.modules():
module_tensors = fp8_tensors.pop(0)
if isinstance(m, TransformerEngineBaseModule): if isinstance(m, TransformerEngineBaseModule):
m.reset_fp8_meta_tensors(fp8_tensors.pop(0)) m.reset_fp8_meta_tensors(module_tensors)
assert len(fp8_tensors) == 0, "TE internal error." elif isinstance(m, BasicOperation):
m._load_fp8_metas(module_tensors)
if len(fp8_tensors) != 0:
raise RuntimeError(
f"Got FP8 state for {len(fp8_tensors)} more modules than expected. "
"There is probably a discrepancy with `save_fp8_tensors`."
)
def make_graphed_callables( def make_graphed_callables(
...@@ -580,7 +599,7 @@ def make_graphed_callables( ...@@ -580,7 +599,7 @@ def make_graphed_callables(
modules = (modules,) modules = (modules,)
# Store FP8 tensors to reset later. # Store FP8 tensors to reset later.
saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe.amax_history_len) saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe=fp8_recipe)
# FP8 wrapper. # FP8 wrapper.
def wrap_autocast(block): def wrap_autocast(block):
......
...@@ -308,8 +308,8 @@ class BasicLinear(BasicOperation): ...@@ -308,8 +308,8 @@ class BasicLinear(BasicOperation):
weight = torch.nn.Parameter(weight) weight = torch.nn.Parameter(weight)
self.weight = weight self.weight = weight
def pre_forward(self) -> None: def pre_forward(self, *args, **kwargs) -> None:
super().pre_forward() super().pre_forward(*args, **kwargs)
if self.weight.device.type == "meta": if self.weight.device.type == "meta":
self.reset_parameters() self.reset_parameters()
......
...@@ -111,8 +111,8 @@ class Bias(BasicOperation): ...@@ -111,8 +111,8 @@ class Bias(BasicOperation):
bias = torch.nn.Parameter(bias) bias = torch.nn.Parameter(bias)
self.bias = bias self.bias = bias
def pre_forward(self) -> None: def pre_forward(self, *args, **kwargs) -> None:
super().pre_forward() super().pre_forward(*args, **kwargs)
if self.bias.device.type == "meta": if self.bias.device.type == "meta":
self.reset_parameters() self.reset_parameters()
......
...@@ -5,12 +5,12 @@ ...@@ -5,12 +5,12 @@
"""Manager class for a pipeline of fusible operations.""" """Manager class for a pipeline of fusible operations."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable
from typing import Any, Optional from typing import Any, Optional
import torch import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.graph import is_graph_capturing
from transformer_engine.pytorch.ops.op import ( from transformer_engine.pytorch.ops.op import (
BasicOperation, BasicOperation,
FusibleOperation, FusibleOperation,
...@@ -28,6 +28,24 @@ def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]: ...@@ -28,6 +28,24 @@ def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]:
return t[:idx], t[idx:] return t[:idx], t[idx:]
# Lazily imported function used in _is_graph_capturing
_is_graph_capturing_function: Optional[Callable[[], bool]] = None
def _is_graph_capturing() -> bool:
"""Whether function is called within `make_graphed_callables`
Avoid circular import with lazy import.
"""
global _is_graph_capturing_function
if _is_graph_capturing_function is None:
from ..graph import is_graph_capturing
_is_graph_capturing_function = is_graph_capturing
return _is_graph_capturing_function()
class _OperationFuserAutogradFunction(torch.autograd.Function): class _OperationFuserAutogradFunction(torch.autograd.Function):
"""Autograd function for a pipeline of operations """Autograd function for a pipeline of operations
...@@ -255,7 +273,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -255,7 +273,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
grad_extra_inputs_flat.extend(dxs) grad_extra_inputs_flat.extend(dxs)
# Update FP8 scaling factors # Update FP8 scaling factors
if func_ctx.is_first_module and not is_graph_capturing(): if func_ctx.is_first_module and not _is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
return ( return (
......
...@@ -14,6 +14,7 @@ import torch ...@@ -14,6 +14,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import ( from transformer_engine.pytorch.fp8 import (
DelayedScaling,
FP8GlobalStateManager, FP8GlobalStateManager,
get_default_fp8_recipe, get_default_fp8_recipe,
) )
...@@ -231,25 +232,37 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -231,25 +232,37 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
} }
@classmethod @classmethod
def _maybe_update_fp8_meta(cls, fp8_meta: Optional[dict[str, Any]]) -> None: def _maybe_update_fp8_meta(
cls,
fp8_meta: Optional[dict[str, Any]],
*,
fp8_recipe: Optional[DelayedScaling] = None,
) -> None:
if fp8_meta is None: if fp8_meta is None:
return return
# Update FP8 recipe and communication group # Update FP8 recipe
recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_recipe is None:
fp8_meta["recipe"] = recipe fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
fp8_meta["recipe"] = fp8_recipe
# Update FP8 communication group
fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
# Adjust amax history length if needed # Adjust amax history length if needed
amax_history_len = recipe.amax_history_len amax_history_len = fp8_recipe.amax_history_len
for is_forward in (True, False): for is_forward in (True, False):
key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward) fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward)
if key not in fp8_meta: if fp8_meta_key not in fp8_meta:
continue continue
meta = fp8_meta[key] meta = fp8_meta[fp8_meta_key]
curr_len = meta.amax_history.size(0) curr_len = meta.amax_history.size(0)
# Nothing to be done if amax history is already correct
if curr_len == amax_history_len: if curr_len == amax_history_len:
continue continue
# Reallocate amax history
with torch.no_grad(): with torch.no_grad():
if curr_len > amax_history_len: if curr_len > amax_history_len:
meta.amax_history = meta.amax_history[:amax_history_len].clone() meta.amax_history = meta.amax_history[:amax_history_len].clone()
...@@ -259,6 +272,21 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -259,6 +272,21 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
pad=(0, 0, 0, amax_history_len - curr_len), pad=(0, 0, 0, amax_history_len - curr_len),
) )
# Update global buffers for amax reductions
buffer_info_key = FP8GlobalStateManager.get_buffer_info()
if buffer_info_key in fp8_meta:
fwd_pos, fwd_key, bwd_pos, bwd_key = fp8_meta[buffer_info_key]
for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)):
assert (
buffer_key in FP8GlobalStateManager.global_amax_history_buffer
), "TE internal error during amax history change."
FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = fp8_meta[
fp8_meta_key
].amax_history[0]
FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = fp8_meta[
fp8_meta_key
].amax_history
def get_fp8_meta(self, mode: str) -> Optional[dict[str, Any]]: def get_fp8_meta(self, mode: str) -> Optional[dict[str, Any]]:
"""FP8 metadata """FP8 metadata
...@@ -272,10 +300,66 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -272,10 +300,66 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
self._fp8_metas = self._make_fp8_metas() self._fp8_metas = self._make_fp8_metas()
return self._fp8_metas[mode] return self._fp8_metas[mode]
def pre_forward(self) -> None: @torch.no_grad()
def _save_fp8_metas(self) -> Optional[dict[str, Any]]:
"""Create copies of tensors in FP8 metadata
Tensor copies can be loaded with _load_fp8_metas.
"""
if self._fp8_metas is None:
return None
out = {}
for mode, fp8_meta in self._fp8_metas.items():
if fp8_meta is None:
continue
out[mode] = {}
for is_forward in (True, False):
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward)
if fp8_meta_key not in fp8_meta:
continue
out[mode][fp8_meta_key] = (
fp8_meta[fp8_meta_key].scale.clone(),
fp8_meta[fp8_meta_key].scale_inv.clone(),
fp8_meta[fp8_meta_key].amax_history.clone(),
)
return out
@torch.no_grad()
def _load_fp8_metas(self, fp8_metas: Optional[dict[str, Any]]) -> None:
"""Update FP8 metadata with saved tensor copies
Tensor copies should be generated with _save_fp8_metas.
"""
assert (self._fp8_metas is None) == (
fp8_metas is None
), "Saved FP8 metadata does not match operation's FP8 metadata"
if fp8_metas is None:
return
for mode, fp8_meta in fp8_metas.items():
assert (
mode in self._fp8_metas
), f"Found an unexpected key ({mode=}) in saved FP8 metadata"
for fp8_meta_key, tensors in fp8_meta.items():
assert (
fp8_meta_key in self._fp8_metas[mode]
), f"Found an unexpected key ({mode=}, {fp8_meta_key=}) in saved FP8 metadata"
scale, scale_inv, amax_history = tensors
self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale)
self._fp8_metas[mode][fp8_meta_key].scale_inv.copy_(scale_inv)
self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history)
def pre_forward(
self,
*,
fp8_enabled: Optional[bool] = None,
fp8_recipe: Optional[DelayedScaling] = None,
) -> None:
"""Preprocessing before forward pass""" """Preprocessing before forward pass"""
# Initialize FP8 metadata if needed # Initialize FP8 metadata if needed
if fp8_enabled is None:
fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() fp8_enabled = FP8GlobalStateManager.is_fp8_enabled()
if fp8_enabled: if fp8_enabled:
...@@ -285,7 +369,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -285,7 +369,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
# Make sure FP8 metadata matches FP8 autocast context # Make sure FP8 metadata matches FP8 autocast context
for fp8_meta in self._fp8_metas.values(): for fp8_meta in self._fp8_metas.values():
self._maybe_update_fp8_meta(fp8_meta) self._maybe_update_fp8_meta(fp8_meta, fp8_recipe=fp8_recipe)
# Register FP8 metadata for amax and scale update # Register FP8 metadata for amax and scale update
if not FP8GlobalStateManager.fp8_graph_capturing(): if not FP8GlobalStateManager.fp8_graph_capturing():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment