Unverified Commit e2a75314 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Reduce size of sanity tests (#510)



* Reduce size of PyT sanity tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add test cases with 126M model and weird dimensions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add missing arg in GPT 126M test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 1bb8b6eb
......@@ -2,6 +2,9 @@
#
# See LICENSE for license information.
from dataclasses import dataclass
from typing import Optional
import torch
import pytest
......@@ -42,21 +45,28 @@ def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
"""Custom func to test recipe."""
return torch.min(amax_history, dim=0).values
@dataclass
class ModelConfig:
def __init__(
self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
):
self.hidden_size = hidden_size
self.eps = eps
self.num_attention_heads = num_attention_heads
self.embed = embed
self.num_layers = num_layers
self.seq_len = seq_len
"""Transformer model configuration"""
num_layers: int
seq_len: int
batch_size: int
hidden_size: int
num_attention_heads: int
kv_channels: Optional[int] = None
def is_fp8_supported(self):
if self.seq_len * self.batch_size % 16:
return False
if self.hidden_size % 16:
return False
return True
model_configs = {
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048),
"126m": ModelConfig(12, 2048, 2, 768, 12),
"small": ModelConfig(2, 32, 2, 64, 2),
"weird": ModelConfig(2, 37, 3, 69, 3),
}
fp8_recipes = [
......@@ -92,8 +102,6 @@ param_types = [torch.float32, torch.float16]
if torch.cuda.is_bf16_supported():
param_types.append(torch.bfloat16)
batch_sizes = [1, 2]
all_boolean = [True, False]
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
......@@ -104,14 +112,14 @@ def _disable_wgrads(block):
p.requires_grad = False
def _test_sanity_e2e_cuda_graph(block, bs, dtype, config, fp8_recipe, skip_wgrad):
def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
# Initialize loss function and optimizer.
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(block.parameters(), lr=0.1)
# Placeholders used for capture.
static_input = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype, requires_grad=True)
static_target = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype)
static_input = torch.randn(config.seq_len, config.batch_size, config.hidden_size, device='cuda', dtype=dtype, requires_grad=True)
static_target = torch.randn(config.seq_len, config.batch_size, config.hidden_size, device='cuda', dtype=dtype)
real_input = torch.rand_like(static_input)
real_target = torch.rand_like(static_target)
......@@ -152,9 +160,9 @@ def _test_sanity_e2e_cuda_graph(block, bs, dtype, config, fp8_recipe, skip_wgrad
torch.cuda.synchronize()
def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad):
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True
config.seq_len, config.batch_size, config.hidden_size, dtype=torch.float32, requires_grad=True
).cuda()
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
......@@ -178,9 +186,9 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad):
assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."
def _test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_recipe, skip_wgrad):
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
......@@ -207,9 +215,9 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_
assert torch.count_nonzero(p.main_grad) > 0, "Gradient not accumulated."
def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad):
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
if skip_wgrad:
......@@ -223,12 +231,12 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad):
torch.cuda.synchronize()
def _test_sanity_e2e_bert(block, bs, dtype, config, fp8_recipe, skip_wgrad):
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = torch.rand(torch.Size([bs, 1, 1, config.seq_len])).cuda() > 0.5
te_inp_attn_mask = torch.rand(torch.Size([config.batch_size, 1, 1, config.seq_len])).cuda() > 0.5
if skip_wgrad:
_disable_wgrads(block)
......@@ -241,12 +249,12 @@ def _test_sanity_e2e_bert(block, bs, dtype, config, fp8_recipe, skip_wgrad):
torch.cuda.synchronize()
def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad):
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
enc_dec_attn_mask = torch.rand(torch.Size([bs, 1, 1, config.seq_len])).cuda() > 0.5
enc_dec_attn_mask = torch.rand(torch.Size([config.batch_size, 1, 1, config.seq_len])).cuda() > 0.5
if skip_wgrad:
_disable_wgrads(block)
......@@ -264,12 +272,12 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad):
torch.cuda.synchronize()
def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad):
def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad):
if skip_dgrad and skip_wgrad:
pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")
te_inp = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=not skip_dgrad
config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=not skip_dgrad
).cuda()
if skip_wgrad:
......@@ -285,12 +293,12 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad, skip_d
torch.cuda.synchronize()
def _test_sanity_normalization_amp(block, bs, dtype, config, skip_wgrad, skip_dgrad):
def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad):
if skip_dgrad and skip_wgrad:
pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")
te_inp = torch.randn(
config.seq_len, bs, config.hidden_size, requires_grad=True
config.seq_len, config.batch_size, config.hidden_size, requires_grad=True
).cuda()
te_inp.retain_grad()
......@@ -309,45 +317,43 @@ def _test_sanity_normalization_amp(block, bs, dtype, config, skip_wgrad, skip_dg
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_normalization_amp(dtype, bs, model, skip_wgrad, skip_dgrad, normalization):
def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normalization):
config = model_configs[model]
module = RMSNorm if normalization == "RMSNorm" else LayerNorm
block = (
module(
config.hidden_size,
eps=config.eps,
)
module(config.hidden_size)
.to(dtype=torch.float32)
.cuda()
)
_test_sanity_normalization_amp(block, bs, dtype, config, skip_wgrad, skip_dgrad)
_test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad,
def test_sanity_layernorm_linear(dtype, fp8_recipe, model, skip_wgrad,
zero_centered_gamma, skip_dgrad,
normalization):
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
......@@ -355,7 +361,6 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad,
LayerNormLinear(
config.hidden_size,
config.hidden_size * 3,
eps=config.eps,
init_method=init_method,
zero_centered_gamma=zero_centered_gamma,
normalization=normalization,
......@@ -363,21 +368,23 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad,
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad, skip_dgrad):
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
sigma = 0.023
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
......@@ -388,29 +395,31 @@ def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad, skip_dgrad):
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad,
def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad,
zero_centered_gamma, skip_dgrad, activation,
normalization):
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
......@@ -419,7 +428,6 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad,
LayerNormMLP(
config.hidden_size,
4 * config.hidden_size,
eps=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
zero_centered_gamma=zero_centered_gamma,
......@@ -429,30 +437,32 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad,
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad,
def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad,
zero_centered_gamma, bias, activation,
normalization, parallel_attention_mlp):
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
......@@ -462,12 +472,11 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad,
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
......@@ -480,26 +489,51 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad,
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
def test_sanity_gpt_126m():
fp8_recipe = None
if fp8_available:
fp8_recipe = recipe.DelayedScaling(
0,
1,
recipe.Format.E4M3,
amax_history_len=16,
amax_compute_algo="most_recent",
)
test_sanity_gpt(
dtype=param_types[-1],
fp8_recipe=fp8_recipe,
model="126m",
skip_wgrad=False,
zero_centered_gamma=True,
bias=True,
activation="gelu",
normalization="LayerNorm",
parallel_attention_mlp=False,
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
normalization):
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
......@@ -509,12 +543,11 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gam
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=True,
output_layernorm=True,
zero_centered_gamma=zero_centered_gamma,
......@@ -525,26 +558,46 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gam
.cuda()
)
_test_sanity_e2e_bert(block, bs, dtype, config, fp8_recipe, skip_wgrad)
_test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad)
def test_sanity_bert_126m():
fp8_recipe = recipe.DelayedScaling(
0,
1,
recipe.Format.E4M3,
amax_history_len=1,
amax_compute_algo="most_recent",
)
test_sanity_bert(
dtype=param_types[-1],
fp8_recipe=fp8_recipe,
model="126m",
skip_wgrad=False,
zero_centered_gamma=False,
normalization="LayerNorm",
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
normalization):
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
......@@ -554,12 +607,11 @@ def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
layer_type="decoder",
......@@ -570,20 +622,40 @@ def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma
.cuda()
)
_test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad)
_test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad)
def test_sanity_T5_126m():
fp8_recipe = recipe.DelayedScaling(
0,
1,
recipe.Format.E4M3,
amax_history_len=1,
amax_compute_algo="most_recent",
)
test_sanity_T5(
dtype=param_types[-1],
fp8_recipe=fp8_recipe,
model="126m",
skip_wgrad=False,
zero_centered_gamma=False,
normalization="LayerNorm",
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
......@@ -593,31 +665,32 @@ def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model, skip_wgrad):
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
kv_channels=config.kv_channels,
)
.to(dtype=torch.float32)
.cuda()
)
_test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad)
_test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_drop_path(dtype, bs, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
......@@ -627,12 +700,11 @@ def test_sanity_drop_path(dtype, bs, fp8_recipe, model, skip_wgrad):
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
drop_path_rate=1.0,
......@@ -641,20 +713,22 @@ def test_sanity_drop_path(dtype, bs, fp8_recipe, model, skip_wgrad):
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
......@@ -664,12 +738,11 @@ def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model, skip_wgrad):
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
fuse_qkv_params=True,
......@@ -678,21 +751,23 @@ def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model, skip_wgrad):
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_gradient_accumulation_fusion(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
......@@ -702,12 +777,11 @@ def test_sanity_gradient_accumulation_fusion(dtype, bs, fp8_recipe, model, skip_
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
......@@ -718,26 +792,28 @@ def test_sanity_gradient_accumulation_fusion(dtype, bs, fp8_recipe, model, skip_
.cuda()
)
_test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_recipe, skip_wgrad)
_test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_gpt_cuda_graph(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
normalization):
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
......@@ -747,12 +823,11 @@ def test_gpt_cuda_graph(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
......@@ -763,7 +838,7 @@ def test_gpt_cuda_graph(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_
.cuda()
)
_test_sanity_e2e_cuda_graph(block, bs, dtype, config, fp8_recipe, skip_wgrad)
_test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad)
def test_model_multiple_cast():
a = torch.zeros((16,16)).cuda()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment