Unverified Commit 2c996359 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Improve PyTorch test harness (#102)



* add layernorm1p fp8 test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* combine tests for easy maintenance
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* using torch.autocast for AMP and check grad types
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add test for wgrad accumulation fusion
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rename file
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Setup numerical tests + SAR
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add test for full activation recompute
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add tests for checkpoint load/store
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* TE vs framework numerical tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix ci
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* relax thresholds
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d1d00b3e
...@@ -7,5 +7,6 @@ set -e ...@@ -7,5 +7,6 @@ set -e
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
pip install pytest==6.2.5 onnxruntime==1.13.1 pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/test_transformerengine.py $TE_PATH/tests/pytorch/test_fp8.py pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
NVTE_FLASH_ATTN=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py NVTE_FLASH_ATTN=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import contextlib
from typing import List, Optional
import pytest
import torch
import torch.nn as nn
from torch.nn import Parameter
from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
)
from transformer_engine.pytorch import Linear, LayerNormLinear, TransformerLayer
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
seed = 1234
rng_str = "rng_state"
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Record initial RNG state from script run.
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
class ModelConfig:
def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len):
self.hidden_size = hidden_size
self.eps = eps
self.num_attention_heads = num_attention_heads
self.embed = embed
self.num_layers = num_layers
self.seq_len = seq_len
model_configs = {
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048),
}
param_types = [torch.float32, torch.float16]
if torch.cuda.is_bf16_supported():
param_types.append(torch.bfloat16)
batch_sizes = [1, 2]
all_boolean = [True, False]
def get_causal_attn_mask(sq: int) -> torch.Tensor:
return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor]) -> bool:
"""Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs."
for t1, t2 in zip(l1, l2):
assert torch.equal(t1, t2), "Output mismatch."
def assert_allclose(l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float) -> bool:
"""Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs."
for t1, t2 in zip(l1, l2):
assert torch.allclose(t1, t2, atol=atol), "Outputs not close enough."
def _set_cuda_rng_state(new_state, device=-1):
"""Sets the random number generator state of the current GPU.
Argumentss:
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases.
"""
if hasattr(_C, "_cuda_setRNGState") and callable(_C._cuda_setRNGState):
# older PyTorch
def cb():
with device_ctx_manager(device):
_C._cuda_setRNGState(new_state)
else:
# newer PyTorch
if device == -1:
device = torch.device("cuda")
elif isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("cuda", device)
def cb():
idx = device.index
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
default_generator.set_state(new_state)
_lazy_call(cb)
def reset_rng_states() -> None:
# revert back to initial RNG state.
torch.set_rng_state(_cpu_rng_state)
_set_cuda_rng_state(_cuda_rng_state)
class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
Using the `add` method, a cuda rng state is initialized based on
the input `seed` and is assigned to `name`. Later, by forking the
rng state, we can perform operations and return to our starting
cuda state.
"""
def __init__(self):
# Map from a string name to the cuda rng state.
self.states_ = {}
# Seeds are just for book keeping and ensure no seed is set twice.
self.seeds_ = set()
def reset(self):
"""Set to the initial state (no tracker)."""
self.states_ = {}
self.seeds_ = set()
def get_states(self):
"""Get rng states. Copy the dictionary so we have direct
pointers to the states, not just a pointer to the dictionary."""
states = {}
for name in self.states_:
states[name] = self.states_[name]
return states
def set_states(self, states):
"""Set the rng states. For efficiency purposes, we do not check
the size of seed for compatibility."""
self.states_ = states
def add(self, name, seed):
"""Track the rng state."""
# Check seed is not already used.
if seed in self.seeds_:
raise Exception("seed {} already exists".format(seed))
self.seeds_.add(seed)
# Check that state is not already defined.
if name in self.states_:
raise Exception("cuda rng state {} already exists".format(name))
# Get the current rng state.
orig_rng_state = torch.cuda.get_rng_state()
# Set the new state and store it.
torch.cuda.manual_seed(seed)
self.states_[name] = torch.cuda.get_rng_state()
# Reset rng state to what it was.
_set_cuda_rng_state(orig_rng_state)
@contextlib.contextmanager
def fork(self, name=rng_str):
"""Fork the cuda rng state, perform operations, and exit with
the original state."""
# Check if we have added the state
if name not in self.states_:
raise Exception("cuda rng state {} is not added".format(name))
# Store current rng state.
orig_cuda_rng_state = torch.cuda.get_rng_state()
# Set rng state to the desired one
_set_cuda_rng_state(self.states_[name])
# Do the stuff we wanted to do.
try:
yield
finally:
# Update the current rng state for later use.
self.states_[name] = torch.cuda.get_rng_state()
# And set the state to the original state we started with.
_set_cuda_rng_state(orig_cuda_rng_state)
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add(rng_str, seed)
def get_dummy_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
class TorchLayerNormLinear(nn.Module):
def __init__(self, in_features: int, out_features: int, eps: float, bias: bool = True):
super().__init__()
self.layernorm = nn.LayerNorm(in_features, eps=eps)
self.linear = nn.Linear(in_features, out_features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(self.layernorm(x))
class TorchMHA(nn.Module):
def __init__(self, hidden_size: int, num_attention_heads: int):
super().__init__()
self.mhsa = nn.MultiheadAttention(
embed_dim=hidden_size,
num_heads=num_attention_heads,
dropout=0.1,
bias=True,
batch_first=False,
)
def forward(self, x, attn_mask=None):
return self.mhsa(x, x, x, attn_mask=attn_mask, need_weights=False)
class TorchMLP(nn.Module):
def __init__(self, hidden_size: int):
super().__init__()
self.fc1 = nn.Linear(hidden_size, 4 * hidden_size)
self.gelu = nn.GELU(approximate="tanh")
self.fc2 = nn.Linear(4 * hidden_size, hidden_size)
def forward(self, x):
return self.fc2(self.gelu(self.fc1(x)))
class TorchGPT(nn.Module):
def __init__(self, hidden_size: int, eps: float, num_attention_heads: int):
super().__init__()
self.ln_1 = nn.LayerNorm(hidden_size, eps=eps)
self.causal_attn = TorchMHA(hidden_size, num_attention_heads)
self.ln_2 = nn.LayerNorm(hidden_size, eps=eps)
self.mlp = TorchMLP(hidden_size)
self.resid_attn_dropout = nn.Dropout(0.1)
self.resid_mlp_dropout = nn.Dropout(0.1)
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
a = self.ln_1(x)
b, _ = self.causal_attn(a, attn_mask)
x = x + self.resid_attn_dropout(b)
m = self.ln_2(x)
n = self.mlp(m)
x = x + self.resid_mlp_dropout(n)
return x
def _test_e2e_selective_recompute(block, bs, dtype, config, recompute=False):
reset_rng_states()
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
te_out = block(
te_inp_hidden_states,
te_inp_attn_mask,
checkpoint_core_attention=recompute,
)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
outputs = [te_out, te_inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_gpt_selective_activation_recompute(dtype, bs, model):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype,
)
.cuda()
.eval()
)
outputs = _test_e2e_selective_recompute(block, bs, dtype, config, recompute=False)
outputs_recompute = _test_e2e_selective_recompute(block, bs, dtype, config, recompute=True)
assert_all_equal(outputs, outputs_recompute)
def _test_e2e_full_recompute(block, bs, dtype, config, recompute=False):
reset_rng_states()
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
if recompute:
te_out = te_checkpoint(
block,
False, # distribute_saved_activations
get_dummy_cuda_rng_tracker,
None, # tp_group
te_inp_hidden_states,
te_inp_attn_mask,
checkpoint_core_attention=False,
)
else:
te_out = block(
te_inp_hidden_states,
te_inp_attn_mask,
checkpoint_core_attention=False,
)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
outputs = [te_out, te_inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_gpt_full_activation_recompute(dtype, bs, model):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype,
)
.cuda()
.eval()
)
outputs = _test_e2e_full_recompute(block, bs, dtype, config, recompute=False)
outputs_recompute = _test_e2e_full_recompute(block, bs, dtype, config, recompute=True)
assert_all_equal(outputs, outputs_recompute)
def _test_e2e_checkpointing_get_model(config, dtype):
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
return (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
params_dtype=dtype,
)
.cuda()
.eval()
)
def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path="checkpoint.pt"):
reset_rng_states()
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
block = _test_e2e_checkpointing_get_model(config, dtype)
for _ in range(steps // 2):
te_out = block(
te_inp_hidden_states,
te_inp_attn_mask,
)
loss = te_out.sum()
loss.backward()
if checkpoint:
# This process is necessary so that we can start afresh with
# a new model while erasing all internal state to ensure that
# loading from a checkpoint gives bitwise identical results.
# Since gradients are being accumulated, it is important to
# restore them post loading the checkpoint.
torch.save(block.state_dict(), path)
param_grads = []
for p in block.parameters():
if p.requires_grad:
param_grads.append(p.grad.clone())
del block
block = _test_e2e_checkpointing_get_model(config, dtype)
block.load_state_dict(torch.load(path))
for p in block.parameters():
if p.requires_grad:
p.grad = param_grads.pop(0)
assert not param_grads, "Oops!"
for _ in range(steps // 2):
te_out = block(
te_inp_hidden_states,
te_inp_attn_mask,
)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
if os.path.exists(path):
os.remove(path)
outputs = [te_out, te_inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_gpt_checkpointing(dtype, bs, model):
config = model_configs[model]
outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
outputs_recompute = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
assert_all_equal(outputs, outputs_recompute)
def _test_e2e_gpt_accuracy(block, bs, dtype, config):
reset_rng_states()
inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
inp_hidden_states.retain_grad()
inp_attn_mask = get_causal_attn_mask(config.seq_len)
out = block(inp_hidden_states, inp_attn_mask)
loss = out.sum()
loss.backward()
torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_gpt_accuracy(dtype, bs, model):
config = model_configs[model]
te_gpt = (
TransformerLayer(
hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
num_attention_heads=config.num_attention_heads,
layernorm_epsilon=config.eps,
attention_dropout=0.1,
hidden_dropout=0.1,
fuse_qkv_params=True,
qkv_weight_interleaved=False,
)
.to(dtype=dtype)
.cuda()
.eval()
)
torch_gpt = (
TorchGPT(
config.hidden_size,
config.eps,
config.num_attention_heads,
)
.to(dtype=dtype)
.cuda()
.eval()
)
# Share params
with torch.no_grad():
torch_gpt.ln_1.weight = Parameter(
te_gpt.self_attention.layernorm_qkv.layer_norm_weight.clone()
)
torch_gpt.ln_1.bias = Parameter(te_gpt.self_attention.layernorm_qkv.layer_norm_bias.clone())
torch_gpt.causal_attn.mhsa.in_proj_weight = Parameter(
te_gpt.self_attention.layernorm_qkv.weight.clone()
)
torch_gpt.causal_attn.mhsa.in_proj_bias = Parameter(
te_gpt.self_attention.layernorm_qkv.bias.clone()
)
torch_gpt.causal_attn.mhsa.out_proj.weight = Parameter(
te_gpt.self_attention.proj.weight.clone()
)
torch_gpt.causal_attn.mhsa.out_proj.bias = Parameter(
te_gpt.self_attention.proj.bias.clone()
)
torch_gpt.ln_2.weight = Parameter(te_gpt.layernorm_mlp.layer_norm_weight.clone())
torch_gpt.ln_2.bias = Parameter(te_gpt.layernorm_mlp.layer_norm_bias.clone())
torch_gpt.mlp.fc1.weight = Parameter(te_gpt.layernorm_mlp.fc1_weight.clone())
torch_gpt.mlp.fc1.bias = Parameter(te_gpt.layernorm_mlp.fc1_bias.clone())
torch_gpt.mlp.fc2.weight = Parameter(te_gpt.layernorm_mlp.fc2_weight.clone())
torch_gpt.mlp.fc2.bias = Parameter(te_gpt.layernorm_mlp.fc2_bias.clone())
te_outputs = _test_e2e_gpt_accuracy(te_gpt, bs, dtype, config)
torch_outputs = _test_e2e_gpt_accuracy(torch_gpt, bs, dtype, config)
# Check output.
if dtype == torch.float32:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
else:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
def _test_granular_accuracy(block, bs, dtype, config):
reset_rng_states()
inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
inp_hidden_states.retain_grad()
out = block(inp_hidden_states)
loss = out.sum()
loss.backward()
torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_linear_accuracy(dtype, bs, model):
config = model_configs[model]
te_linear = (
Linear(
config.hidden_size,
4 * config.hidden_size,
bias=True,
)
.to(dtype=dtype)
.cuda()
.eval()
)
torch_linear = (
torch.nn.Linear(
config.hidden_size,
4 * config.hidden_size,
bias=True,
)
.to(dtype=dtype)
.cuda()
.eval()
)
# Share params
with torch.no_grad():
torch_linear.weight = Parameter(te_linear.weight.clone())
torch_linear.bias = Parameter(te_linear.bias.clone())
te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_linear, bs, dtype, config)
# Check output.
if dtype == torch.float32:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
else:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_layernorm_linear_accuracy(dtype, bs, model):
config = model_configs[model]
te_ln_linear = (
LayerNormLinear(
config.hidden_size,
4 * config.hidden_size,
config.eps,
bias=True,
)
.to(dtype=dtype)
.cuda()
.eval()
)
torch_ln_linear = (
TorchLayerNormLinear(
config.hidden_size,
4 * config.hidden_size,
config.eps,
bias=True,
)
.to(dtype=dtype)
.cuda()
.eval()
)
# Share params
with torch.no_grad():
torch_ln_linear.layernorm.weight = Parameter(te_ln_linear.layer_norm_weight.clone())
torch_ln_linear.layernorm.bias = Parameter(te_ln_linear.layer_norm_bias.clone())
torch_ln_linear.linear.weight = Parameter(te_ln_linear.weight.clone())
torch_ln_linear.linear.bias = Parameter(te_ln_linear.bias.clone())
te_outputs = _test_granular_accuracy(te_ln_linear, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config)
# Check output.
if dtype == torch.float32:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
else:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
...@@ -19,8 +19,7 @@ from transformer_engine.pytorch import ( ...@@ -19,8 +19,7 @@ from transformer_engine.pytorch import (
from transformer_engine.common import recipe from transformer_engine.common import recipe
# Only run FP8 tests on H100. # Only run FP8 tests on H100.
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9: fp8_available = torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9
pytest.skip(allow_module_level=True)
def custom_amax_to_scale( def custom_amax_to_scale(
...@@ -59,6 +58,7 @@ model_configs = { ...@@ -59,6 +58,7 @@ model_configs = {
} }
fp8_recipes = [ fp8_recipes = [
None, # Handles non-FP8 case
recipe.DelayedScaling(0, 1, recipe.Format.E4M3), recipe.DelayedScaling(0, 1, recipe.Format.E4M3),
recipe.DelayedScaling(0, 1, recipe.Format.HYBRID), recipe.DelayedScaling(0, 1, recipe.Format.HYBRID),
recipe.DelayedScaling( recipe.DelayedScaling(
...@@ -86,11 +86,13 @@ fp8_recipes = [ ...@@ -86,11 +86,13 @@ fp8_recipes = [
), ),
] ]
param_types = [torch.float32, torch.bfloat16, torch.float16] param_types = [torch.float32, torch.float16]
if torch.cuda.is_bf16_supported():
param_types.append(torch.bfloat16)
batch_sizes = [1, 2] batch_sizes = [1, 2]
skip_wgrad = [True, False] all_boolean = [True, False]
def _disable_wgrads(block): def _disable_wgrads(block):
...@@ -102,6 +104,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad): ...@@ -102,6 +104,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True
).cuda() ).cuda()
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = ( te_inp_attn_mask = (
torch.rand( torch.rand(
( (
...@@ -118,15 +121,63 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad): ...@@ -118,15 +121,63 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad):
if skip_wgrad: if skip_wgrad:
_disable_wgrads(block) _disable_wgrads(block)
with torch.cuda.amp.autocast(enabled=True, dtype=dtype): use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, te_inp_attn_mask) te_out = block(te_inp_hidden_states, te_inp_attn_mask)
loss = te_out.sum() loss = te_out.sum()
assert te_out.dtype == dtype
loss.backward() loss.backward()
torch.cuda.synchronize() torch.cuda.synchronize()
assert te_out.dtype == dtype, "AMP wrong output type."
assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type."
for name, p in block.named_parameters():
if p.requires_grad:
assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."
def _test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = (
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
if skip_wgrad:
_disable_wgrads(block)
for name, p in block.named_parameters():
if "layer_norm_weight" in name:
continue
elif "weight" in name and p.requires_grad:
p.main_grad = torch.zeros_like(p)
use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, te_inp_attn_mask)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
for name, p in block.named_parameters():
if "layer_norm_weight" in name:
continue
elif "weight" in name and p.requires_grad:
assert (
p.grad is None and torch.count_nonzero(p.main_grad) > 0
), "Gradient not accumulated."
def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
...@@ -148,7 +199,8 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad): ...@@ -148,7 +199,8 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad):
if skip_wgrad: if skip_wgrad:
_disable_wgrads(block) _disable_wgrads(block)
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, te_inp_attn_mask) te_out = block(te_inp_hidden_states, te_inp_attn_mask)
loss = te_out.sum() loss = te_out.sum()
loss.backward() loss.backward()
...@@ -175,7 +227,8 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad): ...@@ -175,7 +227,8 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad):
if skip_wgrad: if skip_wgrad:
_disable_wgrads(block) _disable_wgrads(block)
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block( te_out = block(
te_inp_hidden_states, te_inp_attn_mask, encoder_output=te_inp_hidden_states te_inp_hidden_states, te_inp_attn_mask, encoder_output=te_inp_hidden_states
) )
...@@ -192,7 +245,8 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad): ...@@ -192,7 +245,8 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad):
if skip_wgrad: if skip_wgrad:
_disable_wgrads(block) _disable_wgrads(block)
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp) te_out = block(te_inp)
if isinstance(te_out, tuple): if isinstance(te_out, tuple):
te_out = te_out[0] te_out = te_out[0]
...@@ -205,8 +259,12 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad): ...@@ -205,8 +259,12 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad):
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad): @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -218,6 +276,7 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -218,6 +276,7 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad):
config.hidden_size * 3, config.hidden_size * 3,
eps=config.eps, eps=config.eps,
init_method=init_method, init_method=init_method,
zero_centered_gamma=zero_centered_gamma,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -229,8 +288,11 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -229,8 +288,11 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad): def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -250,8 +312,12 @@ def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -250,8 +312,12 @@ def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad): @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -265,6 +331,7 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -265,6 +331,7 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad):
eps=config.eps, eps=config.eps,
init_method=init_method, init_method=init_method,
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
zero_centered_gamma=zero_centered_gamma,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -276,8 +343,12 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -276,8 +343,12 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad): @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -297,6 +368,7 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -297,6 +368,7 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad):
kv_channels=config.embed, kv_channels=config.embed,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -309,8 +381,12 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -309,8 +381,12 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad): @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -330,6 +406,7 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -330,6 +406,7 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad):
kv_channels=config.embed, kv_channels=config.embed,
apply_residual_connection_post_layernorm=True, apply_residual_connection_post_layernorm=True,
output_layernorm=True, output_layernorm=True,
zero_centered_gamma=zero_centered_gamma,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -342,8 +419,12 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -342,8 +419,12 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad): @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -364,6 +445,7 @@ def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -364,6 +445,7 @@ def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad):
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
layer_type="decoder", layer_type="decoder",
zero_centered_gamma=zero_centered_gamma,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -376,8 +458,11 @@ def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -376,8 +458,11 @@ def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model, skip_wgrad): def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -407,8 +492,11 @@ def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -407,8 +492,11 @@ def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_drop_path(dtype, bs, fp8_recipe, model, skip_wgrad): def test_sanity_drop_path(dtype, bs, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -441,8 +529,11 @@ def test_sanity_drop_path(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -441,8 +529,11 @@ def test_sanity_drop_path(dtype, bs, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model, skip_wgrad): def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -469,3 +560,43 @@ def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -469,3 +560,43 @@ def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model, skip_wgrad):
) )
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad) _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_gradient_accumulation_fusion(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
fuse_qkv_params=True,
fuse_wgrad_accumulation=True,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_recipe, skip_wgrad)
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
import pytest
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
)
from transformer_engine.pytorch import (
LayerNormLinear,
Linear,
LayerNormMLP,
TransformerLayer,
)
class ModelConfig:
def __init__(
self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
):
self.hidden_size = hidden_size
self.eps = eps
self.num_attention_heads = num_attention_heads
self.embed = embed
self.num_layers = num_layers
self.seq_len = seq_len
model_configs = {
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048),
}
param_types = [torch.float32, torch.bfloat16, torch.float16]
batch_sizes = [1, 2]
all_boolean = [True, False]
def _disable_wgrads(block):
for p in block.parameters():
p.requires_grad = False
def _test_sanity_e2e_amp(block, bs, dtype, config, skip_wgrad):
if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
return
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True
).cuda()
te_inp_attn_mask = (
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
if skip_wgrad:
_disable_wgrads(block)
with torch.cuda.amp.autocast(enabled=True, dtype=dtype):
te_out = block(te_inp_hidden_states, te_inp_attn_mask)
loss = te_out.sum()
assert te_out.dtype == dtype
loss.backward()
torch.cuda.synchronize()
def _test_sanity_e2e(block, bs, dtype, config, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = (
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
if skip_wgrad:
_disable_wgrads(block)
te_out = block(te_inp_hidden_states, te_inp_attn_mask)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
def _test_sanity_e2e_T5(block, bs, dtype, config, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = (
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
if skip_wgrad:
_disable_wgrads(block)
te_out = block(
te_inp_hidden_states, te_inp_attn_mask, encoder_output=te_inp_hidden_states
)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
def _test_sanity_common(block, bs, dtype, config, skip_wgrad):
te_inp = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
if skip_wgrad:
_disable_wgrads(block)
te_out = block(te_inp)
if isinstance(te_out, tuple):
te_out = te_out[0]
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_layernorm_linear(dtype, bs, model, skip_wgrad, zero_centered_gamma):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
block = (
LayerNormLinear(
config.hidden_size,
config.hidden_size * 3,
eps=config.eps,
init_method=init_method,
zero_centered_gamma=zero_centered_gamma,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_linear(dtype, bs, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
Linear(
config.hidden_size, config.hidden_size, init_method=output_layer_init_method
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_layernorm_mlp(dtype, bs, model, skip_wgrad, zero_centered_gamma):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
LayerNormMLP(
config.hidden_size,
4 * config.hidden_size,
eps=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
zero_centered_gamma=zero_centered_gamma,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_gpt(dtype, bs, model, skip_wgrad, zero_centered_gamma):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_bert(dtype, bs, model, skip_wgrad, zero_centered_gamma):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=True,
output_layernorm=True,
zero_centered_gamma=zero_centered_gamma,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_T5(dtype, bs, model, skip_wgrad, zero_centered_gamma):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
layer_type="decoder",
zero_centered_gamma=zero_centered_gamma,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_e2e_T5(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_amp_and_nvfuser(dtype, bs, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
)
.to(dtype=torch.float32)
.cuda()
)
_test_sanity_e2e_amp(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_drop_path(dtype, bs, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
drop_path_rate=1.0,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_fused_qkv_params(dtype, bs, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
fuse_qkv_params=True,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config, skip_wgrad)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment