"docs/vscode:/vscode.git/clone" did not exist on "40931961e3983c23c4d23f02a3cbd281d808390d"
Unverified Commit 2a0fe783 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Stop storing fused weight tensor in linear modules (#719)



* Support noop concat without providing full tensor

Stop storing fused buffers in linear modules.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug noop cat func
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Construct TE modules in tests with correct dtypes
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add tolerances to numerical tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use plain PyTorch concat when exporting to ONNX
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 14c1ecd0
......@@ -4,7 +4,7 @@
import math
import os
from typing import List, Optional
from typing import Dict, List, Optional
import pytest
import copy
......@@ -79,19 +79,26 @@ def get_causal_attn_mask(sq: int) -> torch.Tensor:
return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
"""Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs."
failed = False
failed_tensors = ""
for i, (t1, t2) in enumerate(zip(l1, l2)):
if not torch.equal(t1, t2):
failed = True
failed_tensors += f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n"
assert not failed, "Output mismatches in:\n" + failed_tensors
def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
def assert_allclose(l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float) -> bool:
"""
if dtype == torch.float32:
return dict(rtol=1.3e-6, atol=1e-5)
if dtype == torch.float16:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
raise ValueError(f"Unsuppored dtype ({dtype})")
def assert_allclose(
l1: List[torch.Tensor],
l2: List[torch.Tensor],
atol: float,
) -> bool:
"""Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs."
for i, (t1, t2) in enumerate(zip(l1, l2)):
......@@ -424,13 +431,16 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False
output_layernorm=False,
params_dtype=dtype,
fuse_qkv_params=True,
device="cuda",
)
.cuda()
)
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
(config.seq_len, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
......@@ -464,7 +474,20 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_par
outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False)
outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True)
assert_all_equal(outputs, outputs_recompute)
# Check that results match
tols = dtype_tols(dtype)
if dtype in (torch.float16, torch.bfloat16):
tols["atol"] = 1e-4
if fp8 or fp8_model_params:
tols.update(dict(rtol=0.125, atol=0.0675))
for i, (ref, test) in enumerate(zip(outputs, outputs_recompute)):
torch.testing.assert_close(
test,
ref,
msg=f"Mismatch in tensor {i}",
**tols,
)
def _test_e2e_full_recompute(
......@@ -481,8 +504,7 @@ def _test_e2e_full_recompute(
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8 and fp8_model_params):
block = (
TransformerLayer(
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
......@@ -496,13 +518,15 @@ def _test_e2e_full_recompute(
output_layernorm=False,
params_dtype=dtype,
fuse_qkv_params=True,
)
.cuda()
device="cuda",
)
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=use_reentrant
).cuda()
(config.seq_len, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=use_reentrant,
)
if use_reentrant:
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
......@@ -566,7 +590,19 @@ def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params,
# Reset bias+GELU fusion flag to avoid contaminating other tests
del os.environ["NVTE_BIAS_GELU_NVFUSION"]
assert_all_equal(outputs, outputs_recompute, names=names)
# Check that results match
tols = dtype_tols(dtype)
if dtype in (torch.float16, torch.bfloat16):
tols["atol"] = 1e-3
if fp8 or fp8_model_params:
tols.update(dict(rtol=0.125, atol=0.0675))
for i, (ref, test) in enumerate(zip(outputs, outputs_recompute)):
torch.testing.assert_close(
test,
ref,
msg=f"Mismatch in tensor {i}",
**tols,
)
def _test_e2e_checkpointing_get_model(config, dtype):
......@@ -574,22 +610,20 @@ def _test_e2e_checkpointing_get_model(config, dtype):
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
return (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
params_dtype=dtype,
)
.cuda()
return TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
params_dtype=dtype,
device="cuda",
)
......@@ -597,8 +631,11 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
reset_rng_states()
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
(config.seq_len, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_hidden_states.retain_grad()
block = _test_e2e_checkpointing_get_model(config, dtype)
......@@ -666,15 +703,29 @@ def test_gpt_checkpointing(dtype, bs, model):
config = model_configs[model]
outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
assert_all_equal(outputs, outputs_checkpoint)
# Check that results match
tols = dtype_tols(dtype)
if dtype in (torch.float16, torch.bfloat16):
tols.update(dict(rtol=2e-2, atol=2e-3))
for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
torch.testing.assert_close(
test,
ref,
msg=f"Mismatch in tensor {i}",
**tols,
)
def _test_e2e_gpt_accuracy(block, bs, dtype, config):
reset_rng_states()
inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
(config.seq_len, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()
inp_attn_mask = get_causal_attn_mask(config.seq_len)
......@@ -705,12 +756,12 @@ def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
layernorm_epsilon=config.eps,
attention_dropout=0.1,
hidden_dropout=0.1,
params_dtype=dtype,
fuse_qkv_params=True,
qkv_weight_interleaved=False,
parallel_attention_mlp=parallel_attention_mlp,
device="cuda",
)
.to(dtype=dtype)
.cuda()
.eval()
)
......@@ -765,8 +816,11 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
reset_rng_states()
inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
(config.seq_len, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()
inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None
......@@ -799,11 +853,11 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
config.hidden_size,
config.num_attention_heads,
fuse_qkv_params=True,
params_dtype=dtype,
qkv_weight_interleaved=False,
input_layernorm=False,
device="cuda",
)
.to(dtype=dtype)
.cuda()
.eval()
)
......@@ -838,8 +892,11 @@ def _test_granular_accuracy(block, bs, dtype, config):
reset_rng_states()
inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
(config.seq_len, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()
out = block(inp_hidden_states)
......@@ -857,10 +914,16 @@ def _test_granular_accuracy(block, bs, dtype, config):
def _test_dpa_accuracy(block, bs, dtype, config):
reset_rng_states()
mask = torch.triu(torch.ones(config.seq_len, config.seq_len, device="cuda"), diagonal=1).bool()
mask = torch.triu(torch.ones(config.seq_len, config.seq_len, dtype=torch.bool, device="cuda"), diagonal=1)
query, key, value = [
torch.randn(config.seq_len, bs, config.num_attention_heads,
config.embed, dtype=dtype, requires_grad=True).cuda() for _ in range(3)]
torch.randn(
(config.seq_len, bs, config.num_attention_heads, config.embed),
dtype=dtype,
device="cuda",
requires_grad=True,
)
for _ in range(3)
]
query.retain_grad()
key.retain_grad()
......@@ -921,9 +984,9 @@ def test_linear_accuracy(dtype, bs, model):
config.hidden_size,
4 * config.hidden_size,
bias=True,
params_dtype=dtype,
device="cuda",
)
.to(dtype=dtype)
.cuda()
.eval()
)
......@@ -932,9 +995,9 @@ def test_linear_accuracy(dtype, bs, model):
config.hidden_size,
4 * config.hidden_size,
bias=True,
device="cuda",
dtype=dtype,
)
.to(dtype=dtype)
.cuda()
.eval()
)
......@@ -965,10 +1028,10 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
RMSNorm(
config.hidden_size,
eps=eps,
zero_centered_gamma=zero_centered_gamma
params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
device="cuda",
)
.to(dtype=dtype)
.cuda()
.eval()
)
......@@ -1009,10 +1072,10 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
LayerNorm(
config.hidden_size,
eps=eps,
zero_centered_gamma=zero_centered_gamma
params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
device="cuda",
)
.to(dtype=dtype)
.cuda()
.eval()
)
......@@ -1058,10 +1121,10 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere
config.eps,
bias=True,
normalization=normalization,
params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
device="cuda",
)
.to(dtype=dtype)
.cuda()
.eval()
)
......@@ -1112,9 +1175,9 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
4 * config.hidden_size,
activation=activation,
normalization=normalization,
params_dtype=dtype,
device="cuda",
)
.to(dtype=dtype)
.cuda()
.eval()
)
......@@ -1229,11 +1292,11 @@ def test_gpt_cuda_graph(dtype, bs, model):
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
device="cuda",
)
.to(dtype=dtype)
.cuda()
)
graphed_block = copy.deepcopy(block)
......@@ -1257,28 +1320,29 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params):
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8_model_params):
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
params_dtype=dtype,
fuse_qkv_params=True,
)
.cuda()
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
params_dtype=dtype,
fuse_qkv_params=True,
device="cuda",
)
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
(config.seq_len, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
......@@ -1306,7 +1370,18 @@ def test_gpt_fp8_parameters(dtype, bs, model):
outputs = _test_gpt_fp8_parameters(bs, dtype, config, False)
outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True)
assert_all_equal(outputs, outputs_fp8_params)
# Check that results match
tols = dict(rtol=0.125, atol=0.0675)
for i, (ref, test) in enumerate(zip(outputs, outputs_fp8_params)):
torch.testing.assert_close(
test,
ref,
msg=f"Mismatch in tensor {i}",
rtol=0.125,
atol=0.0675,
)
@pytest.mark.parametrize("dtype", param_types)
......@@ -1323,54 +1398,53 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
# other layer. Set `*dropout` values to 0 to make sure the forward pass
# is identical to the other layer.
torch.manual_seed(0)
block_sbhd = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0,
attention_dropout=0,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
attn_input_format="sbhd"
)
.to(dtype=dtype)
.cuda()
block_sbhd = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0,
attention_dropout=0,
kv_channels=config.embed,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
device="cuda",
attn_input_format="sbhd",
)
# Set `torch.manual_seed` to make sure the weights are identical to the
# other layer. Set `*dropout` values to 0 to make sure the forward pass
# is identical to the other layer.
torch.manual_seed(0)
block_bshd = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0,
attention_dropout=0,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
attn_input_format="bshd"
)
.to(dtype=dtype)
.cuda()
block_bshd = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0,
attention_dropout=0,
kv_channels=config.embed,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
device="cuda",
attn_input_format="bshd",
)
for (n1, p1), (n2, p2) in zip(block_bshd.named_parameters(), block_sbhd.named_parameters()):
assert torch.all(torch.eq(p1, p2)), f"{n1}, {n2} not identical"
x_sbhd = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).to(dtype).cuda()
(config.seq_len, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
x_bshd = x_sbhd.transpose(0,1).contiguous()
......@@ -1384,7 +1458,11 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
torch.manual_seed(0)
y_bshd = block_bshd(x_bshd)
assert_all_equal([y_bshd], [y_sbhd.transpose(0,1).contiguous()])
# Check that results match
torch.testing.assert_close(
y_bshd,
y_sbhd.transpose(0,1).contiguous(),
)
@pytest.mark.parametrize("dtype", param_types)
......@@ -1424,10 +1502,10 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
num_attention_heads=H,
attn_input_format=input_format,
layer_number=layer_number,
attention_dropout = 0.0
attention_dropout = 0.0,
params_dtype=dtype,
device="cuda",
)
.to(dtype=dtype)
.cuda()
.eval()
)
else:
......@@ -1437,9 +1515,9 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
num_attention_heads=H,
qkv_format=input_format,
layer_number=layer_number,
attention_dropout = 0.0
attention_dropout = 0.0,
params_dtype=dtype,
)
.to(dtype=dtype)
.cuda()
.eval()
)
......
......@@ -172,10 +172,18 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, config.batch_size, config.hidden_size, dtype=torch.float32, requires_grad=True
).cuda()
(config.seq_len, config.batch_size, config.hidden_size),
dtype=torch.float32,
device="cuda",
requires_grad=True,
)
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
te_inp_attn_mask = torch.randint(
2,
(1, 1, config.seq_len, config.seq_len),
dtype=torch.bool,
device="cuda",
)
if skip_wgrad:
_disable_wgrads(block)
......@@ -198,9 +206,17 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
(config.seq_len, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_attn_mask = torch.randint(
2,
(1, 1, config.seq_len, config.seq_len),
dtype=torch.bool,
device="cuda",
)
if skip_wgrad:
_disable_wgrads(block)
......@@ -227,8 +243,11 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
te_inp_hidden_states = torch.randn(
config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
(config.seq_len, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
if skip_wgrad:
_disable_wgrads(block)
......@@ -250,10 +269,18 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
(config.seq_len, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_attn_mask = torch.rand(torch.Size([config.batch_size, 1, 1, config.seq_len])).cuda() > 0.5
te_inp_attn_mask = torch.randint(
2,
(config.batch_size, 1, 1, config.seq_len),
dtype=torch.bool,
device="cuda",
)
if skip_wgrad:
_disable_wgrads(block)
......@@ -268,10 +295,24 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
enc_dec_attn_mask = torch.rand(torch.Size([config.batch_size, 1, 1, config.seq_len])).cuda() > 0.5
(config.seq_len, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_attn_mask = torch.randint(
2,
(1, 1, config.seq_len, config.seq_len),
dtype=torch.bool,
device="cuda",
)
enc_dec_attn_mask = torch.randint(
2,
(config.batch_size, 1, 1, config.seq_len),
dtype=torch.bool,
device="cuda",
)
if skip_wgrad:
_disable_wgrads(block)
......@@ -294,8 +335,11 @@ def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad
pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")
te_inp = torch.randn(
config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=not skip_dgrad
).cuda()
(config.seq_len, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=not skip_dgrad,
)
if skip_wgrad:
_disable_wgrads(block)
......@@ -315,8 +359,10 @@ def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")
te_inp = torch.randn(
config.seq_len, config.batch_size, config.hidden_size, requires_grad=True
).cuda()
(config.seq_len, config.batch_size, config.hidden_size),
device="cuda",
requires_grad=True,
)
te_inp.retain_grad()
with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
......@@ -371,16 +417,14 @@ def test_sanity_layernorm_linear(dtype, fp8_recipe, model, skip_wgrad,
sigma = 0.023
init_method = init_method_normal(sigma)
block = (
LayerNormLinear(
config.hidden_size,
config.hidden_size * 3,
init_method=init_method,
zero_centered_gamma=zero_centered_gamma,
normalization=normalization,
)
.to(dtype=dtype)
.cuda()
block = LayerNormLinear(
config.hidden_size,
config.hidden_size * 3,
init_method=init_method,
zero_centered_gamma=zero_centered_gamma,
normalization=normalization,
params_dtype=dtype,
device="cuda",
)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
......@@ -402,12 +446,12 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
sigma = 0.023
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
Linear(
config.hidden_size, config.hidden_size, init_method=output_layer_init_method
)
.to(dtype=dtype)
.cuda()
block = Linear(
config.hidden_size,
config.hidden_size,
init_method=output_layer_init_method,
params_dtype=dtype,
device="cuda",
)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
......@@ -435,18 +479,16 @@ def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad,
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
LayerNormMLP(
config.hidden_size,
4 * config.hidden_size,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
zero_centered_gamma=zero_centered_gamma,
activation=activation,
normalization=normalization,
)
.to(dtype=dtype)
.cuda()
block = LayerNormMLP(
config.hidden_size,
4 * config.hidden_size,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
zero_centered_gamma=zero_centered_gamma,
activation=activation,
normalization=normalization,
params_dtype=dtype,
device="cuda",
)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
......@@ -477,26 +519,24 @@ def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad,
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
bias=bias,
activation=activation,
normalization=normalization,
parallel_attention_mlp=parallel_attention_mlp,
)
.to(dtype=dtype)
.cuda()
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
bias=bias,
activation=activation,
normalization=normalization,
device="cuda",
parallel_attention_mlp=parallel_attention_mlp,
)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload)
......@@ -546,24 +586,22 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=True,
output_layernorm=True,
zero_centered_gamma=zero_centered_gamma,
self_attn_mask_type="padding",
normalization=normalization,
)
.to(dtype=dtype)
.cuda()
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=True,
output_layernorm=True,
zero_centered_gamma=zero_centered_gamma,
self_attn_mask_type="padding",
normalization=normalization,
device="cuda",
)
_test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad)
......@@ -607,24 +645,22 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
layer_type="decoder",
zero_centered_gamma=zero_centered_gamma,
normalization=normalization,
)
.to(dtype=dtype)
.cuda()
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
layer_type="decoder",
zero_centered_gamma=zero_centered_gamma,
normalization=normalization,
device="cuda",
)
_test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad)
......@@ -665,19 +701,17 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
)
.to(dtype=torch.float32)
.cuda()
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
params_dtype=torch.float32,
device="cuda",
)
_test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad)
......@@ -700,22 +734,20 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
drop_path_rate=1.0,
)
.to(dtype=dtype)
.cuda()
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
drop_path_rate=1.0,
device="cuda",
)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
......@@ -738,22 +770,20 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
fuse_qkv_params=True,
)
.to(dtype=dtype)
.cuda()
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
fuse_qkv_params=True,
device="cuda",
)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
......@@ -777,24 +807,22 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
fuse_qkv_params=True,
fuse_wgrad_accumulation=True,
)
.to(dtype=dtype)
.cuda()
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
fuse_qkv_params=True,
fuse_wgrad_accumulation=True,
device="cuda",
)
_test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
......@@ -820,30 +848,28 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
fuse_qkv_params=True,
normalization=normalization,
)
.to(dtype=dtype)
.cuda()
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
fuse_qkv_params=True,
normalization=normalization,
device="cuda",
)
_test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad)
def test_model_multiple_cast():
a = torch.zeros((16,16)).cuda()
a = torch.zeros((16,16), device="cuda")
m = Linear(16,32)
y = m(a)
......
......@@ -10,6 +10,7 @@ from dataclasses import dataclass
import torch
from .. import cpp_extensions as tex
from ..export import is_in_onnx_export_mode
from ..fp8 import get_fp8_te_dtype
from ..utils import get_default_init_method
......@@ -99,32 +100,79 @@ def _apply_normalization(inputmat:torch.Tensor,
class _NoopCatFunc(torch.autograd.Function):
"""No-op concatenate tensors along dim 0
"""Concatenate tensors, doing a no-op if possible
`full_tensor` is assumed to already be the concatenation of
`tensors`, i.e. they occupy the same memory with the correct
offsets.
See _noop_cat.
"""
@staticmethod
def forward(
ctx,
split_ranges: List[Tuple[int, int]],
full_tensor: torch.Tensor,
ctx: Any,
dim: int,
*tensors: Tuple[torch.Tensor, ...],
) -> torch.Tensor:
# pylint: disable=unused-argument
# Check first tensor
if not tensors:
raise ValueError("Attempted to concatenate 0 tensors")
num_dims = tensors[0].dim()
if not -num_dims <= dim < num_dims:
raise ValueError(
"Attempted to concatenate tensor "
f"with shape {list(tensors[0].size())} along dim {dim}"
)
dim %= num_dims
# Check remaining tensors
out_shape = list(tensors[0].size())
split_ranges = [(0, tensors[0].size(dim))]
for tensor in tensors[1:]:
in_shape = list(tensor.size())
if (
len(in_shape) != num_dims
or in_shape[:dim] != out_shape[:dim]
or in_shape[dim+1:] != out_shape[dim+1:]
):
raise ValueError(
"Attempted to concatenate tensors with shapes "
f"{[list(tensor.size()) for tensor in tensors]} "
f"along dim {dim}"
)
split_start = out_shape[dim]
split_end = split_start + in_shape[dim]
out_shape[dim] = split_end
split_ranges.append((split_start, split_end))
# Save state for backward
ctx.dim = dim
ctx.split_ranges = split_ranges
assert not full_tensor.requires_grad, "Concatenated tensor should not require gradient"
out = full_tensor.new()
# Out-of-place concatenation if needed
dtype = tensors[0].dtype
device = tensors[0].device
strides = tensors[0].stride()
data_ptr_stride = strides[dim] * tensors[0].element_size()
data_ptr = tensors[0].data_ptr() + tensors[0].size(dim) * data_ptr_stride
for tensor in tensors[1:]:
if (
tensor.dtype != dtype
or tensor.device != device
or tensor.stride() != strides
or tensor.data_ptr() != data_ptr
):
return torch.cat(tensors, dim=dim)
data_ptr += tensor.size(dim) * data_ptr_stride
# No-op concatenation
out = tensors[0].new()
out.set_(
full_tensor.untyped_storage(),
full_tensor.storage_offset(),
full_tensor.size(),
full_tensor.stride(),
tensors[0].untyped_storage(),
tensors[0].storage_offset(),
out_shape,
strides,
)
out.requires_grad = True
out.requires_grad = any(tensor.requires_grad for tensor in tensors)
return out
@staticmethod
......@@ -132,64 +180,32 @@ class _NoopCatFunc(torch.autograd.Function):
ctx,
grad_output: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
grads = [
grad_output[split_start:split_end]
for split_start, split_end in ctx.split_ranges
]
return None, None, *grads
grad_inputs = []
for split_start, split_end in ctx.split_ranges:
slices = [slice(None)] * grad_output.dim()
slices[ctx.dim] = slice(split_start, split_end)
grad_inputs.append(grad_output[tuple(slices)])
return None, *grad_inputs
def _noop_cat(
tensors: List[torch.Tensor],
full_tensor: torch.Tensor,
dim: int = 0,
) -> torch.Tensor:
"""Concatenate tensors along dim 0, doing a no-op if possible
If `full_tensor` is already the concatenation of `tensors`, i.e.
they occupy the same memory region with the correct offsets, then
no copies are performed. Otherwise the buffers in all the tensors
are reallocated so that another call would result in a no-op.
"""Concatenate tensors, doing a no-op if possible
In the backward pass, gradients to `partial_tensors` will just be
tensor views.
If tensors are already concatenated in memory, a tensor view of
that memory region will be returned. Otherwise the tensors will be
concatenated out-of-place, as usual.
"""
# Determine split points
split_ranges = []
full_tensor_shape = full_tensor.size()
offset = 0
for tensor in tensors:
tensor_shape = tensor.size()
if tensor_shape[1:] != full_tensor_shape[1:]:
raise ValueError(
f"Attempting to concatenate tensor with shape={list(tensor_shape)} "
f"into a tensor with shape={list(full_tensor_shape)}"
)
split_start = offset
offset += tensor_shape[0]
split_end = offset
split_ranges.append((split_start, split_end))
if offset != full_tensor_shape[0]:
raise ValueError(
f"Attempting to concatenate tensors with total shape[0]={offset} "
f"into a tensor with shape[0]={full_tensor_shape[0]}"
)
# Reallocate buffers if no-op concat isn't possible
need_to_reallocate = False
for tensor, (split_start, _) in zip(tensors, split_ranges):
if tensor.data_ptr() != full_tensor[split_start].data_ptr():
need_to_reallocate = True
break
if need_to_reallocate:
with torch.no_grad():
full_tensor.data = torch.cat(tensors)
for tensor, (split_start, split_end) in zip(tensors, split_ranges):
tensor.data = full_tensor[split_start:split_end]
# Perform no-op concat
return _NoopCatFunc.apply(split_ranges, full_tensor, *tensors)
if not tensors:
raise ValueError("Attempted to concatenate 0 tensors")
if len(tensors) == 1:
return tensors[0]
if is_in_onnx_export_mode():
return torch.cat(tensors, dim=dim)
return _NoopCatFunc.apply(dim, *tensors)
@dataclass
......
......@@ -926,17 +926,20 @@ class LayerNormLinear(TransformerEngineBaseModule):
else:
self.layer_norm_bias = None
self.weight_tensor = torch.empty(
self.out_features, self.in_features,
device=device, dtype=params_dtype)
# Contiguous buffers for params
weight_tensor = torch.empty(
self.out_features,
self.in_features,
device=device,
dtype=params_dtype,
)
bias_tensor = None
if self.use_bias:
self.bias_tensor = torch.empty(
bias_tensor = torch.empty(
self.out_features,
device=device,
dtype=params_dtype)
else:
self.bias_tensor = torch.Tensor().to(dtype=params_dtype, device=device)
dtype=params_dtype,
)
# Configure parameter splits
self.weight_names = []
......@@ -982,7 +985,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
)
self.parameter_split_sizes[i] = size // self.tp_size
# Construct parameters from weight and bias buffers
# Construct weight parameters
# Note: Register weights together so that they are adjacent to
# each other in LayerNormLinear.parameters(). This makes it
# more likely that they will stay contiguous if the weights
# are manipulated externally, e.g. by FSDP.
offset = 0
for i, split_size in enumerate(self.parameter_split_sizes):
split_start = offset
......@@ -998,32 +1005,30 @@ class LayerNormLinear(TransformerEngineBaseModule):
)
# Construct weight parameter
weight = self.weight_tensor
if is_subview:
weight = weight[split_start:split_end]
weight = torch.nn.Parameter(weight)
self.register_parameter(self.weight_names[i], weight,
init_fn=init_method,
get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT)
# Construct bias parameter if needed
if self.use_bias:
bias = self.bias_tensor
if is_subview:
bias = bias[split_start:split_end]
bias = torch.nn.Parameter(bias)
self.register_parameter(self.bias_names[i], bias,
init_fn=init_method_constant(0.0))
else:
bias = torch.Tensor().to(dtype=params_dtype, device=device)
setattr(self, self.bias_names[i], bias)
self.register_parameter(
self.weight_names[i],
torch.nn.Parameter(weight_tensor[split_start:split_end]),
init_fn=init_method,
get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
# Concatenated tensors are not needed if not splitting
# into multiple parameters
if not is_subview:
del self.weight_tensor
del self.bias_tensor
# Construct bias parameters if needed
if self.use_bias:
offset = 0
for i, split_size in enumerate(self.parameter_split_sizes):
split_start = offset
offset += split_size
split_end = offset
self.register_parameter(
self.bias_names[i],
torch.nn.Parameter(bias_tensor[split_start:split_end]),
init_fn=init_method_constant(0.0),
)
else:
for name in self.bias_names:
bias = torch.Tensor().to(dtype=params_dtype, device=device)
setattr(self, name, bias)
if self.primary_weights_in_fp8:
self.init_fp8_metadata()
......@@ -1150,24 +1155,15 @@ class LayerNormLinear(TransformerEngineBaseModule):
"Need to run inside fp8_autocast region when weights are stored in FP8."
# Get concatenated weight and bias tensors
if len(self.parameter_split_sizes) == 1:
weight_tensor = getattr(self, self.weight_names[0])
bias_tensor = getattr(self, self.bias_names[0])
elif torch.is_grad_enabled():
weight_tensor = _noop_cat(
[getattr(self, name) for name in self.weight_names],
self.weight_tensor,
weight_tensor = _noop_cat(
[getattr(self, name) for name in self.weight_names],
)
if self.use_bias:
bias_tensor = _noop_cat(
[getattr(self, name) for name in self.bias_names],
)
if self.use_bias:
bias_tensor = _noop_cat(
[getattr(self, name) for name in self.bias_names],
self.bias_tensor,
)
else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused
else:
weight_tensor = self.weight_tensor
bias_tensor = self.bias_tensor
bias_tensor = getattr(self, self.bias_names[0]) # Unused
# Fetch the fp8 weights placeholders (for linear/gemm)
weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad(
......
......@@ -777,14 +777,20 @@ class Linear(TransformerEngineBaseModule):
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
self.weight_tensor = torch.empty(
self.out_features, self.in_features,
device=device, dtype=params_dtype)
# Contiguous buffers for params
weight_tensor = torch.empty(
self.out_features,
self.in_features,
device=device,
dtype=params_dtype,
)
bias_tensor = None
if self.use_bias:
self.bias_tensor = torch.empty(self.out_features, device=device, dtype=params_dtype)
else:
self.bias_tensor = torch.Tensor().to(dtype=params_dtype, device=device)
bias_tensor = torch.empty(
self.out_features,
device=device,
dtype=params_dtype,
)
# Configure parameter splits
self.weight_names = []
......@@ -830,7 +836,11 @@ class Linear(TransformerEngineBaseModule):
)
self.parameter_split_sizes[i] = size // self.tp_size
# Construct parameters from weight and bias buffers
# Construct weight parameters
# Note: Register weights together so that they are adjacent to
# each other in Linear.parameters(). This makes it more likely
# that they will stay contiguous if the weights are
# manipulated externally, e.g. by FSDP.
offset = 0
for i, split_size in enumerate(self.parameter_split_sizes):
split_start = offset
......@@ -846,32 +856,30 @@ class Linear(TransformerEngineBaseModule):
)
# Construct weight parameter
weight = self.weight_tensor
if is_subview:
weight = weight[split_start:split_end]
weight = torch.nn.Parameter(weight)
self.register_parameter(self.weight_names[i], weight,
init_fn=init_method,
get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT)
# Construct bias parameter if needed
if self.use_bias:
bias = self.bias_tensor
if is_subview:
bias = bias[split_start:split_end]
bias = torch.nn.Parameter(bias)
self.register_parameter(self.bias_names[i], bias,
init_fn=init_method_constant(0.0))
else:
bias = torch.Tensor().to(dtype=params_dtype, device=device)
setattr(self, self.bias_names[i], bias)
self.register_parameter(
self.weight_names[i],
torch.nn.Parameter(weight_tensor[split_start:split_end]),
init_fn=init_method,
get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
# Concatenated tensors are not needed if not splitting
# into multiple parameters
if not is_subview:
del self.weight_tensor
del self.bias_tensor
# Construct bias parameters if needed
if self.use_bias:
offset = 0
for i, split_size in enumerate(self.parameter_split_sizes):
split_start = offset
offset += split_size
split_end = offset
self.register_parameter(
self.bias_names[i],
torch.nn.Parameter(bias_tensor[split_start:split_end]),
init_fn=init_method_constant(0.0),
)
else:
for name in self.bias_names:
bias = torch.Tensor().to(dtype=params_dtype, device=device)
setattr(self, name, bias)
if self.primary_weights_in_fp8:
self.init_fp8_metadata()
......@@ -974,24 +982,15 @@ class Linear(TransformerEngineBaseModule):
is_first_module_in_mha = is_first_module_in_mha and self.fp8_meta["recipe"].fp8_mha
# Get concatenated weight and bias tensors
if len(self.parameter_split_sizes) == 1:
weight_tensor = getattr(self, self.weight_names[0])
bias_tensor = getattr(self, self.bias_names[0])
elif torch.is_grad_enabled():
weight_tensor = _noop_cat(
[getattr(self, name) for name in self.weight_names],
self.weight_tensor,
weight_tensor = _noop_cat(
[getattr(self, name) for name in self.weight_names],
)
if self.use_bias:
bias_tensor = _noop_cat(
[getattr(self, name) for name in self.bias_names],
)
if self.use_bias:
bias_tensor = _noop_cat(
[getattr(self, name) for name in self.bias_names],
self.bias_tensor,
)
else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused
else:
weight_tensor = self.weight_tensor
bias_tensor = self.bias_tensor
bias_tensor = getattr(self, self.bias_names[0]) # Unused
# Fetch the fp8 weights placeholders (for linear/gemm)
weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment