"tests/vscode:/vscode.git/clone" did not exist on "2d875521b1a4ed5e88b40c1450812737f890e5aa"
Unverified Commit 50b22da8 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Debug CUDA graph support with operation-based API (#1117)



* Debug CUDA graph support with operation-based API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Refactoring CUDA graph tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestions from @ptrendx

Return default recipe from FP8GlobalStateManager.get_fp8_recipe if needed. Expand error message when failing to load FP8 state after capturing CUDA graph.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid unnecessary recursion when saving/loading FP8 state
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix circular import
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent df949037
...@@ -13,24 +13,25 @@ from transformer_engine.pytorch import ( ...@@ -13,24 +13,25 @@ from transformer_engine.pytorch import (
LayerNormLinear, LayerNormLinear,
LayerNormMLP, LayerNormMLP,
Linear, Linear,
make_graphed_callables,
MultiheadAttention, MultiheadAttention,
TransformerLayer, TransformerLayer,
fp8_autocast, fp8_autocast,
fp8_model_init, fp8_model_init,
make_graphed_callables,
) )
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine.pytorch.ops as te_ops
# Only run FP8 tests on H100. # Check if FP8 is supported.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
# Record initial RNG state.
seed = 1234 seed = 1234
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
# Record initial RNG state from script run.
_cpu_rng_state = torch.get_rng_state() _cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state() _cuda_rng_state = torch.cuda.get_rng_state()
...@@ -48,17 +49,14 @@ class ModelConfig: ...@@ -48,17 +49,14 @@ class ModelConfig:
model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)} model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}
modules = ["transformer", "layernorm_mlp", "layernorm_linear", "linear", "mha", "dpa"] # Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16]
all_boolean = [True, False]
dtypes = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher if is_bf16_compatible(): # bf16 requires sm_80 or higher
dtypes.append(torch.bfloat16) dtypes.append(torch.bfloat16)
def reset_rng_states() -> None: def reset_rng_states() -> None:
"""revert back to initial RNG state.""" """Revert to initial RNG state."""
torch.set_rng_state(_cpu_rng_state) torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state) torch.cuda.set_rng_state(_cuda_rng_state)
...@@ -70,64 +68,40 @@ def reset_global_fp8_state(): ...@@ -70,64 +68,40 @@ def reset_global_fp8_state():
def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool: def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
"""Ensures two lists are equal.""" """Check that two lists of tensors match exactly."""
assert len(l1) == len(l2), "Unequal number of outputs." assert len(l1) == len(l2), "Unequal number of outputs."
failed = False failure_message = "Output mismatches in:"
failed_tensors = "" failed_tensors = []
for i, (t1, t2) in enumerate(zip(l1, l2)): for i, (t1, t2) in enumerate(zip(l1, l2)):
if not torch.equal(t1, t2): if not torch.equal(t1, t2):
failed = True failure_message += "\n "
failed_tensors += ( if names is None:
f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n" failure_message += f"tensor at idx={i}"
) else:
assert not failed, "Output mismatches in:\n" + failed_tensors failure_message += names[i]
failed_tensors.append((t1, t2))
if failed_tensors:
print(failure_message)
t1, t2 = failed_tensors[0]
torch.testing.assert_close(t1, t2, rtol=0, atol=0)
def generate_data( def generate_data(
config: ModelConfig, model_config: ModelConfig,
dtype: torch.dtype, dtype: torch.dtype,
dpa: bool = False,
warmup: bool = False, warmup: bool = False,
return_grad_output: bool = False, requires_grad: bool = True,
) -> Tuple[List[torch.Tensor], torch.Tensor]: ) -> torch.Tensor:
"""Generate synthetic data.""" """Generate synthetic data."""
gen_func = torch.ones if warmup else torch.randn gen_func = torch.ones if warmup else torch.randn
if dpa: return gen_func(
inputs = [ model_config.sequence_length,
gen_func( model_config.batch_size,
config.sequence_length, model_config.hidden_size,
config.batch_size,
config.num_heads,
config.kv_channels,
device="cuda",
requires_grad=True,
dtype=dtype,
)
for _ in range(3)
]
else:
inputs = [
gen_func(
config.sequence_length,
config.batch_size,
config.hidden_size,
device="cuda",
requires_grad=True,
dtype=dtype,
)
]
if not return_grad_output:
return inputs
grad_output = torch.randn(
config.sequence_length,
config.batch_size,
config.hidden_size,
device="cuda", device="cuda",
requires_grad=requires_grad,
dtype=dtype, dtype=dtype,
) )
return inputs, grad_output
def get_outputs( def get_outputs(
...@@ -157,30 +131,44 @@ class _Sequential(torch.nn.Sequential): ...@@ -157,30 +131,44 @@ class _Sequential(torch.nn.Sequential):
return x return x
# Supported modules
_test_cuda_graphs_modules: List[str] = [
"transformer",
"layernorm_mlp",
"layernorm_linear",
"linear",
"mha",
"linear_op",
]
def _test_cuda_graphs( def _test_cuda_graphs(
*, *,
config: ModelConfig, graph_mode: str,
module: str,
model_config: ModelConfig,
num_layers: int, num_layers: int,
dtype: torch.dtype, dtype: torch.dtype,
fp8: bool, fp8: bool,
fp8_params: bool, fp8_params: bool,
fp8_weight_caching: bool, fp8_weight_caching: bool,
module: str,
graph_mode: str,
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
"""Helper function for CUDA graph test.""" """Helper function for CUDA graph test."""
reset_rng_states() reset_rng_states()
FP8GlobalStateManager.reset() FP8GlobalStateManager.reset()
dpa = module == "dpa"
# Operation-based API does not support FP8 weight caching.
if module == "linear_op":
fp8_weight_caching = False
# Create modules.
with fp8_model_init(enabled=fp8_params): with fp8_model_init(enabled=fp8_params):
# Create modules.
if module == "transformer": if module == "transformer":
modules = [ modules = [
TransformerLayer( TransformerLayer(
config.hidden_size, model_config.hidden_size,
config.hidden_size, model_config.hidden_size,
config.num_heads, model_config.num_heads,
hidden_dropout=0.0, hidden_dropout=0.0,
attention_dropout=0.0, attention_dropout=0.0,
fuse_qkv_params=True, fuse_qkv_params=True,
...@@ -190,37 +178,56 @@ def _test_cuda_graphs( ...@@ -190,37 +178,56 @@ def _test_cuda_graphs(
] ]
elif module == "layernorm_mlp": elif module == "layernorm_mlp":
modules = [ modules = [
LayerNormMLP(config.hidden_size, config.hidden_size, params_dtype=dtype) LayerNormMLP(
model_config.hidden_size,
model_config.hidden_size,
params_dtype=dtype,
)
for _ in range(num_layers) for _ in range(num_layers)
] ]
elif module == "layernorm_linear": elif module == "layernorm_linear":
modules = [ modules = [
LayerNormLinear(config.hidden_size, config.hidden_size, params_dtype=dtype) LayerNormLinear(
model_config.hidden_size,
model_config.hidden_size,
params_dtype=dtype,
)
for _ in range(num_layers) for _ in range(num_layers)
] ]
elif module == "mha": elif module == "mha":
modules = [ modules = [
MultiheadAttention( MultiheadAttention(
config.hidden_size, model_config.hidden_size,
config.num_heads, model_config.num_heads,
attention_dropout=0.0, attention_dropout=0.0,
params_dtype=dtype, params_dtype=dtype,
fuse_qkv_params=True, fuse_qkv_params=True,
) )
for _ in range(num_layers) for _ in range(num_layers)
] ]
elif dpa: elif module == "linear":
assert config.hidden_size % config.num_heads == 0, "Err."
assert num_layers == 1, "Err."
modules = [ modules = [
DotProductAttention(config.num_heads, config.kv_channels, attention_dropout=0.0) Linear(
model_config.hidden_size,
model_config.hidden_size,
device="cuda",
params_dtype=dtype,
)
for _ in range(num_layers) for _ in range(num_layers)
] ]
else: elif module == "linear_op":
modules = [ modules = [
Linear(config.hidden_size, config.hidden_size, device="cuda", params_dtype=dtype) te_ops.Sequential(
te_ops.Linear(
model_config.hidden_size,
model_config.hidden_size,
dtype=dtype,
),
)
for _ in range(num_layers) for _ in range(num_layers)
] ]
else:
raise ValueError(f"Unknown module type ({module})")
# Initialize gradient buffers. # Initialize gradient buffers.
for module in modules: for module in modules:
...@@ -230,111 +237,208 @@ def _test_cuda_graphs( ...@@ -230,111 +237,208 @@ def _test_cuda_graphs(
# Generate model and wrap API to return graphed version. # Generate model and wrap API to return graphed version.
if graph_mode == "full": if graph_mode == "full":
# Graph entire model at once. # Graph entire model at once.
model = modules[0] if dpa else torch.nn.Sequential(*modules) model = torch.nn.Sequential(*modules)
model = make_graphed_callables( model = make_graphed_callables(
model, model,
generate_data(config, dtype, dpa=dpa, warmup=True), (generate_data(model_config, dtype, warmup=True),),
num_warmup_iters=10, num_warmup_iters=10,
fp8_enabled=fp8, fp8_enabled=fp8,
fp8_weight_caching=fp8_weight_caching, fp8_weight_caching=fp8_weight_caching,
) )
elif graph_mode == "individual": elif graph_mode == "individual":
# Graph individual modules # Graph individual modules.
modules = [ modules = [
make_graphed_callables( make_graphed_callables(
module, module,
generate_data(config, dtype, dpa=dpa, warmup=True), (generate_data(model_config, dtype, warmup=True),),
num_warmup_iters=10, num_warmup_iters=10,
fp8_enabled=fp8, fp8_enabled=fp8,
fp8_weight_caching=fp8_weight_caching, fp8_weight_caching=fp8_weight_caching,
) )
for module in modules for module in modules
] ]
model = modules[0] if dpa else _Sequential(*modules) model = _Sequential(*modules)
else: else:
model = modules[0] if dpa else _Sequential(*modules) model = _Sequential(*modules)
# Optimizer. # Optimizer.
if not dpa: optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
# Launch. # Training steps.
for _ in range(3): for _ in range(3):
if not dpa: optimizer.zero_grad(set_to_none=False)
optimizer.zero_grad(set_to_none=False)
for grad_accumulation_step in range(2): for grad_accumulation_step in range(2):
inputs, grad_output = generate_data(config, dtype, dpa=dpa, return_grad_output=True) input_ = generate_data(model_config, dtype)
grad_output = generate_data(model_config, dtype, requires_grad=False)
with fp8_autocast(enabled=fp8): with fp8_autocast(enabled=fp8):
kwargs = {} kwargs = {}
if fp8_weight_caching: if fp8_weight_caching:
kwargs["is_first_microbatch"] = grad_accumulation_step == 0 kwargs["is_first_microbatch"] = grad_accumulation_step == 0
output = model(*inputs, **kwargs) output = model(input_, **kwargs)
output.backward(grad_output) output.backward(grad_output)
if not dpa: optimizer.step()
optimizer.step()
return get_outputs(model, output) return get_outputs(model, output)
@pytest.mark.parametrize("module", _test_cuda_graphs_modules)
@pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("num_layers", [1, 3]) @pytest.mark.parametrize("fp8_params", (False, True))
@pytest.mark.parametrize("fp8", all_boolean) def test_make_graphed_callables(
@pytest.mark.parametrize("fp8_params", all_boolean) *,
@pytest.mark.parametrize("fp8_weight_caching", all_boolean) module: str,
@pytest.mark.parametrize("module", modules) model_config: str = "small",
def test_gpt_make_graphed_callables( num_layers: int = 3,
dtype: torch.dtype, dtype: torch.dtype,
model: str,
num_layers: int,
fp8: bool, fp8: bool,
fp8_params: bool, fp8_params: bool,
fp8_weight_caching: bool, fp8_weight_caching: bool = False,
module: str,
) -> None: ) -> None:
# Skip invalid configurations.
if fp8 and not fp8_available: if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if fp8_params and not fp8: if fp8_params and not fp8:
pytest.skip("FP8 needed for FP8 parameters.") pytest.skip("FP8 needed for FP8 parameters.")
if fp8_weight_caching and not fp8: if fp8_weight_caching and not fp8:
pytest.skip("FP8 needed for FP8 parameters.") pytest.skip("FP8 needed for FP8 parameters.")
if module == "dpa" and num_layers > 1:
pytest.skip("Max 1 layer for DPA.")
config = model_configs[model]
# Run model with different CUDA graph settings.
model_config = model_configs[model_config]
kwargs = dict( kwargs = dict(
config=config, module=module,
model_config=model_config,
num_layers=num_layers, num_layers=num_layers,
dtype=dtype, dtype=dtype,
fp8=fp8, fp8=fp8,
fp8_params=fp8_params, fp8_params=fp8_params,
fp8_weight_caching=fp8_weight_caching, fp8_weight_caching=fp8_weight_caching,
module=module,
) )
outputs = _test_cuda_graphs(graph_mode="none", **kwargs) outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs) graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs) graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)
# Check that results match # Check that results match.
assert_all_equal(outputs, graph_outputs_mode1) assert_all_equal(outputs, graph_outputs_mode1)
assert_all_equal(outputs, graph_outputs_mode2) assert_all_equal(outputs, graph_outputs_mode2)
def _test_cuda_graphs_with_kwargs( _test_make_graphed_callables_with_fp8_weight_caching_modules = [
"transformer",
"layernorm_mlp",
"layernorm_linear",
"linear",
"mha",
]
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize(
"module",
_test_make_graphed_callables_with_fp8_weight_caching_modules,
)
@pytest.mark.parametrize("fp8_params", (False, True))
def test_make_graphed_callables_with_fp8_weight_caching(
*, *,
config: ModelConfig, module: str,
fp8_params: bool,
) -> None:
test_make_graphed_callables(
module=module,
dtype=torch.float32,
fp8=True,
fp8_params=fp8_params,
fp8_weight_caching=True,
)
def generate_data_for_dot_product_attention(
model_config: ModelConfig,
dtype: torch.dtype, dtype: torch.dtype,
warmup: bool = False,
) -> List[torch.Tensor]:
"""Generate synthetic data for dot product attention."""
gen_func = torch.ones if warmup else torch.randn
return [
gen_func(
model_config.sequence_length,
model_config.batch_size,
model_config.num_heads,
model_config.kv_channels,
device="cuda",
requires_grad=True,
dtype=dtype,
)
for _ in range(3)
]
def _test_cuda_graphs_with_dot_product_attention(
*,
with_graph: bool, with_graph: bool,
model_config: ModelConfig,
dtype: torch.dtype,
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
"""Simulate Megatron-LM interleaved pipeline parallelism.""" """Helper function for CUDA graph test."""
reset_rng_states()
FP8GlobalStateManager.reset()
# Create dot product attention module.
assert model_config.hidden_size % model_config.num_heads == 0
model = DotProductAttention(
model_config.num_heads,
model_config.kv_channels,
attention_dropout=0.0,
)
# Graph model if needed.
if with_graph:
model = make_graphed_callables(
model,
generate_data_for_dot_product_attention(model_config, dtype, warmup=True),
num_warmup_iters=10,
fp8_enabled=False,
)
# Forward and backward passes.
for _ in range(3):
inputs = generate_data_for_dot_product_attention(model_config, dtype)
grad_output = generate_data(model_config, dtype, requires_grad=False)
output = model(*inputs)
output.backward(grad_output)
return get_outputs(model, output)
@pytest.mark.parametrize("dtype", dtypes)
def test_make_graphed_callables_with_dot_product_attention(
*,
model_config: str = "small",
dtype: torch.dtype,
) -> None:
"""Test CUDA graphs with dot product attention."""
model_config = model_configs[model_config]
kwargs = dict(model_config=model_config, dtype=dtype)
outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=False, **kwargs)
graph_outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=True, **kwargs)
assert_all_equal(outputs, graph_outputs)
def _test_cuda_graphs_with_kwargs(
*,
with_graph: bool,
model_config: ModelConfig,
dtype: torch.dtype,
) -> List[torch.Tensor]:
"""Helper function for CUDA graph test with keyword arguments."""
reset_rng_states() reset_rng_states()
# Initialize model. # Initialize model.
model = TransformerLayer( model = TransformerLayer(
config.hidden_size, model_config.hidden_size,
config.hidden_size, model_config.hidden_size,
config.num_heads, model_config.num_heads,
hidden_dropout=0.0, hidden_dropout=0.0,
attention_dropout=0.0, attention_dropout=0.0,
self_attn_mask_type="arbitrary", self_attn_mask_type="arbitrary",
...@@ -349,13 +453,18 @@ def _test_cuda_graphs_with_kwargs( ...@@ -349,13 +453,18 @@ def _test_cuda_graphs_with_kwargs(
# Make graphed version of model if needed. # Make graphed version of model if needed.
if with_graph: if with_graph:
attn_mask = torch.zeros( attn_mask = torch.zeros(
(config.batch_size, 1, config.sequence_length, config.sequence_length), (
model_config.batch_size,
1,
model_config.sequence_length,
model_config.sequence_length,
),
dtype=torch.bool, dtype=torch.bool,
device="cuda", device="cuda",
) )
model = make_graphed_callables( model = make_graphed_callables(
model, model,
generate_data(config, dtype, warmup=True), (generate_data(model_config, dtype, warmup=True),),
sample_kwargs=dict(attention_mask=attn_mask), sample_kwargs=dict(attention_mask=attn_mask),
allow_unused_input=True, allow_unused_input=True,
) )
...@@ -367,14 +476,20 @@ def _test_cuda_graphs_with_kwargs( ...@@ -367,14 +476,20 @@ def _test_cuda_graphs_with_kwargs(
for _ in range(3): for _ in range(3):
optimizer.zero_grad(set_to_none=False) optimizer.zero_grad(set_to_none=False)
for grad_accumulation_step in range(2): for grad_accumulation_step in range(2):
inputs, grad_output = generate_data(config, dtype, return_grad_output=True) input_ = generate_data(model_config, dtype)
grad_output = generate_data(model_config, dtype, requires_grad=False)
attn_mask = torch.randint( attn_mask = torch.randint(
2, 2,
(config.batch_size, 1, config.sequence_length, config.sequence_length), (
model_config.batch_size,
1,
model_config.sequence_length,
model_config.sequence_length,
),
dtype=torch.bool, dtype=torch.bool,
device="cuda", device="cuda",
) )
output = model(*inputs, attention_mask=attn_mask) output = model(input_, attention_mask=attn_mask)
output.backward(grad_output) output.backward(grad_output)
optimizer.step() optimizer.step()
...@@ -382,12 +497,13 @@ def _test_cuda_graphs_with_kwargs( ...@@ -382,12 +497,13 @@ def _test_cuda_graphs_with_kwargs(
def test_make_graphed_callables_with_kwargs( def test_make_graphed_callables_with_kwargs(
*,
model_config: str = "small",
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
model: str = "small",
) -> None: ) -> None:
"""Test CUDA graphs with keyword arguments.""" """Test CUDA graphs with keyword arguments."""
config = model_configs[model] model_config = model_configs[model_config]
kwargs = dict(config=config, dtype=dtype) kwargs = dict(model_config=model_config, dtype=dtype)
outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs) outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs)
graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs) graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs)
assert_all_equal(outputs, graph_outputs) assert_all_equal(outputs, graph_outputs)
...@@ -395,9 +511,9 @@ def test_make_graphed_callables_with_kwargs( ...@@ -395,9 +511,9 @@ def test_make_graphed_callables_with_kwargs(
def _test_cuda_graphs_with_interleaved_pipeline_parallelism( def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
*, *,
config: ModelConfig,
dtype: torch.dtype,
with_graph: bool, with_graph: bool,
model_config: ModelConfig,
dtype: torch.dtype,
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
"""Simulate Megatron-LM interleaved pipeline parallelism.""" """Simulate Megatron-LM interleaved pipeline parallelism."""
reset_rng_states() reset_rng_states()
...@@ -411,8 +527,8 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism( ...@@ -411,8 +527,8 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
model = torch.nn.ModuleList( model = torch.nn.ModuleList(
[ [
Linear( Linear(
config.hidden_size, model_config.hidden_size,
config.hidden_size, model_config.hidden_size,
params_dtype=dtype, params_dtype=dtype,
) )
for _ in range(num_layers) for _ in range(num_layers)
...@@ -430,7 +546,8 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism( ...@@ -430,7 +546,8 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
} }
if with_graph: if with_graph:
sample_args = tuple( sample_args = tuple(
generate_data(config, dtype, warmup=True) for _ in range(num_layers * num_microbatches) (generate_data(model_config, dtype, warmup=True),)
for _ in range(num_layers * num_microbatches)
) )
layer_forwards = make_graphed_callables( layer_forwards = make_graphed_callables(
tuple(model), tuple(model),
...@@ -455,9 +572,10 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism( ...@@ -455,9 +572,10 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
grad_outputs = {} grad_outputs = {}
for layer_idx in range(num_layers): for layer_idx in range(num_layers):
for microbatch_idx in range(num_microbatches): for microbatch_idx in range(num_microbatches):
x, dy = generate_data(config, dtype, return_grad_output=True) x = generate_data(model_config, dtype)
dy = generate_data(model_config, dtype, requires_grad=False)
idxs = (layer_idx, microbatch_idx) idxs = (layer_idx, microbatch_idx)
inputs[idxs] = x[0] inputs[idxs] = x
grad_outputs[idxs] = dy grad_outputs[idxs] = dy
# Cache for layer outputs. # Cache for layer outputs.
...@@ -494,12 +612,13 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism( ...@@ -494,12 +612,13 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
def test_make_graphed_callables_with_interleaved_pipeline_parallelism( def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
*,
model_config: str = "small",
dtype: torch.dtype = torch.float16, dtype: torch.dtype = torch.float16,
model: str = "small",
) -> None: ) -> None:
"""Test CUDA graphs with Megatron-LM interleaved pipeline parallelism.""" """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism."""
config = model_configs[model] model_config = model_configs[model_config]
kwargs = dict(config=config, dtype=dtype) kwargs = dict(model_config=model_config, dtype=dtype)
outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism( outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
with_graph=False, with_graph=False,
**kwargs, **kwargs,
......
...@@ -277,7 +277,9 @@ class FP8GlobalStateManager: ...@@ -277,7 +277,9 @@ class FP8GlobalStateManager:
@classmethod @classmethod
def get_fp8_recipe(cls) -> DelayedScaling: def get_fp8_recipe(cls) -> DelayedScaling:
"""Return the fp8 recipe""" """Return the fp8 recipe"""
return cls.FP8_RECIPE if cls.FP8_RECIPE is not None:
return cls.FP8_RECIPE
return get_default_fp8_recipe()
@classmethod @classmethod
def get_fp8_group(cls) -> Union[dist_group_type, None]: def get_fp8_group(cls) -> Union[dist_group_type, None]:
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Functions for CUDA Graphs support in FP8""" """Functions for CUDA Graphs support in FP8"""
from collections.abc import Iterable
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
import torch import torch
...@@ -18,7 +19,7 @@ from .fp8 import ( ...@@ -18,7 +19,7 @@ from .fp8 import (
) )
from .distributed import get_all_rng_states, graph_safe_rng_available from .distributed import get_all_rng_states, graph_safe_rng_available
from .module.base import TransformerEngineBaseModule from .module.base import TransformerEngineBaseModule
from .ops.op import BasicOperation
__all__ = ["make_graphed_callables"] __all__ = ["make_graphed_callables"]
...@@ -486,28 +487,46 @@ def _make_graphed_callables( ...@@ -486,28 +487,46 @@ def _make_graphed_callables(
return tuple(ret) return tuple(ret)
def save_fp8_tensors(modules, amax_history_len): def save_fp8_tensors(
modules: Iterable[torch.nn.Module],
fp8_recipe: DelayedScaling,
) -> List[Any]:
""" """
Returns the FP8 tensors for all modules Returns the FP8 tensors for all modules
with adjusted amax history sizes. with adjusted amax history sizes.
""" """
saved_fp8_meta_tensors = [] fp8_tensors = []
for module in modules: for module in modules:
for m in module.modules(): for m in module.modules():
module_tensors = None
if isinstance(m, TransformerEngineBaseModule): if isinstance(m, TransformerEngineBaseModule):
if m.primary_weights_in_fp8: if m.primary_weights_in_fp8:
m.adjust_amax_history_length(amax_history_len) m.adjust_amax_history_length(fp8_recipe.amax_history_len)
saved_fp8_meta_tensors.append(m.get_fp8_meta_tensors()) module_tensors = m.get_fp8_meta_tensors()
return saved_fp8_meta_tensors elif isinstance(m, BasicOperation):
m.pre_forward(fp8_enabled=True, fp8_recipe=fp8_recipe)
module_tensors = m._save_fp8_metas()
def restore_fp8_tensors(modules, fp8_tensors): fp8_tensors.append(module_tensors)
return fp8_tensors
def restore_fp8_tensors(
modules: Iterable[torch.nn.Module],
fp8_tensors: List[Any],
) -> None:
"""Restore FP8 tensors.""" """Restore FP8 tensors."""
for module in modules: for module in modules:
for m in module.modules(): for m in module.modules():
module_tensors = fp8_tensors.pop(0)
if isinstance(m, TransformerEngineBaseModule): if isinstance(m, TransformerEngineBaseModule):
m.reset_fp8_meta_tensors(fp8_tensors.pop(0)) m.reset_fp8_meta_tensors(module_tensors)
assert len(fp8_tensors) == 0, "TE internal error." elif isinstance(m, BasicOperation):
m._load_fp8_metas(module_tensors)
if len(fp8_tensors) != 0:
raise RuntimeError(
f"Got FP8 state for {len(fp8_tensors)} more modules than expected. "
"There is probably a discrepancy with `save_fp8_tensors`."
)
def make_graphed_callables( def make_graphed_callables(
...@@ -580,7 +599,7 @@ def make_graphed_callables( ...@@ -580,7 +599,7 @@ def make_graphed_callables(
modules = (modules,) modules = (modules,)
# Store FP8 tensors to reset later. # Store FP8 tensors to reset later.
saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe.amax_history_len) saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe=fp8_recipe)
# FP8 wrapper. # FP8 wrapper.
def wrap_autocast(block): def wrap_autocast(block):
......
...@@ -308,8 +308,8 @@ class BasicLinear(BasicOperation): ...@@ -308,8 +308,8 @@ class BasicLinear(BasicOperation):
weight = torch.nn.Parameter(weight) weight = torch.nn.Parameter(weight)
self.weight = weight self.weight = weight
def pre_forward(self) -> None: def pre_forward(self, *args, **kwargs) -> None:
super().pre_forward() super().pre_forward(*args, **kwargs)
if self.weight.device.type == "meta": if self.weight.device.type == "meta":
self.reset_parameters() self.reset_parameters()
......
...@@ -111,8 +111,8 @@ class Bias(BasicOperation): ...@@ -111,8 +111,8 @@ class Bias(BasicOperation):
bias = torch.nn.Parameter(bias) bias = torch.nn.Parameter(bias)
self.bias = bias self.bias = bias
def pre_forward(self) -> None: def pre_forward(self, *args, **kwargs) -> None:
super().pre_forward() super().pre_forward(*args, **kwargs)
if self.bias.device.type == "meta": if self.bias.device.type == "meta":
self.reset_parameters() self.reset_parameters()
......
...@@ -5,12 +5,12 @@ ...@@ -5,12 +5,12 @@
"""Manager class for a pipeline of fusible operations.""" """Manager class for a pipeline of fusible operations."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable
from typing import Any, Optional from typing import Any, Optional
import torch import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.graph import is_graph_capturing
from transformer_engine.pytorch.ops.op import ( from transformer_engine.pytorch.ops.op import (
BasicOperation, BasicOperation,
FusibleOperation, FusibleOperation,
...@@ -28,6 +28,24 @@ def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]: ...@@ -28,6 +28,24 @@ def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]:
return t[:idx], t[idx:] return t[:idx], t[idx:]
# Lazily imported function used in _is_graph_capturing
_is_graph_capturing_function: Optional[Callable[[], bool]] = None
def _is_graph_capturing() -> bool:
"""Whether function is called within `make_graphed_callables`
Avoid circular import with lazy import.
"""
global _is_graph_capturing_function
if _is_graph_capturing_function is None:
from ..graph import is_graph_capturing
_is_graph_capturing_function = is_graph_capturing
return _is_graph_capturing_function()
class _OperationFuserAutogradFunction(torch.autograd.Function): class _OperationFuserAutogradFunction(torch.autograd.Function):
"""Autograd function for a pipeline of operations """Autograd function for a pipeline of operations
...@@ -255,7 +273,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -255,7 +273,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
grad_extra_inputs_flat.extend(dxs) grad_extra_inputs_flat.extend(dxs)
# Update FP8 scaling factors # Update FP8 scaling factors
if func_ctx.is_first_module and not is_graph_capturing(): if func_ctx.is_first_module and not _is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
return ( return (
......
...@@ -14,6 +14,7 @@ import torch ...@@ -14,6 +14,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import ( from transformer_engine.pytorch.fp8 import (
DelayedScaling,
FP8GlobalStateManager, FP8GlobalStateManager,
get_default_fp8_recipe, get_default_fp8_recipe,
) )
...@@ -231,25 +232,37 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -231,25 +232,37 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
} }
@classmethod @classmethod
def _maybe_update_fp8_meta(cls, fp8_meta: Optional[dict[str, Any]]) -> None: def _maybe_update_fp8_meta(
cls,
fp8_meta: Optional[dict[str, Any]],
*,
fp8_recipe: Optional[DelayedScaling] = None,
) -> None:
if fp8_meta is None: if fp8_meta is None:
return return
# Update FP8 recipe and communication group # Update FP8 recipe
recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_recipe is None:
fp8_meta["recipe"] = recipe fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
fp8_meta["recipe"] = fp8_recipe
# Update FP8 communication group
fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
# Adjust amax history length if needed # Adjust amax history length if needed
amax_history_len = recipe.amax_history_len amax_history_len = fp8_recipe.amax_history_len
for is_forward in (True, False): for is_forward in (True, False):
key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward) fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward)
if key not in fp8_meta: if fp8_meta_key not in fp8_meta:
continue continue
meta = fp8_meta[key] meta = fp8_meta[fp8_meta_key]
curr_len = meta.amax_history.size(0) curr_len = meta.amax_history.size(0)
# Nothing to be done if amax history is already correct
if curr_len == amax_history_len: if curr_len == amax_history_len:
continue continue
# Reallocate amax history
with torch.no_grad(): with torch.no_grad():
if curr_len > amax_history_len: if curr_len > amax_history_len:
meta.amax_history = meta.amax_history[:amax_history_len].clone() meta.amax_history = meta.amax_history[:amax_history_len].clone()
...@@ -259,6 +272,21 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -259,6 +272,21 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
pad=(0, 0, 0, amax_history_len - curr_len), pad=(0, 0, 0, amax_history_len - curr_len),
) )
# Update global buffers for amax reductions
buffer_info_key = FP8GlobalStateManager.get_buffer_info()
if buffer_info_key in fp8_meta:
fwd_pos, fwd_key, bwd_pos, bwd_key = fp8_meta[buffer_info_key]
for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)):
assert (
buffer_key in FP8GlobalStateManager.global_amax_history_buffer
), "TE internal error during amax history change."
FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = fp8_meta[
fp8_meta_key
].amax_history[0]
FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = fp8_meta[
fp8_meta_key
].amax_history
def get_fp8_meta(self, mode: str) -> Optional[dict[str, Any]]: def get_fp8_meta(self, mode: str) -> Optional[dict[str, Any]]:
"""FP8 metadata """FP8 metadata
...@@ -272,11 +300,67 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -272,11 +300,67 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
self._fp8_metas = self._make_fp8_metas() self._fp8_metas = self._make_fp8_metas()
return self._fp8_metas[mode] return self._fp8_metas[mode]
def pre_forward(self) -> None: @torch.no_grad()
def _save_fp8_metas(self) -> Optional[dict[str, Any]]:
"""Create copies of tensors in FP8 metadata
Tensor copies can be loaded with _load_fp8_metas.
"""
if self._fp8_metas is None:
return None
out = {}
for mode, fp8_meta in self._fp8_metas.items():
if fp8_meta is None:
continue
out[mode] = {}
for is_forward in (True, False):
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward)
if fp8_meta_key not in fp8_meta:
continue
out[mode][fp8_meta_key] = (
fp8_meta[fp8_meta_key].scale.clone(),
fp8_meta[fp8_meta_key].scale_inv.clone(),
fp8_meta[fp8_meta_key].amax_history.clone(),
)
return out
@torch.no_grad()
def _load_fp8_metas(self, fp8_metas: Optional[dict[str, Any]]) -> None:
"""Update FP8 metadata with saved tensor copies
Tensor copies should be generated with _save_fp8_metas.
"""
assert (self._fp8_metas is None) == (
fp8_metas is None
), "Saved FP8 metadata does not match operation's FP8 metadata"
if fp8_metas is None:
return
for mode, fp8_meta in fp8_metas.items():
assert (
mode in self._fp8_metas
), f"Found an unexpected key ({mode=}) in saved FP8 metadata"
for fp8_meta_key, tensors in fp8_meta.items():
assert (
fp8_meta_key in self._fp8_metas[mode]
), f"Found an unexpected key ({mode=}, {fp8_meta_key=}) in saved FP8 metadata"
scale, scale_inv, amax_history = tensors
self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale)
self._fp8_metas[mode][fp8_meta_key].scale_inv.copy_(scale_inv)
self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history)
def pre_forward(
self,
*,
fp8_enabled: Optional[bool] = None,
fp8_recipe: Optional[DelayedScaling] = None,
) -> None:
"""Preprocessing before forward pass""" """Preprocessing before forward pass"""
# Initialize FP8 metadata if needed # Initialize FP8 metadata if needed
fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() if fp8_enabled is None:
fp8_enabled = FP8GlobalStateManager.is_fp8_enabled()
if fp8_enabled: if fp8_enabled:
# Construct FP8 metadata if needed # Construct FP8 metadata if needed
...@@ -285,7 +369,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -285,7 +369,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
# Make sure FP8 metadata matches FP8 autocast context # Make sure FP8 metadata matches FP8 autocast context
for fp8_meta in self._fp8_metas.values(): for fp8_meta in self._fp8_metas.values():
self._maybe_update_fp8_meta(fp8_meta) self._maybe_update_fp8_meta(fp8_meta, fp8_recipe=fp8_recipe)
# Register FP8 metadata for amax and scale update # Register FP8 metadata for amax and scale update
if not FP8GlobalStateManager.fp8_graph_capturing(): if not FP8GlobalStateManager.fp8_graph_capturing():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment