"vscode:/vscode.git/clone" did not exist on "32d1eb11854c73b984fb3fcd176adfd28a78a60f"
Unverified Commit 50b22da8 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Debug CUDA graph support with operation-based API (#1117)



* Debug CUDA graph support with operation-based API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Refactoring CUDA graph tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestions from @ptrendx

Return default recipe from FP8GlobalStateManager.get_fp8_recipe if needed. Expand error message when failing to load FP8 state after capturing CUDA graph.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid unnecessary recursion when saving/loading FP8 state
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix circular import
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent df949037
......@@ -13,24 +13,25 @@ from transformer_engine.pytorch import (
LayerNormLinear,
LayerNormMLP,
Linear,
make_graphed_callables,
MultiheadAttention,
TransformerLayer,
fp8_autocast,
fp8_model_init,
make_graphed_callables,
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine.pytorch.ops as te_ops
# Only run FP8 tests on H100.
# Check if FP8 is supported.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
# Record initial RNG state.
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Record initial RNG state from script run.
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
......@@ -48,17 +49,14 @@ class ModelConfig:
model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}
modules = ["transformer", "layernorm_mlp", "layernorm_linear", "linear", "mha", "dpa"]
all_boolean = [True, False]
dtypes = [torch.float32, torch.float16]
# Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
dtypes.append(torch.bfloat16)
def reset_rng_states() -> None:
"""revert back to initial RNG state."""
"""Revert to initial RNG state."""
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
......@@ -70,64 +68,40 @@ def reset_global_fp8_state():
def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
"""Ensures two lists are equal."""
"""Check that two lists of tensors match exactly."""
assert len(l1) == len(l2), "Unequal number of outputs."
failed = False
failed_tensors = ""
failure_message = "Output mismatches in:"
failed_tensors = []
for i, (t1, t2) in enumerate(zip(l1, l2)):
if not torch.equal(t1, t2):
failed = True
failed_tensors += (
f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n"
)
assert not failed, "Output mismatches in:\n" + failed_tensors
failure_message += "\n "
if names is None:
failure_message += f"tensor at idx={i}"
else:
failure_message += names[i]
failed_tensors.append((t1, t2))
if failed_tensors:
print(failure_message)
t1, t2 = failed_tensors[0]
torch.testing.assert_close(t1, t2, rtol=0, atol=0)
def generate_data(
config: ModelConfig,
model_config: ModelConfig,
dtype: torch.dtype,
dpa: bool = False,
warmup: bool = False,
return_grad_output: bool = False,
) -> Tuple[List[torch.Tensor], torch.Tensor]:
requires_grad: bool = True,
) -> torch.Tensor:
"""Generate synthetic data."""
gen_func = torch.ones if warmup else torch.randn
if dpa:
inputs = [
gen_func(
config.sequence_length,
config.batch_size,
config.num_heads,
config.kv_channels,
device="cuda",
requires_grad=True,
dtype=dtype,
)
for _ in range(3)
]
else:
inputs = [
gen_func(
config.sequence_length,
config.batch_size,
config.hidden_size,
device="cuda",
requires_grad=True,
dtype=dtype,
)
]
if not return_grad_output:
return inputs
grad_output = torch.randn(
config.sequence_length,
config.batch_size,
config.hidden_size,
return gen_func(
model_config.sequence_length,
model_config.batch_size,
model_config.hidden_size,
device="cuda",
requires_grad=requires_grad,
dtype=dtype,
)
return inputs, grad_output
def get_outputs(
......@@ -157,30 +131,44 @@ class _Sequential(torch.nn.Sequential):
return x
# Supported modules
_test_cuda_graphs_modules: List[str] = [
"transformer",
"layernorm_mlp",
"layernorm_linear",
"linear",
"mha",
"linear_op",
]
def _test_cuda_graphs(
*,
config: ModelConfig,
graph_mode: str,
module: str,
model_config: ModelConfig,
num_layers: int,
dtype: torch.dtype,
fp8: bool,
fp8_params: bool,
fp8_weight_caching: bool,
module: str,
graph_mode: str,
) -> List[torch.Tensor]:
"""Helper function for CUDA graph test."""
reset_rng_states()
FP8GlobalStateManager.reset()
dpa = module == "dpa"
# Operation-based API does not support FP8 weight caching.
if module == "linear_op":
fp8_weight_caching = False
# Create modules.
with fp8_model_init(enabled=fp8_params):
# Create modules.
if module == "transformer":
modules = [
TransformerLayer(
config.hidden_size,
config.hidden_size,
config.num_heads,
model_config.hidden_size,
model_config.hidden_size,
model_config.num_heads,
hidden_dropout=0.0,
attention_dropout=0.0,
fuse_qkv_params=True,
......@@ -190,37 +178,56 @@ def _test_cuda_graphs(
]
elif module == "layernorm_mlp":
modules = [
LayerNormMLP(config.hidden_size, config.hidden_size, params_dtype=dtype)
LayerNormMLP(
model_config.hidden_size,
model_config.hidden_size,
params_dtype=dtype,
)
for _ in range(num_layers)
]
elif module == "layernorm_linear":
modules = [
LayerNormLinear(config.hidden_size, config.hidden_size, params_dtype=dtype)
LayerNormLinear(
model_config.hidden_size,
model_config.hidden_size,
params_dtype=dtype,
)
for _ in range(num_layers)
]
elif module == "mha":
modules = [
MultiheadAttention(
config.hidden_size,
config.num_heads,
model_config.hidden_size,
model_config.num_heads,
attention_dropout=0.0,
params_dtype=dtype,
fuse_qkv_params=True,
)
for _ in range(num_layers)
]
elif dpa:
assert config.hidden_size % config.num_heads == 0, "Err."
assert num_layers == 1, "Err."
elif module == "linear":
modules = [
DotProductAttention(config.num_heads, config.kv_channels, attention_dropout=0.0)
Linear(
model_config.hidden_size,
model_config.hidden_size,
device="cuda",
params_dtype=dtype,
)
for _ in range(num_layers)
]
else:
elif module == "linear_op":
modules = [
Linear(config.hidden_size, config.hidden_size, device="cuda", params_dtype=dtype)
te_ops.Sequential(
te_ops.Linear(
model_config.hidden_size,
model_config.hidden_size,
dtype=dtype,
),
)
for _ in range(num_layers)
]
else:
raise ValueError(f"Unknown module type ({module})")
# Initialize gradient buffers.
for module in modules:
......@@ -230,111 +237,208 @@ def _test_cuda_graphs(
# Generate model and wrap API to return graphed version.
if graph_mode == "full":
# Graph entire model at once.
model = modules[0] if dpa else torch.nn.Sequential(*modules)
model = torch.nn.Sequential(*modules)
model = make_graphed_callables(
model,
generate_data(config, dtype, dpa=dpa, warmup=True),
(generate_data(model_config, dtype, warmup=True),),
num_warmup_iters=10,
fp8_enabled=fp8,
fp8_weight_caching=fp8_weight_caching,
)
elif graph_mode == "individual":
# Graph individual modules
# Graph individual modules.
modules = [
make_graphed_callables(
module,
generate_data(config, dtype, dpa=dpa, warmup=True),
(generate_data(model_config, dtype, warmup=True),),
num_warmup_iters=10,
fp8_enabled=fp8,
fp8_weight_caching=fp8_weight_caching,
)
for module in modules
]
model = modules[0] if dpa else _Sequential(*modules)
model = _Sequential(*modules)
else:
model = modules[0] if dpa else _Sequential(*modules)
model = _Sequential(*modules)
# Optimizer.
if not dpa:
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
# Launch.
# Training steps.
for _ in range(3):
if not dpa:
optimizer.zero_grad(set_to_none=False)
optimizer.zero_grad(set_to_none=False)
for grad_accumulation_step in range(2):
inputs, grad_output = generate_data(config, dtype, dpa=dpa, return_grad_output=True)
input_ = generate_data(model_config, dtype)
grad_output = generate_data(model_config, dtype, requires_grad=False)
with fp8_autocast(enabled=fp8):
kwargs = {}
if fp8_weight_caching:
kwargs["is_first_microbatch"] = grad_accumulation_step == 0
output = model(*inputs, **kwargs)
output = model(input_, **kwargs)
output.backward(grad_output)
if not dpa:
optimizer.step()
optimizer.step()
return get_outputs(model, output)
@pytest.mark.parametrize("module", _test_cuda_graphs_modules)
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("num_layers", [1, 3])
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("fp8_params", all_boolean)
@pytest.mark.parametrize("fp8_weight_caching", all_boolean)
@pytest.mark.parametrize("module", modules)
def test_gpt_make_graphed_callables(
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("fp8_params", (False, True))
def test_make_graphed_callables(
*,
module: str,
model_config: str = "small",
num_layers: int = 3,
dtype: torch.dtype,
model: str,
num_layers: int,
fp8: bool,
fp8_params: bool,
fp8_weight_caching: bool,
module: str,
fp8_weight_caching: bool = False,
) -> None:
# Skip invalid configurations.
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_params and not fp8:
pytest.skip("FP8 needed for FP8 parameters.")
if fp8_weight_caching and not fp8:
pytest.skip("FP8 needed for FP8 parameters.")
if module == "dpa" and num_layers > 1:
pytest.skip("Max 1 layer for DPA.")
config = model_configs[model]
# Run model with different CUDA graph settings.
model_config = model_configs[model_config]
kwargs = dict(
config=config,
module=module,
model_config=model_config,
num_layers=num_layers,
dtype=dtype,
fp8=fp8,
fp8_params=fp8_params,
fp8_weight_caching=fp8_weight_caching,
module=module,
)
outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)
# Check that results match
# Check that results match.
assert_all_equal(outputs, graph_outputs_mode1)
assert_all_equal(outputs, graph_outputs_mode2)
def _test_cuda_graphs_with_kwargs(
_test_make_graphed_callables_with_fp8_weight_caching_modules = [
"transformer",
"layernorm_mlp",
"layernorm_linear",
"linear",
"mha",
]
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize(
"module",
_test_make_graphed_callables_with_fp8_weight_caching_modules,
)
@pytest.mark.parametrize("fp8_params", (False, True))
def test_make_graphed_callables_with_fp8_weight_caching(
*,
config: ModelConfig,
module: str,
fp8_params: bool,
) -> None:
test_make_graphed_callables(
module=module,
dtype=torch.float32,
fp8=True,
fp8_params=fp8_params,
fp8_weight_caching=True,
)
def generate_data_for_dot_product_attention(
model_config: ModelConfig,
dtype: torch.dtype,
warmup: bool = False,
) -> List[torch.Tensor]:
"""Generate synthetic data for dot product attention."""
gen_func = torch.ones if warmup else torch.randn
return [
gen_func(
model_config.sequence_length,
model_config.batch_size,
model_config.num_heads,
model_config.kv_channels,
device="cuda",
requires_grad=True,
dtype=dtype,
)
for _ in range(3)
]
def _test_cuda_graphs_with_dot_product_attention(
*,
with_graph: bool,
model_config: ModelConfig,
dtype: torch.dtype,
) -> List[torch.Tensor]:
"""Simulate Megatron-LM interleaved pipeline parallelism."""
"""Helper function for CUDA graph test."""
reset_rng_states()
FP8GlobalStateManager.reset()
# Create dot product attention module.
assert model_config.hidden_size % model_config.num_heads == 0
model = DotProductAttention(
model_config.num_heads,
model_config.kv_channels,
attention_dropout=0.0,
)
# Graph model if needed.
if with_graph:
model = make_graphed_callables(
model,
generate_data_for_dot_product_attention(model_config, dtype, warmup=True),
num_warmup_iters=10,
fp8_enabled=False,
)
# Forward and backward passes.
for _ in range(3):
inputs = generate_data_for_dot_product_attention(model_config, dtype)
grad_output = generate_data(model_config, dtype, requires_grad=False)
output = model(*inputs)
output.backward(grad_output)
return get_outputs(model, output)
@pytest.mark.parametrize("dtype", dtypes)
def test_make_graphed_callables_with_dot_product_attention(
*,
model_config: str = "small",
dtype: torch.dtype,
) -> None:
"""Test CUDA graphs with dot product attention."""
model_config = model_configs[model_config]
kwargs = dict(model_config=model_config, dtype=dtype)
outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=False, **kwargs)
graph_outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=True, **kwargs)
assert_all_equal(outputs, graph_outputs)
def _test_cuda_graphs_with_kwargs(
*,
with_graph: bool,
model_config: ModelConfig,
dtype: torch.dtype,
) -> List[torch.Tensor]:
"""Helper function for CUDA graph test with keyword arguments."""
reset_rng_states()
# Initialize model.
model = TransformerLayer(
config.hidden_size,
config.hidden_size,
config.num_heads,
model_config.hidden_size,
model_config.hidden_size,
model_config.num_heads,
hidden_dropout=0.0,
attention_dropout=0.0,
self_attn_mask_type="arbitrary",
......@@ -349,13 +453,18 @@ def _test_cuda_graphs_with_kwargs(
# Make graphed version of model if needed.
if with_graph:
attn_mask = torch.zeros(
(config.batch_size, 1, config.sequence_length, config.sequence_length),
(
model_config.batch_size,
1,
model_config.sequence_length,
model_config.sequence_length,
),
dtype=torch.bool,
device="cuda",
)
model = make_graphed_callables(
model,
generate_data(config, dtype, warmup=True),
(generate_data(model_config, dtype, warmup=True),),
sample_kwargs=dict(attention_mask=attn_mask),
allow_unused_input=True,
)
......@@ -367,14 +476,20 @@ def _test_cuda_graphs_with_kwargs(
for _ in range(3):
optimizer.zero_grad(set_to_none=False)
for grad_accumulation_step in range(2):
inputs, grad_output = generate_data(config, dtype, return_grad_output=True)
input_ = generate_data(model_config, dtype)
grad_output = generate_data(model_config, dtype, requires_grad=False)
attn_mask = torch.randint(
2,
(config.batch_size, 1, config.sequence_length, config.sequence_length),
(
model_config.batch_size,
1,
model_config.sequence_length,
model_config.sequence_length,
),
dtype=torch.bool,
device="cuda",
)
output = model(*inputs, attention_mask=attn_mask)
output = model(input_, attention_mask=attn_mask)
output.backward(grad_output)
optimizer.step()
......@@ -382,12 +497,13 @@ def _test_cuda_graphs_with_kwargs(
def test_make_graphed_callables_with_kwargs(
*,
model_config: str = "small",
dtype: torch.dtype = torch.float32,
model: str = "small",
) -> None:
"""Test CUDA graphs with keyword arguments."""
config = model_configs[model]
kwargs = dict(config=config, dtype=dtype)
model_config = model_configs[model_config]
kwargs = dict(model_config=model_config, dtype=dtype)
outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs)
graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs)
assert_all_equal(outputs, graph_outputs)
......@@ -395,9 +511,9 @@ def test_make_graphed_callables_with_kwargs(
def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
*,
config: ModelConfig,
dtype: torch.dtype,
with_graph: bool,
model_config: ModelConfig,
dtype: torch.dtype,
) -> List[torch.Tensor]:
"""Simulate Megatron-LM interleaved pipeline parallelism."""
reset_rng_states()
......@@ -411,8 +527,8 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
model = torch.nn.ModuleList(
[
Linear(
config.hidden_size,
config.hidden_size,
model_config.hidden_size,
model_config.hidden_size,
params_dtype=dtype,
)
for _ in range(num_layers)
......@@ -430,7 +546,8 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
}
if with_graph:
sample_args = tuple(
generate_data(config, dtype, warmup=True) for _ in range(num_layers * num_microbatches)
(generate_data(model_config, dtype, warmup=True),)
for _ in range(num_layers * num_microbatches)
)
layer_forwards = make_graphed_callables(
tuple(model),
......@@ -455,9 +572,10 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
grad_outputs = {}
for layer_idx in range(num_layers):
for microbatch_idx in range(num_microbatches):
x, dy = generate_data(config, dtype, return_grad_output=True)
x = generate_data(model_config, dtype)
dy = generate_data(model_config, dtype, requires_grad=False)
idxs = (layer_idx, microbatch_idx)
inputs[idxs] = x[0]
inputs[idxs] = x
grad_outputs[idxs] = dy
# Cache for layer outputs.
......@@ -494,12 +612,13 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
*,
model_config: str = "small",
dtype: torch.dtype = torch.float16,
model: str = "small",
) -> None:
"""Test CUDA graphs with Megatron-LM interleaved pipeline parallelism."""
config = model_configs[model]
kwargs = dict(config=config, dtype=dtype)
model_config = model_configs[model_config]
kwargs = dict(model_config=model_config, dtype=dtype)
outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
with_graph=False,
**kwargs,
......
......@@ -277,7 +277,9 @@ class FP8GlobalStateManager:
@classmethod
def get_fp8_recipe(cls) -> DelayedScaling:
"""Return the fp8 recipe"""
return cls.FP8_RECIPE
if cls.FP8_RECIPE is not None:
return cls.FP8_RECIPE
return get_default_fp8_recipe()
@classmethod
def get_fp8_group(cls) -> Union[dist_group_type, None]:
......
......@@ -3,6 +3,7 @@
# See LICENSE for license information.
"""Functions for CUDA Graphs support in FP8"""
from collections.abc import Iterable
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
import torch
......@@ -18,7 +19,7 @@ from .fp8 import (
)
from .distributed import get_all_rng_states, graph_safe_rng_available
from .module.base import TransformerEngineBaseModule
from .ops.op import BasicOperation
__all__ = ["make_graphed_callables"]
......@@ -486,28 +487,46 @@ def _make_graphed_callables(
return tuple(ret)
def save_fp8_tensors(modules, amax_history_len):
def save_fp8_tensors(
modules: Iterable[torch.nn.Module],
fp8_recipe: DelayedScaling,
) -> List[Any]:
"""
Returns the FP8 tensors for all modules
with adjusted amax history sizes.
"""
saved_fp8_meta_tensors = []
fp8_tensors = []
for module in modules:
for m in module.modules():
module_tensors = None
if isinstance(m, TransformerEngineBaseModule):
if m.primary_weights_in_fp8:
m.adjust_amax_history_length(amax_history_len)
saved_fp8_meta_tensors.append(m.get_fp8_meta_tensors())
return saved_fp8_meta_tensors
def restore_fp8_tensors(modules, fp8_tensors):
m.adjust_amax_history_length(fp8_recipe.amax_history_len)
module_tensors = m.get_fp8_meta_tensors()
elif isinstance(m, BasicOperation):
m.pre_forward(fp8_enabled=True, fp8_recipe=fp8_recipe)
module_tensors = m._save_fp8_metas()
fp8_tensors.append(module_tensors)
return fp8_tensors
def restore_fp8_tensors(
modules: Iterable[torch.nn.Module],
fp8_tensors: List[Any],
) -> None:
"""Restore FP8 tensors."""
for module in modules:
for m in module.modules():
module_tensors = fp8_tensors.pop(0)
if isinstance(m, TransformerEngineBaseModule):
m.reset_fp8_meta_tensors(fp8_tensors.pop(0))
assert len(fp8_tensors) == 0, "TE internal error."
m.reset_fp8_meta_tensors(module_tensors)
elif isinstance(m, BasicOperation):
m._load_fp8_metas(module_tensors)
if len(fp8_tensors) != 0:
raise RuntimeError(
f"Got FP8 state for {len(fp8_tensors)} more modules than expected. "
"There is probably a discrepancy with `save_fp8_tensors`."
)
def make_graphed_callables(
......@@ -580,7 +599,7 @@ def make_graphed_callables(
modules = (modules,)
# Store FP8 tensors to reset later.
saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe.amax_history_len)
saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe=fp8_recipe)
# FP8 wrapper.
def wrap_autocast(block):
......
......@@ -308,8 +308,8 @@ class BasicLinear(BasicOperation):
weight = torch.nn.Parameter(weight)
self.weight = weight
def pre_forward(self) -> None:
super().pre_forward()
def pre_forward(self, *args, **kwargs) -> None:
super().pre_forward(*args, **kwargs)
if self.weight.device.type == "meta":
self.reset_parameters()
......
......@@ -111,8 +111,8 @@ class Bias(BasicOperation):
bias = torch.nn.Parameter(bias)
self.bias = bias
def pre_forward(self) -> None:
super().pre_forward()
def pre_forward(self, *args, **kwargs) -> None:
super().pre_forward(*args, **kwargs)
if self.bias.device.type == "meta":
self.reset_parameters()
......
......@@ -5,12 +5,12 @@
"""Manager class for a pipeline of fusible operations."""
from __future__ import annotations
from collections.abc import Callable
from typing import Any, Optional
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.graph import is_graph_capturing
from transformer_engine.pytorch.ops.op import (
BasicOperation,
FusibleOperation,
......@@ -28,6 +28,24 @@ def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]:
return t[:idx], t[idx:]
# Lazily imported function used in _is_graph_capturing
_is_graph_capturing_function: Optional[Callable[[], bool]] = None
def _is_graph_capturing() -> bool:
"""Whether function is called within `make_graphed_callables`
Avoid circular import with lazy import.
"""
global _is_graph_capturing_function
if _is_graph_capturing_function is None:
from ..graph import is_graph_capturing
_is_graph_capturing_function = is_graph_capturing
return _is_graph_capturing_function()
class _OperationFuserAutogradFunction(torch.autograd.Function):
"""Autograd function for a pipeline of operations
......@@ -255,7 +273,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
grad_extra_inputs_flat.extend(dxs)
# Update FP8 scaling factors
if func_ctx.is_first_module and not is_graph_capturing():
if func_ctx.is_first_module and not _is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
return (
......
......@@ -14,6 +14,7 @@ import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import (
DelayedScaling,
FP8GlobalStateManager,
get_default_fp8_recipe,
)
......@@ -231,25 +232,37 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
}
@classmethod
def _maybe_update_fp8_meta(cls, fp8_meta: Optional[dict[str, Any]]) -> None:
def _maybe_update_fp8_meta(
cls,
fp8_meta: Optional[dict[str, Any]],
*,
fp8_recipe: Optional[DelayedScaling] = None,
) -> None:
if fp8_meta is None:
return
# Update FP8 recipe and communication group
recipe = FP8GlobalStateManager.get_fp8_recipe()
fp8_meta["recipe"] = recipe
# Update FP8 recipe
if fp8_recipe is None:
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
fp8_meta["recipe"] = fp8_recipe
# Update FP8 communication group
fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
# Adjust amax history length if needed
amax_history_len = recipe.amax_history_len
amax_history_len = fp8_recipe.amax_history_len
for is_forward in (True, False):
key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward)
if key not in fp8_meta:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward)
if fp8_meta_key not in fp8_meta:
continue
meta = fp8_meta[key]
meta = fp8_meta[fp8_meta_key]
curr_len = meta.amax_history.size(0)
# Nothing to be done if amax history is already correct
if curr_len == amax_history_len:
continue
# Reallocate amax history
with torch.no_grad():
if curr_len > amax_history_len:
meta.amax_history = meta.amax_history[:amax_history_len].clone()
......@@ -259,6 +272,21 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
pad=(0, 0, 0, amax_history_len - curr_len),
)
# Update global buffers for amax reductions
buffer_info_key = FP8GlobalStateManager.get_buffer_info()
if buffer_info_key in fp8_meta:
fwd_pos, fwd_key, bwd_pos, bwd_key = fp8_meta[buffer_info_key]
for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)):
assert (
buffer_key in FP8GlobalStateManager.global_amax_history_buffer
), "TE internal error during amax history change."
FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = fp8_meta[
fp8_meta_key
].amax_history[0]
FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = fp8_meta[
fp8_meta_key
].amax_history
def get_fp8_meta(self, mode: str) -> Optional[dict[str, Any]]:
"""FP8 metadata
......@@ -272,11 +300,67 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
self._fp8_metas = self._make_fp8_metas()
return self._fp8_metas[mode]
def pre_forward(self) -> None:
@torch.no_grad()
def _save_fp8_metas(self) -> Optional[dict[str, Any]]:
"""Create copies of tensors in FP8 metadata
Tensor copies can be loaded with _load_fp8_metas.
"""
if self._fp8_metas is None:
return None
out = {}
for mode, fp8_meta in self._fp8_metas.items():
if fp8_meta is None:
continue
out[mode] = {}
for is_forward in (True, False):
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward)
if fp8_meta_key not in fp8_meta:
continue
out[mode][fp8_meta_key] = (
fp8_meta[fp8_meta_key].scale.clone(),
fp8_meta[fp8_meta_key].scale_inv.clone(),
fp8_meta[fp8_meta_key].amax_history.clone(),
)
return out
@torch.no_grad()
def _load_fp8_metas(self, fp8_metas: Optional[dict[str, Any]]) -> None:
"""Update FP8 metadata with saved tensor copies
Tensor copies should be generated with _save_fp8_metas.
"""
assert (self._fp8_metas is None) == (
fp8_metas is None
), "Saved FP8 metadata does not match operation's FP8 metadata"
if fp8_metas is None:
return
for mode, fp8_meta in fp8_metas.items():
assert (
mode in self._fp8_metas
), f"Found an unexpected key ({mode=}) in saved FP8 metadata"
for fp8_meta_key, tensors in fp8_meta.items():
assert (
fp8_meta_key in self._fp8_metas[mode]
), f"Found an unexpected key ({mode=}, {fp8_meta_key=}) in saved FP8 metadata"
scale, scale_inv, amax_history = tensors
self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale)
self._fp8_metas[mode][fp8_meta_key].scale_inv.copy_(scale_inv)
self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history)
def pre_forward(
self,
*,
fp8_enabled: Optional[bool] = None,
fp8_recipe: Optional[DelayedScaling] = None,
) -> None:
"""Preprocessing before forward pass"""
# Initialize FP8 metadata if needed
fp8_enabled = FP8GlobalStateManager.is_fp8_enabled()
if fp8_enabled is None:
fp8_enabled = FP8GlobalStateManager.is_fp8_enabled()
if fp8_enabled:
# Construct FP8 metadata if needed
......@@ -285,7 +369,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
# Make sure FP8 metadata matches FP8 autocast context
for fp8_meta in self._fp8_metas.values():
self._maybe_update_fp8_meta(fp8_meta)
self._maybe_update_fp8_meta(fp8_meta, fp8_recipe=fp8_recipe)
# Register FP8 metadata for amax and scale update
if not FP8GlobalStateManager.fp8_graph_capturing():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment