"git@developer.sourcefind.cn:gaoqiong/pybind11.git" did not exist on "4038542b157d8e7ffd5c6f47166d9544fea03800"
Unverified Commit 2c996359 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Improve PyTorch test harness (#102)



* add layernorm1p fp8 test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* combine tests for easy maintenance
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* using torch.autocast for AMP and check grad types
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add test for wgrad accumulation fusion
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rename file
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Setup numerical tests + SAR
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add test for full activation recompute
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add tests for checkpoint load/store
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* TE vs framework numerical tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix ci
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* relax thresholds
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d1d00b3e
...@@ -7,5 +7,6 @@ set -e ...@@ -7,5 +7,6 @@ set -e
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
pip install pytest==6.2.5 onnxruntime==1.13.1 pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/test_transformerengine.py $TE_PATH/tests/pytorch/test_fp8.py pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
NVTE_FLASH_ATTN=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py NVTE_FLASH_ATTN=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
This diff is collapsed.
...@@ -19,8 +19,7 @@ from transformer_engine.pytorch import ( ...@@ -19,8 +19,7 @@ from transformer_engine.pytorch import (
from transformer_engine.common import recipe from transformer_engine.common import recipe
# Only run FP8 tests on H100. # Only run FP8 tests on H100.
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9: fp8_available = torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9
pytest.skip(allow_module_level=True)
def custom_amax_to_scale( def custom_amax_to_scale(
...@@ -59,6 +58,7 @@ model_configs = { ...@@ -59,6 +58,7 @@ model_configs = {
} }
fp8_recipes = [ fp8_recipes = [
None, # Handles non-FP8 case
recipe.DelayedScaling(0, 1, recipe.Format.E4M3), recipe.DelayedScaling(0, 1, recipe.Format.E4M3),
recipe.DelayedScaling(0, 1, recipe.Format.HYBRID), recipe.DelayedScaling(0, 1, recipe.Format.HYBRID),
recipe.DelayedScaling( recipe.DelayedScaling(
...@@ -86,11 +86,13 @@ fp8_recipes = [ ...@@ -86,11 +86,13 @@ fp8_recipes = [
), ),
] ]
param_types = [torch.float32, torch.bfloat16, torch.float16] param_types = [torch.float32, torch.float16]
if torch.cuda.is_bf16_supported():
param_types.append(torch.bfloat16)
batch_sizes = [1, 2] batch_sizes = [1, 2]
skip_wgrad = [True, False] all_boolean = [True, False]
def _disable_wgrads(block): def _disable_wgrads(block):
...@@ -102,6 +104,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad): ...@@ -102,6 +104,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True
).cuda() ).cuda()
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = ( te_inp_attn_mask = (
torch.rand( torch.rand(
( (
...@@ -118,15 +121,63 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad): ...@@ -118,15 +121,63 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad):
if skip_wgrad: if skip_wgrad:
_disable_wgrads(block) _disable_wgrads(block)
with torch.cuda.amp.autocast(enabled=True, dtype=dtype): use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, te_inp_attn_mask) te_out = block(te_inp_hidden_states, te_inp_attn_mask)
loss = te_out.sum() loss = te_out.sum()
assert te_out.dtype == dtype
loss.backward() loss.backward()
torch.cuda.synchronize() torch.cuda.synchronize()
assert te_out.dtype == dtype, "AMP wrong output type."
assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type."
for name, p in block.named_parameters():
if p.requires_grad:
assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."
def _test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = (
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
if skip_wgrad:
_disable_wgrads(block)
for name, p in block.named_parameters():
if "layer_norm_weight" in name:
continue
elif "weight" in name and p.requires_grad:
p.main_grad = torch.zeros_like(p)
use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, te_inp_attn_mask)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
for name, p in block.named_parameters():
if "layer_norm_weight" in name:
continue
elif "weight" in name and p.requires_grad:
assert (
p.grad is None and torch.count_nonzero(p.main_grad) > 0
), "Gradient not accumulated."
def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
...@@ -148,7 +199,8 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad): ...@@ -148,7 +199,8 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad):
if skip_wgrad: if skip_wgrad:
_disable_wgrads(block) _disable_wgrads(block)
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, te_inp_attn_mask) te_out = block(te_inp_hidden_states, te_inp_attn_mask)
loss = te_out.sum() loss = te_out.sum()
loss.backward() loss.backward()
...@@ -175,7 +227,8 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad): ...@@ -175,7 +227,8 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad):
if skip_wgrad: if skip_wgrad:
_disable_wgrads(block) _disable_wgrads(block)
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block( te_out = block(
te_inp_hidden_states, te_inp_attn_mask, encoder_output=te_inp_hidden_states te_inp_hidden_states, te_inp_attn_mask, encoder_output=te_inp_hidden_states
) )
...@@ -192,7 +245,8 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad): ...@@ -192,7 +245,8 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad):
if skip_wgrad: if skip_wgrad:
_disable_wgrads(block) _disable_wgrads(block)
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp) te_out = block(te_inp)
if isinstance(te_out, tuple): if isinstance(te_out, tuple):
te_out = te_out[0] te_out = te_out[0]
...@@ -205,8 +259,12 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad): ...@@ -205,8 +259,12 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad):
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad): @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -218,6 +276,7 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -218,6 +276,7 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad):
config.hidden_size * 3, config.hidden_size * 3,
eps=config.eps, eps=config.eps,
init_method=init_method, init_method=init_method,
zero_centered_gamma=zero_centered_gamma,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -229,8 +288,11 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -229,8 +288,11 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad): def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -250,8 +312,12 @@ def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -250,8 +312,12 @@ def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad): @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -265,6 +331,7 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -265,6 +331,7 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad):
eps=config.eps, eps=config.eps,
init_method=init_method, init_method=init_method,
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
zero_centered_gamma=zero_centered_gamma,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -276,8 +343,12 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -276,8 +343,12 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad): @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -297,6 +368,7 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -297,6 +368,7 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad):
kv_channels=config.embed, kv_channels=config.embed,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -309,8 +381,12 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -309,8 +381,12 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad): @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -330,6 +406,7 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -330,6 +406,7 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad):
kv_channels=config.embed, kv_channels=config.embed,
apply_residual_connection_post_layernorm=True, apply_residual_connection_post_layernorm=True,
output_layernorm=True, output_layernorm=True,
zero_centered_gamma=zero_centered_gamma,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -342,8 +419,12 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -342,8 +419,12 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad): @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -364,6 +445,7 @@ def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -364,6 +445,7 @@ def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad):
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
layer_type="decoder", layer_type="decoder",
zero_centered_gamma=zero_centered_gamma,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -376,8 +458,11 @@ def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -376,8 +458,11 @@ def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model, skip_wgrad): def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -407,8 +492,11 @@ def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -407,8 +492,11 @@ def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_drop_path(dtype, bs, fp8_recipe, model, skip_wgrad): def test_sanity_drop_path(dtype, bs, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -441,8 +529,11 @@ def test_sanity_drop_path(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -441,8 +529,11 @@ def test_sanity_drop_path(dtype, bs, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model, skip_wgrad): def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -469,3 +560,43 @@ def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -469,3 +560,43 @@ def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model, skip_wgrad):
) )
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad) _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_gradient_accumulation_fusion(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
fuse_qkv_params=True,
fuse_wgrad_accumulation=True,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_recipe, skip_wgrad)
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
import pytest
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
)
from transformer_engine.pytorch import (
LayerNormLinear,
Linear,
LayerNormMLP,
TransformerLayer,
)
class ModelConfig:
def __init__(
self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
):
self.hidden_size = hidden_size
self.eps = eps
self.num_attention_heads = num_attention_heads
self.embed = embed
self.num_layers = num_layers
self.seq_len = seq_len
model_configs = {
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048),
}
param_types = [torch.float32, torch.bfloat16, torch.float16]
batch_sizes = [1, 2]
all_boolean = [True, False]
def _disable_wgrads(block):
for p in block.parameters():
p.requires_grad = False
def _test_sanity_e2e_amp(block, bs, dtype, config, skip_wgrad):
if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
return
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True
).cuda()
te_inp_attn_mask = (
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
if skip_wgrad:
_disable_wgrads(block)
with torch.cuda.amp.autocast(enabled=True, dtype=dtype):
te_out = block(te_inp_hidden_states, te_inp_attn_mask)
loss = te_out.sum()
assert te_out.dtype == dtype
loss.backward()
torch.cuda.synchronize()
def _test_sanity_e2e(block, bs, dtype, config, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = (
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
if skip_wgrad:
_disable_wgrads(block)
te_out = block(te_inp_hidden_states, te_inp_attn_mask)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
def _test_sanity_e2e_T5(block, bs, dtype, config, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = (
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
if skip_wgrad:
_disable_wgrads(block)
te_out = block(
te_inp_hidden_states, te_inp_attn_mask, encoder_output=te_inp_hidden_states
)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
def _test_sanity_common(block, bs, dtype, config, skip_wgrad):
te_inp = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
if skip_wgrad:
_disable_wgrads(block)
te_out = block(te_inp)
if isinstance(te_out, tuple):
te_out = te_out[0]
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_layernorm_linear(dtype, bs, model, skip_wgrad, zero_centered_gamma):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
block = (
LayerNormLinear(
config.hidden_size,
config.hidden_size * 3,
eps=config.eps,
init_method=init_method,
zero_centered_gamma=zero_centered_gamma,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_linear(dtype, bs, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
Linear(
config.hidden_size, config.hidden_size, init_method=output_layer_init_method
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_layernorm_mlp(dtype, bs, model, skip_wgrad, zero_centered_gamma):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
LayerNormMLP(
config.hidden_size,
4 * config.hidden_size,
eps=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
zero_centered_gamma=zero_centered_gamma,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_gpt(dtype, bs, model, skip_wgrad, zero_centered_gamma):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_bert(dtype, bs, model, skip_wgrad, zero_centered_gamma):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=True,
output_layernorm=True,
zero_centered_gamma=zero_centered_gamma,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_T5(dtype, bs, model, skip_wgrad, zero_centered_gamma):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
layer_type="decoder",
zero_centered_gamma=zero_centered_gamma,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_e2e_T5(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_amp_and_nvfuser(dtype, bs, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
)
.to(dtype=torch.float32)
.cuda()
)
_test_sanity_e2e_amp(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_drop_path(dtype, bs, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
drop_path_rate=1.0,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_fused_qkv_params(dtype, bs, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
fuse_qkv_params=True,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config, skip_wgrad)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment