Unverified Commit 50b22da8 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Debug CUDA graph support with operation-based API (#1117)



* Debug CUDA graph support with operation-based API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Refactoring CUDA graph tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestions from @ptrendx

Return default recipe from FP8GlobalStateManager.get_fp8_recipe if needed. Expand error message when failing to load FP8 state after capturing CUDA graph.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid unnecessary recursion when saving/loading FP8 state
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix circular import
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent df949037
This diff is collapsed.
......@@ -277,7 +277,9 @@ class FP8GlobalStateManager:
@classmethod
def get_fp8_recipe(cls) -> DelayedScaling:
"""Return the fp8 recipe"""
if cls.FP8_RECIPE is not None:
return cls.FP8_RECIPE
return get_default_fp8_recipe()
@classmethod
def get_fp8_group(cls) -> Union[dist_group_type, None]:
......
......@@ -3,6 +3,7 @@
# See LICENSE for license information.
"""Functions for CUDA Graphs support in FP8"""
from collections.abc import Iterable
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
import torch
......@@ -18,7 +19,7 @@ from .fp8 import (
)
from .distributed import get_all_rng_states, graph_safe_rng_available
from .module.base import TransformerEngineBaseModule
from .ops.op import BasicOperation
__all__ = ["make_graphed_callables"]
......@@ -486,28 +487,46 @@ def _make_graphed_callables(
return tuple(ret)
def save_fp8_tensors(modules, amax_history_len):
def save_fp8_tensors(
modules: Iterable[torch.nn.Module],
fp8_recipe: DelayedScaling,
) -> List[Any]:
"""
Returns the FP8 tensors for all modules
with adjusted amax history sizes.
"""
saved_fp8_meta_tensors = []
fp8_tensors = []
for module in modules:
for m in module.modules():
module_tensors = None
if isinstance(m, TransformerEngineBaseModule):
if m.primary_weights_in_fp8:
m.adjust_amax_history_length(amax_history_len)
saved_fp8_meta_tensors.append(m.get_fp8_meta_tensors())
return saved_fp8_meta_tensors
def restore_fp8_tensors(modules, fp8_tensors):
m.adjust_amax_history_length(fp8_recipe.amax_history_len)
module_tensors = m.get_fp8_meta_tensors()
elif isinstance(m, BasicOperation):
m.pre_forward(fp8_enabled=True, fp8_recipe=fp8_recipe)
module_tensors = m._save_fp8_metas()
fp8_tensors.append(module_tensors)
return fp8_tensors
def restore_fp8_tensors(
modules: Iterable[torch.nn.Module],
fp8_tensors: List[Any],
) -> None:
"""Restore FP8 tensors."""
for module in modules:
for m in module.modules():
module_tensors = fp8_tensors.pop(0)
if isinstance(m, TransformerEngineBaseModule):
m.reset_fp8_meta_tensors(fp8_tensors.pop(0))
assert len(fp8_tensors) == 0, "TE internal error."
m.reset_fp8_meta_tensors(module_tensors)
elif isinstance(m, BasicOperation):
m._load_fp8_metas(module_tensors)
if len(fp8_tensors) != 0:
raise RuntimeError(
f"Got FP8 state for {len(fp8_tensors)} more modules than expected. "
"There is probably a discrepancy with `save_fp8_tensors`."
)
def make_graphed_callables(
......@@ -580,7 +599,7 @@ def make_graphed_callables(
modules = (modules,)
# Store FP8 tensors to reset later.
saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe.amax_history_len)
saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe=fp8_recipe)
# FP8 wrapper.
def wrap_autocast(block):
......
......@@ -308,8 +308,8 @@ class BasicLinear(BasicOperation):
weight = torch.nn.Parameter(weight)
self.weight = weight
def pre_forward(self) -> None:
super().pre_forward()
def pre_forward(self, *args, **kwargs) -> None:
super().pre_forward(*args, **kwargs)
if self.weight.device.type == "meta":
self.reset_parameters()
......
......@@ -111,8 +111,8 @@ class Bias(BasicOperation):
bias = torch.nn.Parameter(bias)
self.bias = bias
def pre_forward(self) -> None:
super().pre_forward()
def pre_forward(self, *args, **kwargs) -> None:
super().pre_forward(*args, **kwargs)
if self.bias.device.type == "meta":
self.reset_parameters()
......
......@@ -5,12 +5,12 @@
"""Manager class for a pipeline of fusible operations."""
from __future__ import annotations
from collections.abc import Callable
from typing import Any, Optional
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.graph import is_graph_capturing
from transformer_engine.pytorch.ops.op import (
BasicOperation,
FusibleOperation,
......@@ -28,6 +28,24 @@ def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]:
return t[:idx], t[idx:]
# Lazily imported function used in _is_graph_capturing
_is_graph_capturing_function: Optional[Callable[[], bool]] = None
def _is_graph_capturing() -> bool:
"""Whether function is called within `make_graphed_callables`
Avoid circular import with lazy import.
"""
global _is_graph_capturing_function
if _is_graph_capturing_function is None:
from ..graph import is_graph_capturing
_is_graph_capturing_function = is_graph_capturing
return _is_graph_capturing_function()
class _OperationFuserAutogradFunction(torch.autograd.Function):
"""Autograd function for a pipeline of operations
......@@ -255,7 +273,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
grad_extra_inputs_flat.extend(dxs)
# Update FP8 scaling factors
if func_ctx.is_first_module and not is_graph_capturing():
if func_ctx.is_first_module and not _is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
return (
......
......@@ -14,6 +14,7 @@ import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import (
DelayedScaling,
FP8GlobalStateManager,
get_default_fp8_recipe,
)
......@@ -231,25 +232,37 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
}
@classmethod
def _maybe_update_fp8_meta(cls, fp8_meta: Optional[dict[str, Any]]) -> None:
def _maybe_update_fp8_meta(
cls,
fp8_meta: Optional[dict[str, Any]],
*,
fp8_recipe: Optional[DelayedScaling] = None,
) -> None:
if fp8_meta is None:
return
# Update FP8 recipe and communication group
recipe = FP8GlobalStateManager.get_fp8_recipe()
fp8_meta["recipe"] = recipe
# Update FP8 recipe
if fp8_recipe is None:
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
fp8_meta["recipe"] = fp8_recipe
# Update FP8 communication group
fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
# Adjust amax history length if needed
amax_history_len = recipe.amax_history_len
amax_history_len = fp8_recipe.amax_history_len
for is_forward in (True, False):
key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward)
if key not in fp8_meta:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward)
if fp8_meta_key not in fp8_meta:
continue
meta = fp8_meta[key]
meta = fp8_meta[fp8_meta_key]
curr_len = meta.amax_history.size(0)
# Nothing to be done if amax history is already correct
if curr_len == amax_history_len:
continue
# Reallocate amax history
with torch.no_grad():
if curr_len > amax_history_len:
meta.amax_history = meta.amax_history[:amax_history_len].clone()
......@@ -259,6 +272,21 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
pad=(0, 0, 0, amax_history_len - curr_len),
)
# Update global buffers for amax reductions
buffer_info_key = FP8GlobalStateManager.get_buffer_info()
if buffer_info_key in fp8_meta:
fwd_pos, fwd_key, bwd_pos, bwd_key = fp8_meta[buffer_info_key]
for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)):
assert (
buffer_key in FP8GlobalStateManager.global_amax_history_buffer
), "TE internal error during amax history change."
FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = fp8_meta[
fp8_meta_key
].amax_history[0]
FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = fp8_meta[
fp8_meta_key
].amax_history
def get_fp8_meta(self, mode: str) -> Optional[dict[str, Any]]:
"""FP8 metadata
......@@ -272,10 +300,66 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
self._fp8_metas = self._make_fp8_metas()
return self._fp8_metas[mode]
def pre_forward(self) -> None:
@torch.no_grad()
def _save_fp8_metas(self) -> Optional[dict[str, Any]]:
"""Create copies of tensors in FP8 metadata
Tensor copies can be loaded with _load_fp8_metas.
"""
if self._fp8_metas is None:
return None
out = {}
for mode, fp8_meta in self._fp8_metas.items():
if fp8_meta is None:
continue
out[mode] = {}
for is_forward in (True, False):
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward)
if fp8_meta_key not in fp8_meta:
continue
out[mode][fp8_meta_key] = (
fp8_meta[fp8_meta_key].scale.clone(),
fp8_meta[fp8_meta_key].scale_inv.clone(),
fp8_meta[fp8_meta_key].amax_history.clone(),
)
return out
@torch.no_grad()
def _load_fp8_metas(self, fp8_metas: Optional[dict[str, Any]]) -> None:
"""Update FP8 metadata with saved tensor copies
Tensor copies should be generated with _save_fp8_metas.
"""
assert (self._fp8_metas is None) == (
fp8_metas is None
), "Saved FP8 metadata does not match operation's FP8 metadata"
if fp8_metas is None:
return
for mode, fp8_meta in fp8_metas.items():
assert (
mode in self._fp8_metas
), f"Found an unexpected key ({mode=}) in saved FP8 metadata"
for fp8_meta_key, tensors in fp8_meta.items():
assert (
fp8_meta_key in self._fp8_metas[mode]
), f"Found an unexpected key ({mode=}, {fp8_meta_key=}) in saved FP8 metadata"
scale, scale_inv, amax_history = tensors
self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale)
self._fp8_metas[mode][fp8_meta_key].scale_inv.copy_(scale_inv)
self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history)
def pre_forward(
self,
*,
fp8_enabled: Optional[bool] = None,
fp8_recipe: Optional[DelayedScaling] = None,
) -> None:
"""Preprocessing before forward pass"""
# Initialize FP8 metadata if needed
if fp8_enabled is None:
fp8_enabled = FP8GlobalStateManager.is_fp8_enabled()
if fp8_enabled:
......@@ -285,7 +369,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
# Make sure FP8 metadata matches FP8 autocast context
for fp8_meta in self._fp8_metas.values():
self._maybe_update_fp8_meta(fp8_meta)
self._maybe_update_fp8_meta(fp8_meta, fp8_recipe=fp8_recipe)
# Register FP8 metadata for amax and scale update
if not FP8GlobalStateManager.fp8_graph_capturing():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment