Unverified Commit 2a0fe783 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Stop storing fused weight tensor in linear modules (#719)



* Support noop concat without providing full tensor

Stop storing fused buffers in linear modules.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug noop cat func
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Construct TE modules in tests with correct dtypes
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add tolerances to numerical tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use plain PyTorch concat when exporting to ONNX
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 14c1ecd0
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import math import math
import os import os
from typing import List, Optional from typing import Dict, List, Optional
import pytest import pytest
import copy import copy
...@@ -79,19 +79,26 @@ def get_causal_attn_mask(sq: int) -> torch.Tensor: ...@@ -79,19 +79,26 @@ def get_causal_attn_mask(sq: int) -> torch.Tensor:
return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool: def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
"""Ensures two lists are equal.""" """Estimated numerical error for a datatype
assert len(l1) == len(l2), "Unequal number of outputs."
failed = False
failed_tensors = ""
for i, (t1, t2) in enumerate(zip(l1, l2)):
if not torch.equal(t1, t2):
failed = True
failed_tensors += f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n"
assert not failed, "Output mismatches in:\n" + failed_tensors
Based on tolerances for torch.testing.assert_close.
def assert_allclose(l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float) -> bool: """
if dtype == torch.float32:
return dict(rtol=1.3e-6, atol=1e-5)
if dtype == torch.float16:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
raise ValueError(f"Unsuppored dtype ({dtype})")
def assert_allclose(
l1: List[torch.Tensor],
l2: List[torch.Tensor],
atol: float,
) -> bool:
"""Ensures two lists are equal.""" """Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs." assert len(l1) == len(l2), "Unequal number of outputs."
for i, (t1, t2) in enumerate(zip(l1, l2)): for i, (t1, t2) in enumerate(zip(l1, l2)):
...@@ -424,13 +431,16 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False ...@@ -424,13 +431,16 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False
output_layernorm=False, output_layernorm=False,
params_dtype=dtype, params_dtype=dtype,
fuse_qkv_params=True, fuse_qkv_params=True,
device="cuda",
) )
.cuda()
) )
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True (config.seq_len, bs, config.hidden_size),
).cuda() dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_hidden_states.retain_grad() te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len) te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
...@@ -464,7 +474,20 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_par ...@@ -464,7 +474,20 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_par
outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False) outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False)
outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True) outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True)
assert_all_equal(outputs, outputs_recompute)
# Check that results match
tols = dtype_tols(dtype)
if dtype in (torch.float16, torch.bfloat16):
tols["atol"] = 1e-4
if fp8 or fp8_model_params:
tols.update(dict(rtol=0.125, atol=0.0675))
for i, (ref, test) in enumerate(zip(outputs, outputs_recompute)):
torch.testing.assert_close(
test,
ref,
msg=f"Mismatch in tensor {i}",
**tols,
)
def _test_e2e_full_recompute( def _test_e2e_full_recompute(
...@@ -481,8 +504,7 @@ def _test_e2e_full_recompute( ...@@ -481,8 +504,7 @@ def _test_e2e_full_recompute(
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8 and fp8_model_params): with fp8_model_init(enabled=fp8 and fp8_model_params):
block = ( block = TransformerLayer(
TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_attention_heads,
...@@ -496,13 +518,15 @@ def _test_e2e_full_recompute( ...@@ -496,13 +518,15 @@ def _test_e2e_full_recompute(
output_layernorm=False, output_layernorm=False,
params_dtype=dtype, params_dtype=dtype,
fuse_qkv_params=True, fuse_qkv_params=True,
) device="cuda",
.cuda()
) )
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=use_reentrant (config.seq_len, bs, config.hidden_size),
).cuda() dtype=dtype,
device="cuda",
requires_grad=use_reentrant,
)
if use_reentrant: if use_reentrant:
te_inp_hidden_states.retain_grad() te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len) te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
...@@ -566,7 +590,19 @@ def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params, ...@@ -566,7 +590,19 @@ def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params,
# Reset bias+GELU fusion flag to avoid contaminating other tests # Reset bias+GELU fusion flag to avoid contaminating other tests
del os.environ["NVTE_BIAS_GELU_NVFUSION"] del os.environ["NVTE_BIAS_GELU_NVFUSION"]
assert_all_equal(outputs, outputs_recompute, names=names) # Check that results match
tols = dtype_tols(dtype)
if dtype in (torch.float16, torch.bfloat16):
tols["atol"] = 1e-3
if fp8 or fp8_model_params:
tols.update(dict(rtol=0.125, atol=0.0675))
for i, (ref, test) in enumerate(zip(outputs, outputs_recompute)):
torch.testing.assert_close(
test,
ref,
msg=f"Mismatch in tensor {i}",
**tols,
)
def _test_e2e_checkpointing_get_model(config, dtype): def _test_e2e_checkpointing_get_model(config, dtype):
...@@ -574,8 +610,7 @@ def _test_e2e_checkpointing_get_model(config, dtype): ...@@ -574,8 +610,7 @@ def _test_e2e_checkpointing_get_model(config, dtype):
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
return ( return TransformerLayer(
TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_attention_heads,
...@@ -588,8 +623,7 @@ def _test_e2e_checkpointing_get_model(config, dtype): ...@@ -588,8 +623,7 @@ def _test_e2e_checkpointing_get_model(config, dtype):
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
params_dtype=dtype, params_dtype=dtype,
) device="cuda",
.cuda()
) )
...@@ -597,8 +631,11 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= ...@@ -597,8 +631,11 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
reset_rng_states() reset_rng_states()
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True (config.seq_len, bs, config.hidden_size),
).cuda() dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_hidden_states.retain_grad() te_inp_hidden_states.retain_grad()
block = _test_e2e_checkpointing_get_model(config, dtype) block = _test_e2e_checkpointing_get_model(config, dtype)
...@@ -666,15 +703,29 @@ def test_gpt_checkpointing(dtype, bs, model): ...@@ -666,15 +703,29 @@ def test_gpt_checkpointing(dtype, bs, model):
config = model_configs[model] config = model_configs[model]
outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False) outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True) outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
assert_all_equal(outputs, outputs_checkpoint)
# Check that results match
tols = dtype_tols(dtype)
if dtype in (torch.float16, torch.bfloat16):
tols.update(dict(rtol=2e-2, atol=2e-3))
for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
torch.testing.assert_close(
test,
ref,
msg=f"Mismatch in tensor {i}",
**tols,
)
def _test_e2e_gpt_accuracy(block, bs, dtype, config): def _test_e2e_gpt_accuracy(block, bs, dtype, config):
reset_rng_states() reset_rng_states()
inp_hidden_states = torch.randn( inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True (config.seq_len, bs, config.hidden_size),
).cuda() dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad() inp_hidden_states.retain_grad()
inp_attn_mask = get_causal_attn_mask(config.seq_len) inp_attn_mask = get_causal_attn_mask(config.seq_len)
...@@ -705,12 +756,12 @@ def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): ...@@ -705,12 +756,12 @@ def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
layernorm_epsilon=config.eps, layernorm_epsilon=config.eps,
attention_dropout=0.1, attention_dropout=0.1,
hidden_dropout=0.1, hidden_dropout=0.1,
params_dtype=dtype,
fuse_qkv_params=True, fuse_qkv_params=True,
qkv_weight_interleaved=False, qkv_weight_interleaved=False,
parallel_attention_mlp=parallel_attention_mlp, parallel_attention_mlp=parallel_attention_mlp,
device="cuda",
) )
.to(dtype=dtype)
.cuda()
.eval() .eval()
) )
...@@ -765,8 +816,11 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): ...@@ -765,8 +816,11 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
reset_rng_states() reset_rng_states()
inp_hidden_states = torch.randn( inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True (config.seq_len, bs, config.hidden_size),
).cuda() dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad() inp_hidden_states.retain_grad()
inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None
...@@ -799,11 +853,11 @@ def test_mha_accuracy(dtype, bs, model, mask_type): ...@@ -799,11 +853,11 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
config.hidden_size, config.hidden_size,
config.num_attention_heads, config.num_attention_heads,
fuse_qkv_params=True, fuse_qkv_params=True,
params_dtype=dtype,
qkv_weight_interleaved=False, qkv_weight_interleaved=False,
input_layernorm=False, input_layernorm=False,
device="cuda",
) )
.to(dtype=dtype)
.cuda()
.eval() .eval()
) )
...@@ -838,8 +892,11 @@ def _test_granular_accuracy(block, bs, dtype, config): ...@@ -838,8 +892,11 @@ def _test_granular_accuracy(block, bs, dtype, config):
reset_rng_states() reset_rng_states()
inp_hidden_states = torch.randn( inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True (config.seq_len, bs, config.hidden_size),
).cuda() dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad() inp_hidden_states.retain_grad()
out = block(inp_hidden_states) out = block(inp_hidden_states)
...@@ -857,10 +914,16 @@ def _test_granular_accuracy(block, bs, dtype, config): ...@@ -857,10 +914,16 @@ def _test_granular_accuracy(block, bs, dtype, config):
def _test_dpa_accuracy(block, bs, dtype, config): def _test_dpa_accuracy(block, bs, dtype, config):
reset_rng_states() reset_rng_states()
mask = torch.triu(torch.ones(config.seq_len, config.seq_len, device="cuda"), diagonal=1).bool() mask = torch.triu(torch.ones(config.seq_len, config.seq_len, dtype=torch.bool, device="cuda"), diagonal=1)
query, key, value = [ query, key, value = [
torch.randn(config.seq_len, bs, config.num_attention_heads, torch.randn(
config.embed, dtype=dtype, requires_grad=True).cuda() for _ in range(3)] (config.seq_len, bs, config.num_attention_heads, config.embed),
dtype=dtype,
device="cuda",
requires_grad=True,
)
for _ in range(3)
]
query.retain_grad() query.retain_grad()
key.retain_grad() key.retain_grad()
...@@ -921,9 +984,9 @@ def test_linear_accuracy(dtype, bs, model): ...@@ -921,9 +984,9 @@ def test_linear_accuracy(dtype, bs, model):
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
bias=True, bias=True,
params_dtype=dtype,
device="cuda",
) )
.to(dtype=dtype)
.cuda()
.eval() .eval()
) )
...@@ -932,9 +995,9 @@ def test_linear_accuracy(dtype, bs, model): ...@@ -932,9 +995,9 @@ def test_linear_accuracy(dtype, bs, model):
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
bias=True, bias=True,
device="cuda",
dtype=dtype,
) )
.to(dtype=dtype)
.cuda()
.eval() .eval()
) )
...@@ -965,10 +1028,10 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): ...@@ -965,10 +1028,10 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
RMSNorm( RMSNorm(
config.hidden_size, config.hidden_size,
eps=eps, eps=eps,
zero_centered_gamma=zero_centered_gamma params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
device="cuda",
) )
.to(dtype=dtype)
.cuda()
.eval() .eval()
) )
...@@ -1009,10 +1072,10 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): ...@@ -1009,10 +1072,10 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
LayerNorm( LayerNorm(
config.hidden_size, config.hidden_size,
eps=eps, eps=eps,
zero_centered_gamma=zero_centered_gamma params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
device="cuda",
) )
.to(dtype=dtype)
.cuda()
.eval() .eval()
) )
...@@ -1058,10 +1121,10 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere ...@@ -1058,10 +1121,10 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere
config.eps, config.eps,
bias=True, bias=True,
normalization=normalization, normalization=normalization,
params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
device="cuda",
) )
.to(dtype=dtype)
.cuda()
.eval() .eval()
) )
...@@ -1112,9 +1175,9 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization): ...@@ -1112,9 +1175,9 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
4 * config.hidden_size, 4 * config.hidden_size,
activation=activation, activation=activation,
normalization=normalization, normalization=normalization,
params_dtype=dtype,
device="cuda",
) )
.to(dtype=dtype)
.cuda()
.eval() .eval()
) )
...@@ -1229,11 +1292,11 @@ def test_gpt_cuda_graph(dtype, bs, model): ...@@ -1229,11 +1292,11 @@ def test_gpt_cuda_graph(dtype, bs, model):
hidden_dropout=0.1, hidden_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
kv_channels=config.embed, kv_channels=config.embed,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
device="cuda",
) )
.to(dtype=dtype)
.cuda()
) )
graphed_block = copy.deepcopy(block) graphed_block = copy.deepcopy(block)
...@@ -1257,8 +1320,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params): ...@@ -1257,8 +1320,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params):
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8_model_params): with fp8_model_init(enabled=fp8_model_params):
block = ( block = TransformerLayer(
TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_attention_heads,
...@@ -1272,13 +1334,15 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params): ...@@ -1272,13 +1334,15 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params):
output_layernorm=False, output_layernorm=False,
params_dtype=dtype, params_dtype=dtype,
fuse_qkv_params=True, fuse_qkv_params=True,
) device="cuda",
.cuda()
) )
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True (config.seq_len, bs, config.hidden_size),
).cuda() dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_hidden_states.retain_grad() te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len) te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
...@@ -1306,7 +1370,18 @@ def test_gpt_fp8_parameters(dtype, bs, model): ...@@ -1306,7 +1370,18 @@ def test_gpt_fp8_parameters(dtype, bs, model):
outputs = _test_gpt_fp8_parameters(bs, dtype, config, False) outputs = _test_gpt_fp8_parameters(bs, dtype, config, False)
outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True) outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True)
assert_all_equal(outputs, outputs_fp8_params)
# Check that results match
tols = dict(rtol=0.125, atol=0.0675)
for i, (ref, test) in enumerate(zip(outputs, outputs_fp8_params)):
torch.testing.assert_close(
test,
ref,
msg=f"Mismatch in tensor {i}",
rtol=0.125,
atol=0.0675,
)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
...@@ -1323,8 +1398,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): ...@@ -1323,8 +1398,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
# other layer. Set `*dropout` values to 0 to make sure the forward pass # other layer. Set `*dropout` values to 0 to make sure the forward pass
# is identical to the other layer. # is identical to the other layer.
torch.manual_seed(0) torch.manual_seed(0)
block_sbhd = ( block_sbhd = TransformerLayer(
TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_attention_heads,
...@@ -1334,20 +1408,18 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): ...@@ -1334,20 +1408,18 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
hidden_dropout=0, hidden_dropout=0,
attention_dropout=0, attention_dropout=0,
kv_channels=config.embed, kv_channels=config.embed,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
attn_input_format="sbhd" device="cuda",
) attn_input_format="sbhd",
.to(dtype=dtype)
.cuda()
) )
# Set `torch.manual_seed` to make sure the weights are identical to the # Set `torch.manual_seed` to make sure the weights are identical to the
# other layer. Set `*dropout` values to 0 to make sure the forward pass # other layer. Set `*dropout` values to 0 to make sure the forward pass
# is identical to the other layer. # is identical to the other layer.
torch.manual_seed(0) torch.manual_seed(0)
block_bshd = ( block_bshd = TransformerLayer(
TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_attention_heads,
...@@ -1357,20 +1429,22 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): ...@@ -1357,20 +1429,22 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
hidden_dropout=0, hidden_dropout=0,
attention_dropout=0, attention_dropout=0,
kv_channels=config.embed, kv_channels=config.embed,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
attn_input_format="bshd" device="cuda",
) attn_input_format="bshd",
.to(dtype=dtype)
.cuda()
) )
for (n1, p1), (n2, p2) in zip(block_bshd.named_parameters(), block_sbhd.named_parameters()): for (n1, p1), (n2, p2) in zip(block_bshd.named_parameters(), block_sbhd.named_parameters()):
assert torch.all(torch.eq(p1, p2)), f"{n1}, {n2} not identical" assert torch.all(torch.eq(p1, p2)), f"{n1}, {n2} not identical"
x_sbhd = torch.randn( x_sbhd = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True (config.seq_len, bs, config.hidden_size),
).to(dtype).cuda() dtype=dtype,
device="cuda",
requires_grad=True,
)
x_bshd = x_sbhd.transpose(0,1).contiguous() x_bshd = x_sbhd.transpose(0,1).contiguous()
...@@ -1384,7 +1458,11 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): ...@@ -1384,7 +1458,11 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
torch.manual_seed(0) torch.manual_seed(0)
y_bshd = block_bshd(x_bshd) y_bshd = block_bshd(x_bshd)
assert_all_equal([y_bshd], [y_sbhd.transpose(0,1).contiguous()]) # Check that results match
torch.testing.assert_close(
y_bshd,
y_sbhd.transpose(0,1).contiguous(),
)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
...@@ -1424,10 +1502,10 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ...@@ -1424,10 +1502,10 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
num_attention_heads=H, num_attention_heads=H,
attn_input_format=input_format, attn_input_format=input_format,
layer_number=layer_number, layer_number=layer_number,
attention_dropout = 0.0 attention_dropout = 0.0,
params_dtype=dtype,
device="cuda",
) )
.to(dtype=dtype)
.cuda()
.eval() .eval()
) )
else: else:
...@@ -1437,9 +1515,9 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ...@@ -1437,9 +1515,9 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
num_attention_heads=H, num_attention_heads=H,
qkv_format=input_format, qkv_format=input_format,
layer_number=layer_number, layer_number=layer_number,
attention_dropout = 0.0 attention_dropout = 0.0,
params_dtype=dtype,
) )
.to(dtype=dtype)
.cuda() .cuda()
.eval() .eval()
) )
......
...@@ -172,10 +172,18 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): ...@@ -172,10 +172,18 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, config.batch_size, config.hidden_size, dtype=torch.float32, requires_grad=True (config.seq_len, config.batch_size, config.hidden_size),
).cuda() dtype=torch.float32,
device="cuda",
requires_grad=True,
)
te_inp_hidden_states.retain_grad() te_inp_hidden_states.retain_grad()
te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool() te_inp_attn_mask = torch.randint(
2,
(1, 1, config.seq_len, config.seq_len),
dtype=torch.bool,
device="cuda",
)
if skip_wgrad: if skip_wgrad:
_disable_wgrads(block) _disable_wgrads(block)
...@@ -198,9 +206,17 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): ...@@ -198,9 +206,17 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True (config.seq_len, config.batch_size, config.hidden_size),
).cuda() dtype=dtype,
te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool() device="cuda",
requires_grad=True,
)
te_inp_attn_mask = torch.randint(
2,
(1, 1, config.seq_len, config.seq_len),
dtype=torch.bool,
device="cuda",
)
if skip_wgrad: if skip_wgrad:
_disable_wgrads(block) _disable_wgrads(block)
...@@ -227,8 +243,11 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci ...@@ -227,8 +243,11 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True (config.seq_len, config.batch_size, config.hidden_size),
).cuda() dtype=dtype,
device="cuda",
requires_grad=True,
)
if skip_wgrad: if skip_wgrad:
_disable_wgrads(block) _disable_wgrads(block)
...@@ -250,10 +269,18 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): ...@@ -250,10 +269,18 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True (config.seq_len, config.batch_size, config.hidden_size),
).cuda() dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_attn_mask = torch.rand(torch.Size([config.batch_size, 1, 1, config.seq_len])).cuda() > 0.5 te_inp_attn_mask = torch.randint(
2,
(config.batch_size, 1, 1, config.seq_len),
dtype=torch.bool,
device="cuda",
)
if skip_wgrad: if skip_wgrad:
_disable_wgrads(block) _disable_wgrads(block)
...@@ -268,10 +295,24 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad): ...@@ -268,10 +295,24 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True (config.seq_len, config.batch_size, config.hidden_size),
).cuda() dtype=dtype,
te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool() device="cuda",
enc_dec_attn_mask = torch.rand(torch.Size([config.batch_size, 1, 1, config.seq_len])).cuda() > 0.5 requires_grad=True,
)
te_inp_attn_mask = torch.randint(
2,
(1, 1, config.seq_len, config.seq_len),
dtype=torch.bool,
device="cuda",
)
enc_dec_attn_mask = torch.randint(
2,
(config.batch_size, 1, 1, config.seq_len),
dtype=torch.bool,
device="cuda",
)
if skip_wgrad: if skip_wgrad:
_disable_wgrads(block) _disable_wgrads(block)
...@@ -294,8 +335,11 @@ def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad ...@@ -294,8 +335,11 @@ def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad
pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.") pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")
te_inp = torch.randn( te_inp = torch.randn(
config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=not skip_dgrad (config.seq_len, config.batch_size, config.hidden_size),
).cuda() dtype=dtype,
device="cuda",
requires_grad=not skip_dgrad,
)
if skip_wgrad: if skip_wgrad:
_disable_wgrads(block) _disable_wgrads(block)
...@@ -315,8 +359,10 @@ def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad) ...@@ -315,8 +359,10 @@ def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.") pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")
te_inp = torch.randn( te_inp = torch.randn(
config.seq_len, config.batch_size, config.hidden_size, requires_grad=True (config.seq_len, config.batch_size, config.hidden_size),
).cuda() device="cuda",
requires_grad=True,
)
te_inp.retain_grad() te_inp.retain_grad()
with torch.autocast(device_type="cuda", enabled=True, dtype=dtype): with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
...@@ -371,16 +417,14 @@ def test_sanity_layernorm_linear(dtype, fp8_recipe, model, skip_wgrad, ...@@ -371,16 +417,14 @@ def test_sanity_layernorm_linear(dtype, fp8_recipe, model, skip_wgrad,
sigma = 0.023 sigma = 0.023
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
block = ( block = LayerNormLinear(
LayerNormLinear(
config.hidden_size, config.hidden_size,
config.hidden_size * 3, config.hidden_size * 3,
init_method=init_method, init_method=init_method,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
normalization=normalization, normalization=normalization,
) params_dtype=dtype,
.to(dtype=dtype) device="cuda",
.cuda()
) )
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad) _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
...@@ -402,12 +446,12 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad): ...@@ -402,12 +446,12 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
sigma = 0.023 sigma = 0.023
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = ( block = Linear(
Linear( config.hidden_size,
config.hidden_size, config.hidden_size, init_method=output_layer_init_method config.hidden_size,
) init_method=output_layer_init_method,
.to(dtype=dtype) params_dtype=dtype,
.cuda() device="cuda",
) )
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad) _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
...@@ -435,8 +479,7 @@ def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad, ...@@ -435,8 +479,7 @@ def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad,
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = ( block = LayerNormMLP(
LayerNormMLP(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
init_method=init_method, init_method=init_method,
...@@ -444,9 +487,8 @@ def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad, ...@@ -444,9 +487,8 @@ def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
activation=activation, activation=activation,
normalization=normalization, normalization=normalization,
) params_dtype=dtype,
.to(dtype=dtype) device="cuda",
.cuda()
) )
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad) _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
...@@ -477,8 +519,7 @@ def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad, ...@@ -477,8 +519,7 @@ def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad,
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = ( block = TransformerLayer(
TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_attention_heads,
...@@ -487,17 +528,16 @@ def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad, ...@@ -487,17 +528,16 @@ def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad,
hidden_dropout=0.1, hidden_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
kv_channels=config.kv_channels, kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
bias=bias, bias=bias,
activation=activation, activation=activation,
normalization=normalization, normalization=normalization,
device="cuda",
parallel_attention_mlp=parallel_attention_mlp, parallel_attention_mlp=parallel_attention_mlp,
) )
.to(dtype=dtype)
.cuda()
)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload) _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload)
...@@ -546,8 +586,7 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, ...@@ -546,8 +586,7 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = ( block = TransformerLayer(
TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_attention_heads,
...@@ -556,14 +595,13 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, ...@@ -556,14 +595,13 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
hidden_dropout=0.1, hidden_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
kv_channels=config.kv_channels, kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=True, apply_residual_connection_post_layernorm=True,
output_layernorm=True, output_layernorm=True,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
self_attn_mask_type="padding", self_attn_mask_type="padding",
normalization=normalization, normalization=normalization,
) device="cuda",
.to(dtype=dtype)
.cuda()
) )
_test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad) _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad)
...@@ -607,8 +645,7 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, ...@@ -607,8 +645,7 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = ( block = TransformerLayer(
TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_attention_heads,
...@@ -617,14 +654,13 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, ...@@ -617,14 +654,13 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
hidden_dropout=0.1, hidden_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
kv_channels=config.kv_channels, kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
layer_type="decoder", layer_type="decoder",
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
normalization=normalization, normalization=normalization,
) device="cuda",
.to(dtype=dtype)
.cuda()
) )
_test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad) _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad)
...@@ -665,8 +701,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): ...@@ -665,8 +701,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = ( block = TransformerLayer(
TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_attention_heads,
...@@ -675,9 +710,8 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): ...@@ -675,9 +710,8 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
hidden_dropout=0.1, hidden_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
kv_channels=config.kv_channels, kv_channels=config.kv_channels,
) params_dtype=torch.float32,
.to(dtype=torch.float32) device="cuda",
.cuda()
) )
_test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad) _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad)
...@@ -700,8 +734,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): ...@@ -700,8 +734,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = ( block = TransformerLayer(
TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_attention_heads,
...@@ -710,12 +743,11 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): ...@@ -710,12 +743,11 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
hidden_dropout=0.1, hidden_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
kv_channels=config.kv_channels, kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
drop_path_rate=1.0, drop_path_rate=1.0,
) device="cuda",
.to(dtype=dtype)
.cuda()
) )
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False) _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
...@@ -738,8 +770,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): ...@@ -738,8 +770,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = ( block = TransformerLayer(
TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_attention_heads,
...@@ -748,12 +779,11 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): ...@@ -748,12 +779,11 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
hidden_dropout=0.1, hidden_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
kv_channels=config.kv_channels, kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
fuse_qkv_params=True, fuse_qkv_params=True,
) device="cuda",
.to(dtype=dtype)
.cuda()
) )
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False) _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
...@@ -777,8 +807,7 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra ...@@ -777,8 +807,7 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = ( block = TransformerLayer(
TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_attention_heads,
...@@ -787,14 +816,13 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra ...@@ -787,14 +816,13 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra
hidden_dropout=0.1, hidden_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
kv_channels=config.kv_channels, kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
fuse_qkv_params=True, fuse_qkv_params=True,
fuse_wgrad_accumulation=True, fuse_wgrad_accumulation=True,
) device="cuda",
.to(dtype=dtype)
.cuda()
) )
_test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad) _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
...@@ -820,8 +848,7 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm ...@@ -820,8 +848,7 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = ( block = TransformerLayer(
TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_attention_heads,
...@@ -830,20 +857,19 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm ...@@ -830,20 +857,19 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm
hidden_dropout=0.1, hidden_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
kv_channels=config.kv_channels, kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
fuse_qkv_params=True, fuse_qkv_params=True,
normalization=normalization, normalization=normalization,
) device="cuda",
.to(dtype=dtype)
.cuda()
) )
_test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad) _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad)
def test_model_multiple_cast(): def test_model_multiple_cast():
a = torch.zeros((16,16)).cuda() a = torch.zeros((16,16), device="cuda")
m = Linear(16,32) m = Linear(16,32)
y = m(a) y = m(a)
......
...@@ -10,6 +10,7 @@ from dataclasses import dataclass ...@@ -10,6 +10,7 @@ from dataclasses import dataclass
import torch import torch
from .. import cpp_extensions as tex from .. import cpp_extensions as tex
from ..export import is_in_onnx_export_mode
from ..fp8 import get_fp8_te_dtype from ..fp8 import get_fp8_te_dtype
from ..utils import get_default_init_method from ..utils import get_default_init_method
...@@ -99,32 +100,79 @@ def _apply_normalization(inputmat:torch.Tensor, ...@@ -99,32 +100,79 @@ def _apply_normalization(inputmat:torch.Tensor,
class _NoopCatFunc(torch.autograd.Function): class _NoopCatFunc(torch.autograd.Function):
"""No-op concatenate tensors along dim 0 """Concatenate tensors, doing a no-op if possible
`full_tensor` is assumed to already be the concatenation of See _noop_cat.
`tensors`, i.e. they occupy the same memory with the correct
offsets.
""" """
@staticmethod @staticmethod
def forward( def forward(
ctx, ctx: Any,
split_ranges: List[Tuple[int, int]], dim: int,
full_tensor: torch.Tensor,
*tensors: Tuple[torch.Tensor, ...], *tensors: Tuple[torch.Tensor, ...],
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=unused-argument
# Check first tensor
if not tensors:
raise ValueError("Attempted to concatenate 0 tensors")
num_dims = tensors[0].dim()
if not -num_dims <= dim < num_dims:
raise ValueError(
"Attempted to concatenate tensor "
f"with shape {list(tensors[0].size())} along dim {dim}"
)
dim %= num_dims
# Check remaining tensors
out_shape = list(tensors[0].size())
split_ranges = [(0, tensors[0].size(dim))]
for tensor in tensors[1:]:
in_shape = list(tensor.size())
if (
len(in_shape) != num_dims
or in_shape[:dim] != out_shape[:dim]
or in_shape[dim+1:] != out_shape[dim+1:]
):
raise ValueError(
"Attempted to concatenate tensors with shapes "
f"{[list(tensor.size()) for tensor in tensors]} "
f"along dim {dim}"
)
split_start = out_shape[dim]
split_end = split_start + in_shape[dim]
out_shape[dim] = split_end
split_ranges.append((split_start, split_end))
# Save state for backward
ctx.dim = dim
ctx.split_ranges = split_ranges ctx.split_ranges = split_ranges
assert not full_tensor.requires_grad, "Concatenated tensor should not require gradient"
out = full_tensor.new() # Out-of-place concatenation if needed
dtype = tensors[0].dtype
device = tensors[0].device
strides = tensors[0].stride()
data_ptr_stride = strides[dim] * tensors[0].element_size()
data_ptr = tensors[0].data_ptr() + tensors[0].size(dim) * data_ptr_stride
for tensor in tensors[1:]:
if (
tensor.dtype != dtype
or tensor.device != device
or tensor.stride() != strides
or tensor.data_ptr() != data_ptr
):
return torch.cat(tensors, dim=dim)
data_ptr += tensor.size(dim) * data_ptr_stride
# No-op concatenation
out = tensors[0].new()
out.set_( out.set_(
full_tensor.untyped_storage(), tensors[0].untyped_storage(),
full_tensor.storage_offset(), tensors[0].storage_offset(),
full_tensor.size(), out_shape,
full_tensor.stride(), strides,
) )
out.requires_grad = True out.requires_grad = any(tensor.requires_grad for tensor in tensors)
return out return out
@staticmethod @staticmethod
...@@ -132,64 +180,32 @@ class _NoopCatFunc(torch.autograd.Function): ...@@ -132,64 +180,32 @@ class _NoopCatFunc(torch.autograd.Function):
ctx, ctx,
grad_output: torch.Tensor, grad_output: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]: ) -> Tuple[Optional[torch.Tensor], ...]:
grads = [ grad_inputs = []
grad_output[split_start:split_end] for split_start, split_end in ctx.split_ranges:
for split_start, split_end in ctx.split_ranges slices = [slice(None)] * grad_output.dim()
] slices[ctx.dim] = slice(split_start, split_end)
return None, None, *grads grad_inputs.append(grad_output[tuple(slices)])
return None, *grad_inputs
def _noop_cat( def _noop_cat(
tensors: List[torch.Tensor], tensors: List[torch.Tensor],
full_tensor: torch.Tensor, dim: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
"""Concatenate tensors along dim 0, doing a no-op if possible """Concatenate tensors, doing a no-op if possible
If `full_tensor` is already the concatenation of `tensors`, i.e. If tensors are already concatenated in memory, a tensor view of
they occupy the same memory region with the correct offsets, then that memory region will be returned. Otherwise the tensors will be
no copies are performed. Otherwise the buffers in all the tensors concatenated out-of-place, as usual.
are reallocated so that another call would result in a no-op.
In the backward pass, gradients to `partial_tensors` will just be
tensor views.
""" """
if not tensors:
# Determine split points raise ValueError("Attempted to concatenate 0 tensors")
split_ranges = [] if len(tensors) == 1:
full_tensor_shape = full_tensor.size() return tensors[0]
offset = 0 if is_in_onnx_export_mode():
for tensor in tensors: return torch.cat(tensors, dim=dim)
tensor_shape = tensor.size() return _NoopCatFunc.apply(dim, *tensors)
if tensor_shape[1:] != full_tensor_shape[1:]:
raise ValueError(
f"Attempting to concatenate tensor with shape={list(tensor_shape)} "
f"into a tensor with shape={list(full_tensor_shape)}"
)
split_start = offset
offset += tensor_shape[0]
split_end = offset
split_ranges.append((split_start, split_end))
if offset != full_tensor_shape[0]:
raise ValueError(
f"Attempting to concatenate tensors with total shape[0]={offset} "
f"into a tensor with shape[0]={full_tensor_shape[0]}"
)
# Reallocate buffers if no-op concat isn't possible
need_to_reallocate = False
for tensor, (split_start, _) in zip(tensors, split_ranges):
if tensor.data_ptr() != full_tensor[split_start].data_ptr():
need_to_reallocate = True
break
if need_to_reallocate:
with torch.no_grad():
full_tensor.data = torch.cat(tensors)
for tensor, (split_start, split_end) in zip(tensors, split_ranges):
tensor.data = full_tensor[split_start:split_end]
# Perform no-op concat
return _NoopCatFunc.apply(split_ranges, full_tensor, *tensors)
@dataclass @dataclass
......
...@@ -926,17 +926,20 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -926,17 +926,20 @@ class LayerNormLinear(TransformerEngineBaseModule):
else: else:
self.layer_norm_bias = None self.layer_norm_bias = None
self.weight_tensor = torch.empty( # Contiguous buffers for params
self.out_features, self.in_features, weight_tensor = torch.empty(
device=device, dtype=params_dtype) self.out_features,
self.in_features,
device=device,
dtype=params_dtype,
)
bias_tensor = None
if self.use_bias: if self.use_bias:
self.bias_tensor = torch.empty( bias_tensor = torch.empty(
self.out_features, self.out_features,
device=device, device=device,
dtype=params_dtype) dtype=params_dtype,
else: )
self.bias_tensor = torch.Tensor().to(dtype=params_dtype, device=device)
# Configure parameter splits # Configure parameter splits
self.weight_names = [] self.weight_names = []
...@@ -982,7 +985,11 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -982,7 +985,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
) )
self.parameter_split_sizes[i] = size // self.tp_size self.parameter_split_sizes[i] = size // self.tp_size
# Construct parameters from weight and bias buffers # Construct weight parameters
# Note: Register weights together so that they are adjacent to
# each other in LayerNormLinear.parameters(). This makes it
# more likely that they will stay contiguous if the weights
# are manipulated externally, e.g. by FSDP.
offset = 0 offset = 0
for i, split_size in enumerate(self.parameter_split_sizes): for i, split_size in enumerate(self.parameter_split_sizes):
split_start = offset split_start = offset
...@@ -998,32 +1005,30 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -998,32 +1005,30 @@ class LayerNormLinear(TransformerEngineBaseModule):
) )
# Construct weight parameter # Construct weight parameter
weight = self.weight_tensor self.register_parameter(
if is_subview: self.weight_names[i],
weight = weight[split_start:split_end] torch.nn.Parameter(weight_tensor[split_start:split_end]),
weight = torch.nn.Parameter(weight)
self.register_parameter(self.weight_names[i], weight,
init_fn=init_method, init_fn=init_method,
get_rng_state_tracker=get_rng_state_tracker, get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT) fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
# Construct bias parameter if needed # Construct bias parameters if needed
if self.use_bias: if self.use_bias:
bias = self.bias_tensor offset = 0
if is_subview: for i, split_size in enumerate(self.parameter_split_sizes):
bias = bias[split_start:split_end] split_start = offset
bias = torch.nn.Parameter(bias) offset += split_size
self.register_parameter(self.bias_names[i], bias, split_end = offset
init_fn=init_method_constant(0.0)) self.register_parameter(
self.bias_names[i],
torch.nn.Parameter(bias_tensor[split_start:split_end]),
init_fn=init_method_constant(0.0),
)
else: else:
for name in self.bias_names:
bias = torch.Tensor().to(dtype=params_dtype, device=device) bias = torch.Tensor().to(dtype=params_dtype, device=device)
setattr(self, self.bias_names[i], bias) setattr(self, name, bias)
# Concatenated tensors are not needed if not splitting
# into multiple parameters
if not is_subview:
del self.weight_tensor
del self.bias_tensor
if self.primary_weights_in_fp8: if self.primary_weights_in_fp8:
self.init_fp8_metadata() self.init_fp8_metadata()
...@@ -1150,24 +1155,15 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1150,24 +1155,15 @@ class LayerNormLinear(TransformerEngineBaseModule):
"Need to run inside fp8_autocast region when weights are stored in FP8." "Need to run inside fp8_autocast region when weights are stored in FP8."
# Get concatenated weight and bias tensors # Get concatenated weight and bias tensors
if len(self.parameter_split_sizes) == 1:
weight_tensor = getattr(self, self.weight_names[0])
bias_tensor = getattr(self, self.bias_names[0])
elif torch.is_grad_enabled():
weight_tensor = _noop_cat( weight_tensor = _noop_cat(
[getattr(self, name) for name in self.weight_names], [getattr(self, name) for name in self.weight_names],
self.weight_tensor,
) )
if self.use_bias: if self.use_bias:
bias_tensor = _noop_cat( bias_tensor = _noop_cat(
[getattr(self, name) for name in self.bias_names], [getattr(self, name) for name in self.bias_names],
self.bias_tensor,
) )
else: else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused bias_tensor = getattr(self, self.bias_names[0]) # Unused
else:
weight_tensor = self.weight_tensor
bias_tensor = self.bias_tensor
# Fetch the fp8 weights placeholders (for linear/gemm) # Fetch the fp8 weights placeholders (for linear/gemm)
weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad( weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad(
......
...@@ -777,14 +777,20 @@ class Linear(TransformerEngineBaseModule): ...@@ -777,14 +777,20 @@ class Linear(TransformerEngineBaseModule):
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
self.weight_tensor = torch.empty( # Contiguous buffers for params
self.out_features, self.in_features, weight_tensor = torch.empty(
device=device, dtype=params_dtype) self.out_features,
self.in_features,
device=device,
dtype=params_dtype,
)
bias_tensor = None
if self.use_bias: if self.use_bias:
self.bias_tensor = torch.empty(self.out_features, device=device, dtype=params_dtype) bias_tensor = torch.empty(
else: self.out_features,
self.bias_tensor = torch.Tensor().to(dtype=params_dtype, device=device) device=device,
dtype=params_dtype,
)
# Configure parameter splits # Configure parameter splits
self.weight_names = [] self.weight_names = []
...@@ -830,7 +836,11 @@ class Linear(TransformerEngineBaseModule): ...@@ -830,7 +836,11 @@ class Linear(TransformerEngineBaseModule):
) )
self.parameter_split_sizes[i] = size // self.tp_size self.parameter_split_sizes[i] = size // self.tp_size
# Construct parameters from weight and bias buffers # Construct weight parameters
# Note: Register weights together so that they are adjacent to
# each other in Linear.parameters(). This makes it more likely
# that they will stay contiguous if the weights are
# manipulated externally, e.g. by FSDP.
offset = 0 offset = 0
for i, split_size in enumerate(self.parameter_split_sizes): for i, split_size in enumerate(self.parameter_split_sizes):
split_start = offset split_start = offset
...@@ -846,32 +856,30 @@ class Linear(TransformerEngineBaseModule): ...@@ -846,32 +856,30 @@ class Linear(TransformerEngineBaseModule):
) )
# Construct weight parameter # Construct weight parameter
weight = self.weight_tensor self.register_parameter(
if is_subview: self.weight_names[i],
weight = weight[split_start:split_end] torch.nn.Parameter(weight_tensor[split_start:split_end]),
weight = torch.nn.Parameter(weight)
self.register_parameter(self.weight_names[i], weight,
init_fn=init_method, init_fn=init_method,
get_rng_state_tracker=get_rng_state_tracker, get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT) fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
# Construct bias parameter if needed # Construct bias parameters if needed
if self.use_bias: if self.use_bias:
bias = self.bias_tensor offset = 0
if is_subview: for i, split_size in enumerate(self.parameter_split_sizes):
bias = bias[split_start:split_end] split_start = offset
bias = torch.nn.Parameter(bias) offset += split_size
self.register_parameter(self.bias_names[i], bias, split_end = offset
init_fn=init_method_constant(0.0)) self.register_parameter(
self.bias_names[i],
torch.nn.Parameter(bias_tensor[split_start:split_end]),
init_fn=init_method_constant(0.0),
)
else: else:
for name in self.bias_names:
bias = torch.Tensor().to(dtype=params_dtype, device=device) bias = torch.Tensor().to(dtype=params_dtype, device=device)
setattr(self, self.bias_names[i], bias) setattr(self, name, bias)
# Concatenated tensors are not needed if not splitting
# into multiple parameters
if not is_subview:
del self.weight_tensor
del self.bias_tensor
if self.primary_weights_in_fp8: if self.primary_weights_in_fp8:
self.init_fp8_metadata() self.init_fp8_metadata()
...@@ -974,24 +982,15 @@ class Linear(TransformerEngineBaseModule): ...@@ -974,24 +982,15 @@ class Linear(TransformerEngineBaseModule):
is_first_module_in_mha = is_first_module_in_mha and self.fp8_meta["recipe"].fp8_mha is_first_module_in_mha = is_first_module_in_mha and self.fp8_meta["recipe"].fp8_mha
# Get concatenated weight and bias tensors # Get concatenated weight and bias tensors
if len(self.parameter_split_sizes) == 1:
weight_tensor = getattr(self, self.weight_names[0])
bias_tensor = getattr(self, self.bias_names[0])
elif torch.is_grad_enabled():
weight_tensor = _noop_cat( weight_tensor = _noop_cat(
[getattr(self, name) for name in self.weight_names], [getattr(self, name) for name in self.weight_names],
self.weight_tensor,
) )
if self.use_bias: if self.use_bias:
bias_tensor = _noop_cat( bias_tensor = _noop_cat(
[getattr(self, name) for name in self.bias_names], [getattr(self, name) for name in self.bias_names],
self.bias_tensor,
) )
else: else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused bias_tensor = getattr(self, self.bias_names[0]) # Unused
else:
weight_tensor = self.weight_tensor
bias_tensor = self.bias_tensor
# Fetch the fp8 weights placeholders (for linear/gemm) # Fetch the fp8 weights placeholders (for linear/gemm)
weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad( weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment