Unverified Commit 868c7d30 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Add CUDA graph tests with FP8 weight caching (#869)



* Modify CUDA graph tests to use grad accumulation steps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Initialize grad buffers before capturing CUDA graph in CUDA graph tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Only use BS=2 in CUDA graph tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update tests/pytorch/test_cuda_graphs.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 8b210490
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
from dataclasses import dataclass
from typing import List, Tuple from typing import List, Tuple
import pytest import pytest
...@@ -25,22 +26,19 @@ torch.cuda.manual_seed(seed) ...@@ -25,22 +26,19 @@ torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state() _cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state() _cuda_rng_state = torch.cuda.get_rng_state()
@dataclass
class ModelConfig: class ModelConfig:
def __init__(self, hidden_size, nheads, kv, seq_len): """Data tensor dimensions within Transformer model"""
self.h = hidden_size sequence_length: int
self.nheads = nheads batch_size: int
self.kv = kv hidden_size: int
self.s = seq_len num_heads: int
kv_channels: int
model_configs = { model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}
"small": ModelConfig(64, 2, 32, 32),
}
modules = ["transformer", "layernorm_mlp", "layernorm_linear", "linear", "mha", "dpa"] modules = ["transformer", "layernorm_mlp", "layernorm_linear", "linear", "mha", "dpa"]
optimizers = [torch.optim.SGD, torch.optim.Adam]
all_boolean = [True, False] all_boolean = [True, False]
dtypes = [torch.float32, torch.float16] dtypes = [torch.float32, torch.float16]
...@@ -66,9 +64,6 @@ def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) ...@@ -66,9 +64,6 @@ def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None)
failed = False failed = False
failed_tensors = "" failed_tensors = ""
for i, (t1, t2) in enumerate(zip(l1, l2)): for i, (t1, t2) in enumerate(zip(l1, l2)):
with torch.no_grad():
t1.masked_fill_(t1.isnan(), 1.0)
t2.masked_fill_(t2.isnan(), 1.0)
if not torch.equal(t1, t2): if not torch.equal(t1, t2):
failed = True failed = True
failed_tensors += f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n" failed_tensors += f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n"
...@@ -76,21 +71,50 @@ def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) ...@@ -76,21 +71,50 @@ def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None)
def generate_data( def generate_data(
s: int, b: int, h: int, nheads: int, kv: int, dtype: torch.dtype, config: ModelConfig,
dpa: bool = False, warmup: bool = False, gen_labels: bool = False, dtype: torch.dtype,
dpa: bool = False,
warmup: bool = False,
return_grad_output: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Generate synthetic data.""" """Generate synthetic data."""
gen_func = torch.ones if warmup else torch.randn gen_func = torch.ones if warmup else torch.randn
if dpa: if dpa:
inputs = [gen_func(s, b, nheads, kv, device="cuda", requires_grad=True, dtype=dtype) for _ in range(3)] inputs = [
gen_func(
config.sequence_length,
config.batch_size,
config.num_heads,
config.kv_channels,
device="cuda",
requires_grad=True,
dtype=dtype,
)
for _ in range(3)
]
else: else:
inputs = [gen_func(s, b, h, device="cuda", requires_grad=True, dtype=dtype)] inputs = [
gen_func(
if not gen_labels: config.sequence_length,
config.batch_size,
config.hidden_size,
device="cuda",
requires_grad=True,
dtype=dtype,
)
]
if not return_grad_output:
return inputs return inputs
target = torch.randn(s, b, h, device="cuda", dtype=dtype) grad_output = torch.randn(
return inputs, target config.sequence_length,
config.batch_size,
config.hidden_size,
device="cuda",
dtype=dtype,
)
return inputs, grad_output
def get_outputs(model, output): def get_outputs(model, output):
...@@ -104,7 +128,27 @@ def get_outputs(model, output): ...@@ -104,7 +128,27 @@ def get_outputs(model, output):
return values return values
def _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, graph, module, optimizer, graph_mode=""): class _Sequential(torch.nn.Sequential):
"""Sequential model that forwards keyword arguments to modules"""
def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
x = input_
for module in self:
x = module(x, **kwargs)
return x
def _test_cuda_graphs(
*,
config: ModelConfig,
num_layers: int,
dtype: torch.dtype,
fp8: bool,
fp8_params: bool,
fp8_weight_caching: bool,
module: str,
graph_mode: str,
) -> List[torch.Tensor]:
"""Helper function for test.""" """Helper function for test."""
reset_rng_states() reset_rng_states()
FP8GlobalStateManager.reset() FP8GlobalStateManager.reset()
...@@ -114,9 +158,9 @@ def _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, graph, mod ...@@ -114,9 +158,9 @@ def _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, graph, mod
# Create modules. # Create modules.
if module == "transformer": if module == "transformer":
modules = [TransformerLayer( modules = [TransformerLayer(
config.h, config.hidden_size,
config.h, config.hidden_size,
config.nheads, config.num_heads,
hidden_dropout=0.0, hidden_dropout=0.0,
attention_dropout=0.0, attention_dropout=0.0,
fuse_qkv_params=True, fuse_qkv_params=True,
...@@ -124,91 +168,124 @@ def _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, graph, mod ...@@ -124,91 +168,124 @@ def _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, graph, mod
) for _ in range(num_layers)] ) for _ in range(num_layers)]
elif module == "layernorm_mlp": elif module == "layernorm_mlp":
modules = [LayerNormMLP( modules = [LayerNormMLP(
config.h, config.h, params_dtype=dtype config.hidden_size, config.hidden_size, params_dtype=dtype
) for _ in range(num_layers)] ) for _ in range(num_layers)]
elif module == "layernorm_linear": elif module == "layernorm_linear":
modules = [LayerNormLinear( modules = [LayerNormLinear(
config.h, config.h, params_dtype=dtype config.hidden_size, config.hidden_size, params_dtype=dtype
) for _ in range(num_layers)] ) for _ in range(num_layers)]
elif module == "mha": elif module == "mha":
modules = [MultiheadAttention( modules = [MultiheadAttention(
config.h, config.hidden_size,
config.nheads, config.num_heads,
attention_dropout=0.0, attention_dropout=0.0,
params_dtype=dtype, params_dtype=dtype,
fuse_qkv_params=True, fuse_qkv_params=True,
) for _ in range(num_layers)] ) for _ in range(num_layers)]
elif dpa: elif dpa:
assert config.h % config.nheads == 0, "Err." assert config.hidden_size % config.num_heads == 0, "Err."
assert num_layers == 1, "Err." assert num_layers == 1, "Err."
modules = [DotProductAttention( modules = [DotProductAttention(
config.nheads, config.kv, attention_dropout=0.0 config.num_heads, config.kv_channels, attention_dropout=0.0
) for _ in range(num_layers)] ) for _ in range(num_layers)]
else: else:
modules = [Linear( modules = [Linear(
config.h, config.h, device="cuda", params_dtype=dtype config.hidden_size, config.hidden_size, device="cuda", params_dtype=dtype
) for _ in range(num_layers)] ) for _ in range(num_layers)]
# Initialize gradient buffers.
for module in modules:
for param in module.parameters():
param.grad = torch.empty_like(param)
# Generate model and wrap API to return graphed version. # Generate model and wrap API to return graphed version.
if graph:
# Graph entire module at once.
if graph_mode == "full": if graph_mode == "full":
# Graph entire model at once.
model = modules[0] if dpa else torch.nn.Sequential(*modules) model = modules[0] if dpa else torch.nn.Sequential(*modules)
model = make_graphed_callables( model = make_graphed_callables(
model, model,
generate_data(config.s, bs, config.h, config.nheads, config.kv, dtype, dpa=dpa, warmup=True), generate_data(config, dtype, dpa=dpa, warmup=True),
num_warmup_iters=10, num_warmup_iters=10,
fp8_enabled=fp8) fp8_enabled=fp8,
else: fp8_weight_caching=fp8_weight_caching,
modules = [make_graphed_callables( )
elif graph_mode == "individual":
# Graph individual modules
modules = [
make_graphed_callables(
module, module,
generate_data(config.s, bs, config.h, config.nheads, config.kv, dtype, dpa=dpa, warmup=True), generate_data(config, dtype, dpa=dpa, warmup=True),
num_warmup_iters=10, num_warmup_iters=10,
fp8_enabled=fp8) for module in modules] fp8_enabled=fp8,
model = modules[0] if dpa else torch.nn.Sequential(*modules) fp8_weight_caching=fp8_weight_caching,
)
for module in modules
]
model = modules[0] if dpa else _Sequential(*modules)
else: else:
model = modules[0] if dpa else torch.nn.Sequential(*modules) model = modules[0] if dpa else _Sequential(*modules)
# Loss function and optimizer. # Loss function and optimizer.
loss_fn = torch.nn.MSELoss()
if not dpa: if not dpa:
optimizer = optimizer(model.parameters(), lr=0.001) optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
# Launch. # Launch.
for _ in range(10): for _ in range(3):
inputs, target = generate_data(config.s, bs, config.h, config.nheads, config.kv, dtype, dpa=dpa, gen_labels=True) if not dpa:
optimizer.zero_grad(set_to_none=False)
for grad_accumulation_step in range(2):
inputs, grad_output = generate_data(config, dtype, dpa=dpa, return_grad_output=True)
with fp8_autocast(enabled=fp8): with fp8_autocast(enabled=fp8):
output = model(*inputs) kwargs = {}
loss = loss_fn(output, target) if fp8_weight_caching:
loss.backward() kwargs["is_first_microbatch"] = (grad_accumulation_step == 0)
output = model(*inputs, **kwargs)
output.backward(grad_output)
if not dpa: if not dpa:
optimizer.step() optimizer.step()
optimizer.zero_grad()
return get_outputs(model, output) return get_outputs(model, output)
@pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("bs", [1, 2])
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("num_layers", [1, 10]) @pytest.mark.parametrize("num_layers", [1, 3])
@pytest.mark.parametrize("fp8", all_boolean) @pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("fp8_params", all_boolean) @pytest.mark.parametrize("fp8_params", all_boolean)
@pytest.mark.parametrize("fp8_weight_caching", all_boolean)
@pytest.mark.parametrize("module", modules) @pytest.mark.parametrize("module", modules)
@pytest.mark.parametrize("optimizer", optimizers) def test_gpt_make_graphed_callables(
def test_gpt_make_graphed_callables(dtype, bs, model, num_layers, fp8, fp8_params, module, optimizer): dtype: torch.dtype,
model: str,
num_layers: int,
fp8: bool,
fp8_params: bool,
fp8_weight_caching: bool,
module: str,
) -> None:
if fp8 and not fp8_available: if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if fp8_params and not fp8: if fp8_params and not fp8:
pytest.skip("FP8 needed for FP8 parameters.") pytest.skip("FP8 needed for FP8 parameters.")
if fp8_weight_caching and not fp8:
pytest.skip("FP8 needed for FP8 parameters.")
if module == "dpa" and num_layers > 1: if module == "dpa" and num_layers > 1:
pytest.skip("Max 1 layer for DPA.") pytest.skip("Max 1 layer for DPA.")
config = model_configs[model] config = model_configs[model]
outputs = _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, False, module, optimizer) kwargs = dict(
graph_outputs_mode1 = _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, True, module, optimizer, graph_mode="full") config=config,
graph_outputs_mode2 = _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, True, module, optimizer, graph_mode="individual") num_layers=num_layers,
dtype=dtype,
fp8=fp8,
fp8_params=fp8_params,
fp8_weight_caching=fp8_weight_caching,
module=module,
)
outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)
# Check that results match # Check that results match
assert_all_equal(outputs, graph_outputs_mode1) assert_all_equal(outputs, graph_outputs_mode1)
......
...@@ -536,11 +536,6 @@ def make_graphed_callables( ...@@ -536,11 +536,6 @@ def make_graphed_callables(
else: else:
torch.cuda.set_rng_state(original_rng_states) torch.cuda.set_rng_state(original_rng_states)
# Reset FP8 gradients.
for module in modules:
for p in module.parameters():
p.grad = None
# Restore FP8 state. # Restore FP8 state.
restore_fp8_tensors(modules, saved_fp8_tensors) restore_fp8_tensors(modules, saved_fp8_tensors)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment