Unverified Commit 9416519d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Apply formatting (#929)



* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d99142a0
......@@ -9,9 +9,10 @@ from transformer_engine.pytorch.attention import DotProductAttention
import transformer_engine_torch as tex
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn
dtypes={'fp16' : torch.float16, 'bf16' : torch.bfloat16}
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16}
def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend='FlashAttention'):
def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention"):
"""Test DotProductAttention module with context parallelism"""
os.environ["NVTE_FLASH_ATTN"] = "0"
......@@ -22,11 +23,13 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model]
if qkv_format == 'thd' and (config.num_heads != config.num_gqa_groups or config.attn_bias_type == "post_scale_bias"):
if qkv_format == "thd" and (
config.num_heads != config.num_gqa_groups or config.attn_bias_type == "post_scale_bias"
):
return
rank = int(os.getenv('RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if dist.is_initialized():
world_size = dist.get_world_size()
......@@ -38,51 +41,76 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
print(f"[INFO] world_size:{world_size}, rank:{rank}")
dist.init_process_group(backend='nccl', world_size=world_size, rank=rank)
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
# create flash attn comm group for CP
cp_comm_ranks = range(world_size)
assert(rank in cp_comm_ranks)
cp_comm_group = dist.new_group(cp_comm_ranks, backend='nccl')
assert rank in cp_comm_ranks
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
assert config.attn_mask_type in ['causal', 'no_mask'], f"{config.attn_mask_type} is an unsupported attention mask type!"
assert config.attn_mask_type in [
"causal",
"no_mask",
], f"{config.attn_mask_type} is an unsupported attention mask type!"
if kernel_backend == 'FusedAttention' and qkv_format == 'thd':
if 'causal' in config.attn_mask_type:
config.attn_mask_type = 'padding_causal'
if kernel_backend == "FusedAttention" and qkv_format == "thd":
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = 'padding'
config.attn_mask_type = "padding"
# instantiate core attn module
core_attn = DotProductAttention(config.num_heads,
config.head_dim,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type)
core_attn = DotProductAttention(
config.num_heads,
config.head_dim,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type,
)
core_attn = core_attn.cuda()
# create flash attn inputs
if qkv_format == "bshd":
q_input_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim)
kv_input_shape = (config.batch_size, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim)
attn_output_shape = (config.batch_size, config.max_seqlen_q, config.num_heads*config.head_dim)
kv_input_shape = (
config.batch_size,
config.max_seqlen_kv,
config.num_gqa_groups,
config.head_dim,
)
attn_output_shape = (
config.batch_size,
config.max_seqlen_q,
config.num_heads * config.head_dim,
)
cu_seqlens_q = None
cu_seqlens_kv = None
elif qkv_format == "sbhd":
q_input_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim)
kv_input_shape = (config.max_seqlen_kv, config.batch_size, config.num_gqa_groups, config.head_dim)
attn_output_shape = (config.max_seqlen_q, config.batch_size, config.num_heads*config.head_dim)
kv_input_shape = (
config.max_seqlen_kv,
config.batch_size,
config.num_gqa_groups,
config.head_dim,
)
attn_output_shape = (
config.max_seqlen_q,
config.batch_size,
config.num_heads * config.head_dim,
)
cu_seqlens_q = None
cu_seqlens_kv = None
elif qkv_format == "thd":
seqlens_q = torch.randint(world_size * 2, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
seqlens_q = torch.randint(world_size * 2, config.max_seqlen_q + 1, [config.batch_size]).to(
torch.int32
)
seqlens_q = seqlens_q - seqlens_q % (world_size * 2)
cu_seqlens_q = torch.cat([torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0)])
cu_seqlens_kv = cu_seqlens_q
q_input_shape = (cu_seqlens_q[-1], config.num_heads, config.head_dim)
kv_input_shape = (cu_seqlens_kv[-1], config.num_gqa_groups, config.head_dim)
attn_output_shape = (cu_seqlens_q[-1], config.num_heads*config.head_dim)
attn_output_shape = (cu_seqlens_q[-1], config.num_heads * config.head_dim)
cu_seqlens_q = cu_seqlens_q.to(torch.int32).cuda()
cu_seqlens_kv = cu_seqlens_kv.to(torch.int32).cuda()
else:
......@@ -111,7 +139,9 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
for x in [q, k, v]:
x.requires_grad = True
out = core_attn(
q, k, v,
q,
k,
v,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
cu_seqlens_q=cu_seqlens_q,
......@@ -120,17 +150,28 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
out.backward(dout)
# run core_attn wit CP
q_, k_, v_, dout_, *rest = [x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])]
q_, k_, v_, dout_, *rest = [
x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])
]
bias_ = rest[0] if len(rest) else None
if qkv_format == "bshd" or qkv_format == "sbhd":
seq_dim = qkv_format.index('s')
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q_, k_, v_, dout_]]
seq_idx = torch.tensor([rank, 2*world_size-rank-1], device=q_.device)
seq_dim = qkv_format.index("s")
q_, k_, v_, dout_ = [
x.view(
*x.shape[:seq_dim],
2 * world_size,
x.shape[seq_dim] // (2 * world_size),
*x.shape[(seq_dim + 1) :],
)
for x in [q_, k_, v_, dout_]
]
seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=q_.device)
q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]]
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim+2):]) for x in [q_, k_, v_, dout_]]
q_, k_, v_, dout_ = [
x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) for x in [q_, k_, v_, dout_]
]
elif qkv_format == "thd":
seq_idx_q = tex.thd_get_partitioned_indices(cu_seqlens_q, q_.size(0), world_size, rank)
seq_idx_q = tex.thd_get_partitioned_indices(cu_seqlens_q, q_.size(0), world_size, rank)
seq_idx_kv = tex.thd_get_partitioned_indices(cu_seqlens_kv, k_.size(0), world_size, rank)
q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]]
k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]]
......@@ -140,14 +181,18 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
assert False, f"{qkv_format} is an unsupported qkv_format!"
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
if bias_ is not None:
bias_ = bias_.view(*bias_.shape[:-2], 2*world_size, bias_.shape[-2]//(2*world_size), bias_.shape[-1])
bias_ = bias_.view(
*bias_.shape[:-2], 2 * world_size, bias_.shape[-2] // (2 * world_size), bias_.shape[-1]
)
bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream())
max_seqlen_q = config.max_seqlen_q
max_seqlen_q = config.max_seqlen_q
max_seqlen_kv = config.max_seqlen_kv
out_ = core_attn(
q_, k_, v_,
q_,
k_,
v_,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias_,
cu_seqlens_q=cu_seqlens_q,
......@@ -158,23 +203,32 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
out_.backward(dout_)
for x in [out_, q_.grad, k_.grad, v_.grad]:
assert(torch.all(~torch.isnan(x)))
assert(torch.all(~torch.isinf(x)))
assert torch.all(~torch.isnan(x))
assert torch.all(~torch.isinf(x))
# compare results with and without CP
tols = dict(atol=5e-3, rtol=5e-3)
if dtype == 'bf16':
if dtype == "bf16":
if config.num_heads == config.num_gqa_groups:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=3.5e-2, rtol=3.5e-2)
if qkv_format == "bshd" or qkv_format == "sbhd":
dq, dk, dv, out = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q.grad, k.grad, v.grad, out]]
dq, dk, dv, out = [
x.view(
*x.shape[:seq_dim],
2 * world_size,
x.shape[seq_dim] // (2 * world_size),
*x.shape[(seq_dim + 1) :],
)
for x in [q.grad, k.grad, v.grad, out]
]
dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]]
dq_, dk_, dv_, out_ = [x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim]//2, *x.shape[(seq_dim+1):]) \
for x in [q_.grad, k_.grad, v_.grad, out_]]
dq_, dk_, dv_, out_ = [
x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :])
for x in [q_.grad, k_.grad, v_.grad, out_]
]
elif qkv_format == "thd":
dq, out = [x.index_select(0, seq_idx_q).contiguous().view(-1) for x in [q.grad, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous().view(-1) for x in [k.grad, v.grad]]
......@@ -208,9 +262,11 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
def main(**kwargs):
run_dpa_with_cp(**kwargs)
if __name__ == "__main__":
kwargs = dict(arg.split('=') for arg in sys.argv[2:])
kwargs = dict(arg.split("=") for arg in sys.argv[2:])
main(**kwargs)
......@@ -91,10 +91,10 @@ class ModelConfig:
self.max_seqlen_q = max_seqlen_q
self.max_seqlen_kv = max_seqlen_kv
self.dropout_p = dropout_p
self.attn_mask_type = attn_mask_type
self.attn_bias_type = attn_bias_type
self.alibi_type = alibi_type
self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross"
self.attn_mask_type = attn_mask_type
self.attn_bias_type = attn_bias_type
self.alibi_type = alibi_type
self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross"
self.num_layers = num_layers
self.bias_shape = bias_shape
......@@ -184,28 +184,29 @@ def _is_flash_attention_supported(config: ModelConfig) -> bool:
return False
return True
def _is_unfused_attention_supported(
config: ModelConfig,
qkv_format: str,
) -> bool:
) -> bool:
"""Check if UnfusedDotProductAttention supports a model configuration"""
if ("padding" in config.attn_mask_type):
if "padding" in config.attn_mask_type:
return False
if ("causal" in config.attn_mask_type and config.attn_type == 'cross'):
if "causal" in config.attn_mask_type and config.attn_type == "cross":
return False
if qkv_format == 'thd':
if qkv_format == "thd":
return False
return True
model_configs_base = {
# test: b, h, hg, d, sq, skv, p, mask, bias # attn , backend
"base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0
"base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), # cross, 0
"base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), # self , 1
"base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1
"base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
"base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
"base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0
"base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), # cross, 0
"base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), # self , 1
"base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1
"base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
"base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
}
......@@ -221,13 +222,13 @@ def get_swa(seq_q, seq_kv, w=None):
if w is None:
w = torch.randint(0, seq_kv, [2], dtype=torch.int32, device="cuda")
m = torch.ones(seq_q, seq_kv, dtype=torch.bool, device="cuda")
mu = torch.triu(m, diagonal=seq_kv-seq_q-w[0])
ml = torch.tril(mu, diagonal=seq_kv-seq_q+w[1])
ml = ~ ml
mu = torch.triu(m, diagonal=seq_kv - seq_q - w[0])
ml = torch.tril(mu, diagonal=seq_kv - seq_q + w[1])
ml = ~ml
return w, ml
@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", model_configs_base.keys())
......@@ -236,8 +237,9 @@ def get_swa(seq_q, seq_kv, w=None):
@pytest.mark.parametrize("qkv_layout", [None])
@pytest.mark.parametrize("swa", [False])
@pytest.mark.parametrize("pad_between_seqs", [False])
def test_dot_product_attention(dtype, model_configs, model, ckpt_attn,
workspace_opt, qkv_layout, swa, pad_between_seqs):
def test_dot_product_attention(
dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa, pad_between_seqs
):
"""Test DotProductAttention module"""
# Get configs
......@@ -251,36 +253,43 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn,
else:
qkv_layout = "sbhd_sb2hd"
if "3" in qkv_layout and config.attn_type == "cross":
pytest.skip(
"No need to test this layout for cross attention"
)
pytest.skip("No need to test this layout for cross attention")
# Skip if only unfused backend is supported
qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
unfused_attn_supported = _is_unfused_attention_supported(config, qkv_format)
if config.max_seqlen_q <= 512 and config.max_seqlen_kv <= 512:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config, dtype, qkv_layout=qkv_layout,
config,
dtype,
qkv_layout=qkv_layout,
)
if swa:
fused_attn_supported = False
flash_attn_supported = _is_flash_attention_supported(config)
if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
if (qkv_format == 'thd' and 'padding' not in config.attn_mask_type):
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
pytest.skip("THD layout requires padding/padding_causal mask type.")
# d=256 is supported by cuDNN 9.0+ for inference but not training
is_training = (config.head_dim <= 128)
is_training = config.head_dim <= 128
# UnfusedDotProductAttention backend
if unfused_attn_supported:
if swa:
attn_mask_type = config.attn_mask_type
config.attn_mask_type = "arbitrary"
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
dtype, config, "UnfusedDotProductAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
dtype,
config,
"UnfusedDotProductAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
if swa:
config.attn_mask_type = attn_mask_type
......@@ -289,51 +298,79 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn,
if fused_attn_supported:
if len(fused_attn_backend) == 1:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
if len(fused_attn_backend) == 2:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
dtype, config, "FusedAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
# FlashAttention backend
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, config, "FlashAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
dtype,
config,
"FlashAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
if unfused_attn_supported and fused_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
for i,_ in enumerate(unfused_attn_bwd):
for i, _ in enumerate(unfused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
if unfused_attn_supported and flash_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs flash attn")
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
for i,_ in enumerate(flash_attn_bwd):
for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_supported and flash_attn_supported:
logging.info("[test_dot_product_attention]: fused attn vs flash attn")
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i,_ in enumerate(flash_attn_bwd):
for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_supported and len(fused_attn_backend) == 2:
logging.info("[test_dot_product_attention]: fused attn backend 0 vs 1")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
for i,_ in enumerate(fused_attn_bwd):
for i, _ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)
@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", ["base_1_1", "base_2_1"])
......@@ -344,22 +381,22 @@ def test_dpa_checkpoint(dtype, model_configs, model):
model_configs_mask = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
"mask_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"),
"mask_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"mask_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"mask_3_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"),
"mask_3_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"mask_4_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"mask_4_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"mask_5_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
"mask_5_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding_causal", "no_bias"),
"mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
"mask_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"),
"mask_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"mask_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"mask_3_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"),
"mask_3_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"mask_4_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"mask_4_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"mask_5_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
"mask_5_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding_causal", "no_bias"),
"mask_6_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
}
@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_mask])
@pytest.mark.parametrize("model", model_configs_mask.keys())
......@@ -370,34 +407,48 @@ def test_dpa_mask(dtype, model_configs, model):
model_configs_bias = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
"bias_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "post_scale_bias"),
"bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias"),
"bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "post_scale_bias"),
"bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "alibi"), # skipped
"bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "alibi"), # skipped
"bias_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), # skipped
"bias_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "post_scale_bias"), # skipped
"bias_2_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding", "post_scale_bias"), # skipped
"bias_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "post_scale_bias"), # skipped
"bias_2_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding", "alibi"), # skipped
"bias_2_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "alibi"), # skipped
"bias_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"bias_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "post_scale_bias"),
"bias_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"),
"bias_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "post_scale_bias"), # skipped
"bias_3_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi"),
"bias_3_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "alibi"), # skipped
"bias_4_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "padding_causal", "post_scale_bias"), # skipped
"bias_4_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"), # skipped
"bias_4_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "post_scale_bias"), # skipped
"bias_4_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"), # skipped
"bias_4_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "alibi"), # skipped
"bias_4_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "alibi"), # skipped
"bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
"bias_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "post_scale_bias"),
"bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias"),
"bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "post_scale_bias"),
"bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "alibi"), # skipped
"bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "alibi"), # skipped
"bias_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), # skipped
"bias_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "post_scale_bias"), # skipped
"bias_2_2": ModelConfig(
4, 24, 24, 128, 2048, 2048, 0.0, "padding", "post_scale_bias"
), # skipped
"bias_2_3": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding", "post_scale_bias"
), # skipped
"bias_2_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding", "alibi"), # skipped
"bias_2_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "alibi"), # skipped
"bias_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"bias_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "post_scale_bias"),
"bias_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"),
"bias_3_3": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "causal", "post_scale_bias"
), # skipped
"bias_3_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi"),
"bias_3_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "alibi"), # skipped
"bias_4_0": ModelConfig(
4, 16, 16, 64, 128, 128, 0.0, "padding_causal", "post_scale_bias"
), # skipped
"bias_4_1": ModelConfig(
2, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"
), # skipped
"bias_4_2": ModelConfig(
4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "post_scale_bias"
), # skipped
"bias_4_3": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"
), # skipped
"bias_4_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "alibi"), # skipped
"bias_4_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "alibi"), # skipped
}
@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias])
@pytest.mark.parametrize("model", model_configs_bias.keys())
......@@ -408,23 +459,38 @@ def test_dpa_bias(dtype, model_configs, model):
model_configs_bias_shapes = {
# test: b, h, hg, d, sq, skv, p,
"bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0,
# mask, bias, bias_shape,
"no_mask", "post_scale_bias", bias_shape='11ss'),
"bias_1_1": ModelConfig(2, 16, 16, 64, 128, 128, 0.0,
"no_mask", "post_scale_bias", bias_shape='1hss'),
"bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0,
"no_mask", "post_scale_bias", bias_shape='b1ss'),
"bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0,
"no_mask", "post_scale_bias", bias_shape='bhss'),
"bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0,
"causal", "alibi", bias_shape='1hss', alibi_type='custom'),
"bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0,
"causal", "alibi", bias_shape='bhss', alibi_type='custom'),
"bias_1_0": ModelConfig(
4,
16,
16,
64,
128,
128,
0.0,
# mask, bias, bias_shape,
"no_mask",
"post_scale_bias",
bias_shape="11ss",
),
"bias_1_1": ModelConfig(
2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias", bias_shape="1hss"
),
"bias_1_2": ModelConfig(
4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias", bias_shape="b1ss"
),
"bias_1_3": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias", bias_shape="bhss"
),
"bias_1_4": ModelConfig(
4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi", bias_shape="1hss", alibi_type="custom"
),
"bias_1_5": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi", bias_shape="bhss", alibi_type="custom"
),
}
@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias_shapes])
@pytest.mark.parametrize("model", model_configs_bias_shapes.keys())
......@@ -435,10 +501,10 @@ def test_dpa_bias_shapes(dtype, model_configs, model):
model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
"swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"),
"swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
"swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
"swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"),
"swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
}
......@@ -453,10 +519,14 @@ def test_dpa_sliding_window(dtype, model_configs, model):
model_configs_alibi_slopes = {
# test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type
"alibi_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "alibi", alibi_type="vanilla"),
"alibi_1_1": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "causal", "alibi", alibi_type="vanilla"),
"alibi_2_0": ModelConfig(2, 24, 24, 128, 1024, 1024, 0.0, "causal", "alibi", alibi_type= "custom"),
"alibi_2_1": ModelConfig(1, 24, 24, 128, 1024, 2048, 0.0, "causal", "alibi", alibi_type= "custom"),
"alibi_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "alibi", alibi_type="vanilla"),
"alibi_1_1": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "causal", "alibi", alibi_type="vanilla"),
"alibi_2_0": ModelConfig(
2, 24, 24, 128, 1024, 1024, 0.0, "causal", "alibi", alibi_type="custom"
),
"alibi_2_1": ModelConfig(
1, 24, 24, 128, 1024, 2048, 0.0, "causal", "alibi", alibi_type="custom"
),
}
......@@ -470,27 +540,35 @@ def test_dpa_alibi_slopes(dtype, model_configs, model):
qkv_layouts = [
'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd',
'bs3hd', 'bsh3d', 'bshd_bs2hd', 'bshd_bsh2d', 'bshd_bshd_bshd',
]
"sb3hd",
"sbh3d",
"sbhd_sb2hd",
"sbhd_sbh2d",
"sbhd_sbhd_sbhd",
"bs3hd",
"bsh3d",
"bshd_bs2hd",
"bshd_bsh2d",
"bshd_bshd_bshd",
]
model_configs_layout = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
"layout_0_1": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"layout_0_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"layout_0_3": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"),
"layout_1_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"layout_1_1": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"),
"layout_1_2": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"layout_0_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
"layout_0_1": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"layout_0_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"layout_0_3": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"),
"layout_1_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"layout_1_1": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"),
"layout_1_2": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"layout_1_3": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"),
"layout_2_0": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"),
"layout_2_1": ModelConfig(2, 24, 24, 256, 2048, 2048, 0.0, "causal", "post_scale_bias"),
"layout_2_0": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"),
"layout_2_1": ModelConfig(2, 24, 24, 256, 2048, 2048, 0.0, "causal", "post_scale_bias"),
}
@pytest.mark.skipif(get_cudnn_version() < (8,9,5), reason="cuDNN 8.9.5+ is required.")
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 5), reason="cuDNN 8.9.5+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_layout])
@pytest.mark.parametrize("model", model_configs_layout.keys())
......@@ -500,26 +578,28 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)
qkv_layouts_thd = ['t3hd', 'th3d', 'thd_t2hd', 'thd_th2d', 'thd_thd_thd']
qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"]
model_configs_layout_thd = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_1": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"),
"layout_0_2": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"),
"layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
"layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
"layout_1_1": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"layout_1_2": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"layout_1_3": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_1_4": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_2_1": ModelConfig(1, 16, 16, 128, 128, 128, 0.0, "padding", "no_bias"),
"layout_2_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"layout_2_3": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_2_4": ModelConfig(8, 16, 16, 64, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"layout_0_1": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"),
"layout_0_2": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"),
"layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
"layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
"layout_1_1": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"layout_1_2": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"layout_1_3": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_1_4": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_2_1": ModelConfig(1, 16, 16, 128, 128, 128, 0.0, "padding", "no_bias"),
"layout_2_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"layout_2_3": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_2_4": ModelConfig(8, 16, 16, 64, 2048, 4096, 0.0, "padding_causal", "no_bias"),
}
@pytest.mark.skipif(get_cudnn_version() < (9,0,0), reason="cuDNN 9.0.0+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="THD is only supported on Hopper+.")
@pytest.mark.skipif(get_cudnn_version() < (9, 0, 0), reason="cuDNN 9.0.0+ is required.")
@pytest.mark.skipif(
get_device_compute_capability() < (9, 0), reason="THD is only supported on Hopper+."
)
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_layout_thd])
@pytest.mark.parametrize("model", model_configs_layout_thd.keys())
......@@ -527,24 +607,26 @@ model_configs_layout_thd = {
def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with different QKV layouts"""
pad_between_seqs = False
test_dot_product_attention(dtype, model_configs, model, False, True,
qkv_layout, False, pad_between_seqs)
test_dot_product_attention(
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
)
pad_between_seqs = True
test_dot_product_attention(dtype, model_configs, model, False, True,
qkv_layout, False, pad_between_seqs)
test_dot_product_attention(
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
)
def _run_dot_product_attention(
dtype: torch.dtype,
config: ModelConfig,
backend: str,
ckpt_attn: bool,
qkv_layout: str,
workspace_opt: bool,
swa: bool,
pad_between_seqs: bool,
is_training: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
dtype: torch.dtype,
config: ModelConfig,
backend: str,
ckpt_attn: bool,
qkv_layout: str,
workspace_opt: bool,
swa: bool,
pad_between_seqs: bool,
is_training: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run DotProductAttention module with one forward pass and one backward pass"""
# Set RNG and environment varables
......@@ -558,22 +640,27 @@ def _run_dot_product_attention(
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
# Create seqlens
qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
if "padding" in config.attn_mask_type or qkv_format == 'thd':
if config.attn_type == 'self':
seqlens_q = torch.randint(1, config.max_seqlen_q, [config.batch_size],
dtype=torch.int32, device="cuda")
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if "padding" in config.attn_mask_type or qkv_format == "thd":
if config.attn_type == "self":
seqlens_q = torch.randint(
1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
seqlens_kv = seqlens_q
if config.attn_type == 'cross':
seqlens_q = torch.randint(1, config.max_seqlen_q, [config.batch_size],
dtype=torch.int32, device="cuda")
seqlens_kv = torch.randint(1, config.max_seqlen_kv, [config.batch_size],
dtype=torch.int32, device="cuda")
if config.attn_type == "cross":
seqlens_q = torch.randint(
1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
seqlens_kv = torch.randint(
1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
)
else:
seqlens_q = torch.full([config.batch_size], config.max_seqlen_q,
dtype=torch.int32, device="cuda")
seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv,
dtype=torch.int32, device="cuda")
seqlens_q = torch.full(
[config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
)
seqlens_kv = torch.full(
[config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
)
cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
......@@ -586,7 +673,7 @@ def _run_dot_product_attention(
pad_len = [0] * config.batch_size
if pad_between_seqs:
max_pad_len = 3
pad_len = torch.randint(0, max_pad_len+1, [config.batch_size], device="cuda") #3
pad_len = torch.randint(0, max_pad_len + 1, [config.batch_size], device="cuda") # 3
seqlens_q_after_pad = seqlens_q + pad_len
seqlens_kv_after_pad = seqlens_kv + pad_len
cu_seqlens_q_after_pad[1:] = torch.cumsum(seqlens_q_after_pad, dim=0)
......@@ -595,25 +682,58 @@ def _run_dot_product_attention(
# Create attention mask if padding
attention_mask = None
if "padding" in config.attn_mask_type:
if config.attn_type == 'self':
if config.attn_type == "self":
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
for i in range(config.batch_size):
attention_mask_q = torch.cat([attention_mask_q,
torch.Tensor([False]*seqlens_q[i] + [True]*(config.max_seqlen_q-seqlens_q[i]))
.to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
attention_mask_q = torch.cat(
[
attention_mask_q,
torch.Tensor(
[False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
)
.to(dtype=torch.bool)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0),
],
dim=0,
)
attention_mask = attention_mask_q.to(device="cuda")
if config.attn_type == 'cross':
if config.attn_type == "cross":
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool)
for i in range(config.batch_size):
attention_mask_q = torch.cat([attention_mask_q,
torch.Tensor([False]*seqlens_q[i] + [True]*(config.max_seqlen_q-seqlens_q[i]))
.to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
attention_mask_kv = torch.cat([attention_mask_kv, torch.Tensor(
[False]*seqlens_kv[i] + [True]*(config.max_seqlen_kv-seqlens_kv[i]))
.to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
attention_mask_q = torch.cat(
[
attention_mask_q,
torch.Tensor(
[False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
)
.to(dtype=torch.bool)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0),
],
dim=0,
)
attention_mask_kv = torch.cat(
[
attention_mask_kv,
torch.Tensor(
[False] * seqlens_kv[i]
+ [True] * (config.max_seqlen_kv - seqlens_kv[i])
)
.to(dtype=torch.bool)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0),
],
dim=0,
)
attention_mask = (
attention_mask_q.to(device="cuda"), attention_mask_kv.to(device="cuda"))
attention_mask_q.to(device="cuda"),
attention_mask_kv.to(device="cuda"),
)
window_size = None
if swa:
window_size, attention_mask = get_swa(config.max_seqlen_q, config.max_seqlen_kv)
......@@ -623,62 +743,84 @@ def _run_dot_product_attention(
alibi_slopes = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
if config.bias_shape == "1hss":
alibi_slopes = torch.randn(
config.num_heads).abs().to(dtype=torch.float32, device="cuda")
alibi_slopes = (
torch.randn(config.num_heads).abs().to(dtype=torch.float32, device="cuda")
)
if config.bias_shape == "bhss":
alibi_slopes = torch.randn(
config.batch_size, config.num_heads).abs().to(dtype=torch.float32, device="cuda")
alibi_slopes = (
torch.randn(config.batch_size, config.num_heads)
.abs()
.to(dtype=torch.float32, device="cuda")
)
# Create input tensors
dim_to_num = {
'b' : config.batch_size,
'sq' : config.max_seqlen_q,
'skv': config.max_seqlen_kv,
'h' : config.num_heads,
'hg' : config.num_gqa_groups,
'd' : config.head_dim,
't' : cu_seqlens_q_after_pad[-1],
'tg' : cu_seqlens_kv_after_pad[-1],
'3' : 3,
'2' : 2,
'1' : 1,
}
"b": config.batch_size,
"sq": config.max_seqlen_q,
"skv": config.max_seqlen_kv,
"h": config.num_heads,
"hg": config.num_gqa_groups,
"d": config.head_dim,
"t": cu_seqlens_q_after_pad[-1],
"tg": cu_seqlens_kv_after_pad[-1],
"3": 3,
"2": 2,
"1": 1,
}
inp = []
inp_orig = []
for i,layout in enumerate(qkv_layout.split('_')):
layout = '_'.join(layout)
for i, layout in enumerate(qkv_layout.split("_")):
layout = "_".join(layout)
if i == 0:
layout = layout.replace('s', 'sq')
layout = layout.replace("s", "sq")
else:
layout = layout.replace('s', 'skv')
layout = layout.replace('h', 'hg')
layout = layout.replace('t', 'tg')
tensor_shape = [dim_to_num[j] for j in layout.split('_')]
layout = layout.replace("s", "skv")
layout = layout.replace("h", "hg")
layout = layout.replace("t", "tg")
tensor_shape = [dim_to_num[j] for j in layout.split("_")]
tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
tensor_orig = tensor
if qkv_format == 'thd' and pad_between_seqs:
tensor_orig = torch.Tensor([]).to(device="cuda",dtype=dtype)
if layout in ['t_h_d', 't_3_h_d', 't_h_3_d']:
for i in range(1, config.batch_size+1):
valid_range = (cu_seqlens_q_after_pad[i-1], cu_seqlens_q_after_pad[i] - pad_len[i-1])
pad_range = (cu_seqlens_q_after_pad[i] - pad_len[i-1], cu_seqlens_q_after_pad[i])
tensor[pad_range[0]:pad_range[1]] = 0.0
tensor_orig = torch.cat([tensor_orig, tensor[valid_range[0]:valid_range[1]]], dim=0)
if layout in ['tg_hg_d', 'tg_2_hg_d', 'tg_hg_2_d']:
for i in range(1, config.batch_size+1):
valid_range = (cu_seqlens_kv_after_pad[i-1], cu_seqlens_kv_after_pad[i] - pad_len[i-1])
pad_range = (cu_seqlens_kv_after_pad[i] - pad_len[i-1], cu_seqlens_kv_after_pad[i])
tensor[pad_range[0]:pad_range[1]] = 0.0
tensor_orig = torch.cat([tensor_orig, tensor[valid_range[0]:valid_range[1]]], dim=0)
if qkv_format == "thd" and pad_between_seqs:
tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
if layout in ["t_h_d", "t_3_h_d", "t_h_3_d"]:
for i in range(1, config.batch_size + 1):
valid_range = (
cu_seqlens_q_after_pad[i - 1],
cu_seqlens_q_after_pad[i] - pad_len[i - 1],
)
pad_range = (
cu_seqlens_q_after_pad[i] - pad_len[i - 1],
cu_seqlens_q_after_pad[i],
)
tensor[pad_range[0] : pad_range[1]] = 0.0
tensor_orig = torch.cat(
[tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
)
if layout in ["tg_hg_d", "tg_2_hg_d", "tg_hg_2_d"]:
for i in range(1, config.batch_size + 1):
valid_range = (
cu_seqlens_kv_after_pad[i - 1],
cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
)
pad_range = (
cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
cu_seqlens_kv_after_pad[i],
)
tensor[pad_range[0] : pad_range[1]] = 0.0
tensor_orig = torch.cat(
[tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
)
tensor_count = 1
split_dim = 0
for dim, l in enumerate(layout.split('_')):
for dim, l in enumerate(layout.split("_")):
if l.isdigit():
tensor_count = int(l)
split_dim = dim
break
tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
tensors_orig = torch.split(tensor_orig, 1, dim=split_dim) if split_dim != 0 else [tensor_orig]
tensors_orig = (
torch.split(tensor_orig, 1, dim=split_dim) if split_dim != 0 else [tensor_orig]
)
for j in range(tensor_count):
if split_dim != 0:
inp.append(tensors[j].squeeze(split_dim))
......@@ -692,73 +834,77 @@ def _run_dot_product_attention(
# Create ragged offsets for q/k/v
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o = None, None, None, None
qkv_group = ''.join([x for x in qkv_layout if x not in 'bst'])
if qkv_format == 'thd':
qkv_group = "".join([x for x in qkv_layout if x not in "bst"])
if qkv_format == "thd":
seq_offsets_o = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
if qkv_group == 'hd_hd_hd':
if qkv_group == "hd_hd_hd":
seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
seq_offsets_k = config.num_gqa_groups * config.head_dim * cu_seqlens_kv_after_pad
seq_offsets_v = config.num_gqa_groups * config.head_dim * cu_seqlens_kv_after_pad
if qkv_group in ['3hd', 'h3d']:
if qkv_group in ["3hd", "h3d"]:
seq_offsets_q = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
seq_offsets_k = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
seq_offsets_v = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
if qkv_group in ['hd_2hd', 'hd_h2d']:
if qkv_group in ["hd_2hd", "hd_h2d"]:
seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
seq_offsets_k = config.num_gqa_groups * config.head_dim * 2 * cu_seqlens_kv_after_pad
seq_offsets_v = config.num_gqa_groups * config.head_dim * 2 * cu_seqlens_kv_after_pad
# Create output gradient
qkv_format_kv = '_'.join(qkv_format)
qkv_format_kv = qkv_format_kv.replace('s', 'sq')
out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split('_')]
qkv_format_kv = "_".join(qkv_format)
qkv_format_kv = qkv_format_kv.replace("s", "sq")
out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda")
out_grad_orig = out_grad
if qkv_format == 'thd' and pad_between_seqs:
out_grad_orig = torch.Tensor([]).to(device="cuda",dtype=dtype)
if qkv_format_kv == 't_h_d':
for i in range(1, config.batch_size+1):
valid_range = (cu_seqlens_q_after_pad[i-1], cu_seqlens_q_after_pad[i] - pad_len[i-1])
pad_range = (cu_seqlens_q_after_pad[i] - pad_len[i-1], cu_seqlens_q_after_pad[i])
out_grad[pad_range[0]:pad_range[1]] = 0.0
out_grad_orig = torch.cat([out_grad_orig, out_grad[valid_range[0]:valid_range[1]]], dim=0)
if qkv_format == "thd" and pad_between_seqs:
out_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
if qkv_format_kv == "t_h_d":
for i in range(1, config.batch_size + 1):
valid_range = (
cu_seqlens_q_after_pad[i - 1],
cu_seqlens_q_after_pad[i] - pad_len[i - 1],
)
pad_range = (cu_seqlens_q_after_pad[i] - pad_len[i - 1], cu_seqlens_q_after_pad[i])
out_grad[pad_range[0] : pad_range[1]] = 0.0
out_grad_orig = torch.cat(
[out_grad_orig, out_grad[valid_range[0] : valid_range[1]]], dim=0
)
# Create bias
if config.attn_bias_type in ['no_bias', 'alibi']:
if config.attn_bias_type in ["no_bias", "alibi"]:
bias = None
if config.attn_bias_type == 'post_scale_bias':
shape = '_'.join(config.bias_shape)
shape = shape.replace('_s_s', '_sq_skv')
tensor_shape = [dim_to_num[j] for j in shape.split('_')]
if config.attn_bias_type == "post_scale_bias":
shape = "_".join(config.bias_shape)
shape = shape.replace("_s_s", "_sq_skv")
tensor_shape = [dim_to_num[j] for j in shape.split("_")]
bias = torch.randn(tensor_shape, dtype=dtype, device="cuda")
if config.bias_shape != '1hss':
if config.bias_shape != "1hss":
bias.requires_grad = False
# Create RNG
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
# Set up model
block = (
DotProductAttention(
config.num_heads,
config.head_dim,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type,
sequence_parallel=False,
tp_size=1,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
tp_group=None,
layer_number=1,
attention_type=config.attn_type,
).to(dtype=dtype, device="cuda")
)
block = DotProductAttention(
config.num_heads,
config.head_dim,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type,
sequence_parallel=False,
tp_size=1,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
tp_group=None,
layer_number=1,
attention_type=config.attn_type,
).to(dtype=dtype, device="cuda")
# Run a forward and backward pass
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
......@@ -771,24 +917,28 @@ def _run_dot_product_attention(
k = inp[1]
v = inp[2]
d_out = out_grad
out = block(q, k, v,
window_size=window_size,
attention_mask=attention_mask,
qkv_format=qkv_format,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
seq_offsets_q=seq_offsets_q,
seq_offsets_k=seq_offsets_k,
seq_offsets_v=seq_offsets_v,
seq_offsets_o=seq_offsets_o,
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=True)
out = block(
q,
k,
v,
window_size=window_size,
attention_mask=attention_mask,
qkv_format=qkv_format,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
seq_offsets_q=seq_offsets_q,
seq_offsets_k=seq_offsets_k,
seq_offsets_v=seq_offsets_v,
seq_offsets_o=seq_offsets_o,
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=True,
)
if is_training:
out.backward(d_out)
......@@ -798,18 +948,30 @@ def _run_dot_product_attention(
else:
return out, (None, None, None)
if backend == "FusedAttention":
if qkv_format == 'thd' and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda",dtype=dtype)
q_grad_orig = torch.Tensor([]).to(device="cuda",dtype=dtype)
k_grad_orig = torch.Tensor([]).to(device="cuda",dtype=dtype)
v_grad_orig = torch.Tensor([]).to(device="cuda",dtype=dtype)
for i in range(1, config.batch_size+1):
valid_range_q = (cu_seqlens_q_after_pad[i-1], cu_seqlens_q_after_pad[i] - pad_len[i-1])
valid_range_kv = (cu_seqlens_kv_after_pad[i-1], cu_seqlens_kv_after_pad[i] - pad_len[i-1])
out_orig = torch.cat([out_orig, out[valid_range_q[0]:valid_range_q[1]]], dim=0)
q_grad_orig = torch.cat([q_grad_orig, q.grad[valid_range_q[0]:valid_range_q[1]]], dim=0)
k_grad_orig = torch.cat([k_grad_orig, k.grad[valid_range_kv[0]:valid_range_kv[1]]], dim=0)
v_grad_orig = torch.cat([v_grad_orig, v.grad[valid_range_kv[0]:valid_range_kv[1]]], dim=0)
if qkv_format == "thd" and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
for i in range(1, config.batch_size + 1):
valid_range_q = (
cu_seqlens_q_after_pad[i - 1],
cu_seqlens_q_after_pad[i] - pad_len[i - 1],
)
valid_range_kv = (
cu_seqlens_kv_after_pad[i - 1],
cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
)
out_orig = torch.cat([out_orig, out[valid_range_q[0] : valid_range_q[1]]], dim=0)
q_grad_orig = torch.cat(
[q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0
)
k_grad_orig = torch.cat(
[k_grad_orig, k.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
)
v_grad_orig = torch.cat(
[v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
)
if is_training:
return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig)
else:
......@@ -823,18 +985,18 @@ def _run_dot_product_attention(
model_configs_te_layer = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
"te_1_1": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"te_1_2": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"),
"te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"),
"te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"),
"te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
"te_1_1": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"te_1_2": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"),
"te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"),
"te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"),
}
@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", model_configs_te_layer.keys())
......@@ -842,7 +1004,9 @@ model_configs_te_layer = {
@pytest.mark.parametrize("qkv_format", ["sbhd"])
@pytest.mark.parametrize("fused_qkv_params", [False])
@pytest.mark.parametrize("RoPE", [False])
def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE):
def test_transformer_layer(
dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
):
"""Test TransformerLayer module"""
# Get configs
......@@ -916,7 +1080,7 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f
torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)
@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_1_2", "te_2_0"])
......@@ -926,22 +1090,24 @@ def test_te_layer_misc(dtype, model_configs, model, qkv_format):
ckpt_attn = True
fused_qkv_params = True
RoPE = True
test_transformer_layer(dtype, model_configs, model,
ckpt_attn, qkv_format, fused_qkv_params, RoPE)
test_transformer_layer(
dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
)
@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_2_0", "te_2_1", "te_2_2"])
def test_te_layer_mqa_gqa(dtype, model_configs, model):
"""Test TransformerLayer module with MQA/GQA"""
def find_factors(x):
f = []
for i in range(2, x + 1):
if x % i == 0:
f.append(i)
return f
f = []
for i in range(2, x + 1):
if x % i == 0:
f.append(i)
return f
ckpt_attn = True
qkv_format = "bshd"
......@@ -951,21 +1117,22 @@ def test_te_layer_mqa_gqa(dtype, model_configs, model):
num_querys_per_gqa_group = find_factors(config.num_heads)
for num_q_per_gqa_group in num_querys_per_gqa_group:
config.num_gqa_groups=config.num_heads // num_q_per_gqa_group
test_transformer_layer(dtype, model_configs, model,
ckpt_attn, qkv_format, fused_qkv_params, RoPE)
config.num_gqa_groups = config.num_heads // num_q_per_gqa_group
test_transformer_layer(
dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
)
def _run_transformer_layer(
dtype: torch.dtype,
config: ModelConfig,
backend: str,
ckpt_attn: bool,
qkv_format: str,
workspace_opt: bool,
fused_qkv_params: bool,
RoPE: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
dtype: torch.dtype,
config: ModelConfig,
backend: str,
ckpt_attn: bool,
qkv_format: str,
workspace_opt: bool,
fused_qkv_params: bool,
RoPE: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run TransformerLayer module with one forward pass and one backward pass"""
# Set RNG and environment variables
......@@ -978,29 +1145,47 @@ def _run_transformer_layer(
os.environ["NVTE_FUSED_ATTN"] = "1"
# Create input tensor
inp = torch.randn(config.max_seqlen_q, config.batch_size, config.hidden_size,
dtype=dtype, device="cuda", requires_grad = True)
inp = torch.randn(
config.max_seqlen_q,
config.batch_size,
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)
# In case the format to be tested is batch-first, need to transpose the
# input tensor.
if qkv_format == "bshd":
inp = inp.transpose(0,1)
inp = inp.transpose(0, 1)
# Create seqlens
if "padding" in config.attn_mask_type:
seqlens_q = torch.randint(1, config.max_seqlen_q, [config.batch_size],
dtype=torch.int32, device="cuda")
seqlens_q = torch.randint(
1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
else:
seqlens_q = torch.full([config.batch_size], config.max_seqlen_q,
dtype=torch.int32, device="cuda")
seqlens_q = torch.full(
[config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
)
# Create attention mask if padding
attention_mask = None
if "padding" in config.attn_mask_type:
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
for i in range(config.batch_size):
attention_mask_q = torch.cat([attention_mask_q,
torch.Tensor([False]*seqlens_q[i] + [True]*(config.max_seqlen_q-seqlens_q[i]))
.to(torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
attention_mask_q = torch.cat(
[
attention_mask_q,
torch.Tensor(
[False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
)
.to(torch.bool)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0),
],
dim=0,
)
attention_mask = attention_mask_q.to(device="cuda")
sigma = 0.02
......@@ -1009,14 +1194,19 @@ def _run_transformer_layer(
layer_number = 1
drop_path_rate = 0.0
drop_path_rates = [
rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
drop_path_rates = [rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
# Create bias
bias = None
if config.attn_bias_type == 'post_scale_bias':
bias = torch.randn(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv,
dtype=dtype, device="cuda")
if config.attn_bias_type == "post_scale_bias":
bias = torch.randn(
1,
config.num_heads,
config.max_seqlen_q,
config.max_seqlen_kv,
dtype=dtype,
device="cuda",
)
# Create RoPE
rotary_pos_emb = None
......@@ -1025,58 +1215,56 @@ def _run_transformer_layer(
rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
# Set up model
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_heads,
num_gqa_groups=config.num_gqa_groups,
layernorm_epsilon=1e-5,
hidden_dropout=0.0,
attention_dropout=config.dropout_p,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
layer_number=layer_number,
kv_channels=config.head_dim,
self_attn_mask_type=config.attn_mask_type,
tp_group=None,
tp_size=1,
params_dtype=dtype,
get_rng_state_tracker=None,
fuse_wgrad_accumulation=False,
seq_length=config.max_seqlen_q,
micro_batch_size=config.batch_size,
sequence_parallel=False,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
layer_type="encoder",
drop_path_rate=drop_path_rates[layer_number - 1],
set_parallel_mode=True,
fuse_qkv_params=fused_qkv_params,
zero_centered_gamma=False,
qkv_weight_interleaved=False,
ub_tp_comm_overlap=False,
bias=True,
attn_input_format=qkv_format,
)
.to(dtype=dtype, device="cuda")
)
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_heads,
num_gqa_groups=config.num_gqa_groups,
layernorm_epsilon=1e-5,
hidden_dropout=0.0,
attention_dropout=config.dropout_p,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
layer_number=layer_number,
kv_channels=config.head_dim,
self_attn_mask_type=config.attn_mask_type,
tp_group=None,
tp_size=1,
params_dtype=dtype,
get_rng_state_tracker=None,
fuse_wgrad_accumulation=False,
seq_length=config.max_seqlen_q,
micro_batch_size=config.batch_size,
sequence_parallel=False,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
layer_type="encoder",
drop_path_rate=drop_path_rates[layer_number - 1],
set_parallel_mode=True,
fuse_qkv_params=fused_qkv_params,
zero_centered_gamma=False,
qkv_weight_interleaved=False,
ub_tp_comm_overlap=False,
bias=True,
attn_input_format=qkv_format,
).to(dtype=dtype, device="cuda")
# Create ALiBi slopes
alibi_slopes = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
alibi_slopes = torch.randn(
config.num_heads).abs().to(dtype=torch.float32, device="cuda")
alibi_slopes = torch.randn(config.num_heads).abs().to(dtype=torch.float32, device="cuda")
# Run a forward and backward pass
out = block(inp,
out = block(
inp,
attention_mask=attention_mask,
self_attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=False,
rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
alibi_slopes=alibi_slopes)
alibi_slopes=alibi_slopes,
)
loss = out.sum()
loss.backward()
......@@ -1085,23 +1273,24 @@ def _run_transformer_layer(
model_configs_fp8_vs_f16 = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"fp8_9" : ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_10": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_9": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_10": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_11": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_12": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_13": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"),
"fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"),
"fp8_12": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_13": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"),
"fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"),
}
param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16]
qkv_layout_fp8_vs_f16 = ['sbh3d', 'bshd_bshd_bshd', 'sbhd_sbhd_sbhd']
qkv_format_fp8_vs_f16 = ['bshd', 'sbhd']
qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"]
qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]
def _rmse(a, b):
return math.sqrt((torch.pow((a-b), 2)/a.numel()).sum())
return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum())
@pytest.mark.skipif(get_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
......@@ -1118,58 +1307,78 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd):
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm)
dtype, config, True, qkv_format, input_layernorm
)
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
dtype, config, False, qkv_format, input_layernorm)
dtype, config, False, qkv_format, input_layernorm
)
tols = dict(atol=5e-1, rtol=5e-1)
rmse_tol = 0.1
fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16)
fwd_range = max(fused_attn_fwd_fp8.max().item(),
fused_attn_fwd_f16.max().item()) - min(fused_attn_fwd_fp8.min().item(),
fused_attn_fwd_f16.min().item())
logging.debug('========== {:^25s} =========='.format('forward output'))
logging.debug('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format(
fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item()))
logging.debug('fused_attn_fwd_f16 min {:.6f} max {:.6f}'.format(
fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item()))
logging.debug('fused_attn_fwd RMSE: {:.6f}'.format(fwd_rmse))
fwd_range = max(fused_attn_fwd_fp8.max().item(), fused_attn_fwd_f16.max().item()) - min(
fused_attn_fwd_fp8.min().item(), fused_attn_fwd_f16.min().item()
)
logging.debug("========== {:^25s} ==========".format("forward output"))
logging.debug(
"fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format(
fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item()
)
)
logging.debug(
"fused_attn_fwd_f16 min {:.6f} max {:.6f}".format(
fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item()
)
)
logging.debug("fused_attn_fwd RMSE: {:.6f}".format(fwd_rmse))
try:
torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols)
except Exception as e:
logging.debug(e)
assert(fwd_rmse < rmse_tol * fwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range)
assert (
fwd_rmse < rmse_tol * fwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range
)
for i in range(len(param_names[:1])):
bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i])
bwd_range = max(fused_attn_bwd_fp8[i].max().item(),
fused_attn_bwd_f16[i].max().item()) - min(fused_attn_bwd_fp8[i].min().item(),
fused_attn_bwd_f16[i].min().item())
logging.debug('========== {:^25s} =========='.format(param_names[i]))
logging.debug('fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}'.format(i,
fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()))
logging.debug('fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}'.format(i,
fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()))
logging.debug('fused_attn_bwd RMSE[{}]: {:.6f}'.format(i, bwd_rmse))
bwd_range = max(
fused_attn_bwd_fp8[i].max().item(), fused_attn_bwd_f16[i].max().item()
) - min(fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_f16[i].min().item())
logging.debug("========== {:^25s} ==========".format(param_names[i]))
logging.debug(
"fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}".format(
i, fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()
)
)
logging.debug(
"fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}".format(
i, fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()
)
)
logging.debug("fused_attn_bwd RMSE[{}]: {:.6f}".format(i, bwd_rmse))
try:
torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols)
except Exception as e:
logging.debug(e)
assert(bwd_rmse < rmse_tol * bwd_range
), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range)
assert (
bwd_rmse < rmse_tol * bwd_range
), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range
)
def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
......@@ -1184,7 +1393,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
)
with fp8_model_init(enabled=fp8_mha):
mha = (MultiheadAttention(
mha = MultiheadAttention(
hidden_size=config.hidden_size,
num_attention_heads=config.num_heads,
kv_channels=config.head_dim,
......@@ -1199,34 +1408,35 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
attention_type="self",
qkv_weight_interleaved=True,
qkv_format=qkv_format,
).to(dtype=dtype, device="cuda")
)
).to(dtype=dtype, device="cuda")
seqlens_q = torch.full([config.batch_size], config.max_seqlen_q,
dtype=torch.int32, device="cuda")
seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv,
dtype=torch.int32, device="cuda")
seqlens_q = torch.full(
[config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
)
seqlens_kv = torch.full(
[config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
)
cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
dim_to_num = {
'b' : config.batch_size,
'sq' : config.max_seqlen_q,
'skv': config.max_seqlen_kv,
'h' : config.num_heads,
'hg' : config.num_gqa_groups,
'd' : config.head_dim,
't' : cu_seqlens_q[-1],
'tg' : cu_seqlens_kv[-1],
'3' : 3,
'2' : 2,
'1' : 1,
}
layout = '_'.join(qkv_format)
layout = layout.replace('s', 'sq')
tensor_shape = [dim_to_num[j] for j in layout.split('_')]
"b": config.batch_size,
"sq": config.max_seqlen_q,
"skv": config.max_seqlen_kv,
"h": config.num_heads,
"hg": config.num_gqa_groups,
"d": config.head_dim,
"t": cu_seqlens_q[-1],
"tg": cu_seqlens_kv[-1],
"3": 3,
"2": 2,
"1": 1,
}
layout = "_".join(qkv_format)
layout = layout.replace("s", "sq")
tensor_shape = [dim_to_num[j] for j in layout.split("_")]
tensor = 0.01 * torch.randint(-100, 100, tensor_shape, dtype=dtype, device="cuda")
hidden_states = tensor.view(*tensor.shape[:-2], -1)
hidden_states.requires_grad = True
......@@ -1234,27 +1444,28 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
out_grad = tensor.view(*tensor.shape[:-2], -1)
with fp8_autocast(enabled=fp8_mha, fp8_recipe=fp8_recipe):
out = mha(hidden_states,
out = mha(
hidden_states,
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type,
is_first_microbatch=None,
)
)
out.backward(out_grad)
param_names = []
param_names.append('hidden_states.grad')
param_names.append("hidden_states.grad")
params = []
params.append(hidden_states)
for name, param in mha.named_parameters():
if param.requires_grad:
param_names.append(name+'.grad')
param_names.append(name + ".grad")
params.append(param)
return out, param_names, tuple(x.grad for x in params)
@pytest.mark.skipif(get_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
......@@ -1264,62 +1475,75 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd):
config = model_configs_fp8_vs_f16[model]
if (config.num_heads != config.num_gqa_groups and '3' in qkv_layout):
pytest.skip("qkv_layout not applicable for MQA/GQA");
if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
pytest.skip("qkv_layout not applicable for MQA/GQA")
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
dtype, config, True, qkv_layout)
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(dtype, config, True, qkv_layout)
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False")
fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(
dtype, config, False, qkv_layout)
fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(dtype, config, False, qkv_layout)
tols = dict(atol=5e-1, rtol=5e-2)
rmse_tol = 0.1
bwd_names = ['dq', 'dk', 'dv']
bwd_names = ["dq", "dk", "dv"]
fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16)
fwd_range = max(fused_attn_fwd_fp8.max().item(),
fused_attn_fwd_f16.max().item()) - min(fused_attn_fwd_fp8.min().item(),
fused_attn_fwd_f16.min().item())
logging.debug('========== {:^25s} =========='.format('forward output'))
logging.debug('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format(
fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item()))
logging.debug('fused_attn_fwd_f16 min {:.6f} max {:.6f}'.format(
fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item()))
logging.debug('fused_attn_fwd RMSE: {:.6f}'.format(fwd_rmse))
fwd_range = max(fused_attn_fwd_fp8.max().item(), fused_attn_fwd_f16.max().item()) - min(
fused_attn_fwd_fp8.min().item(), fused_attn_fwd_f16.min().item()
)
logging.debug("========== {:^25s} ==========".format("forward output"))
logging.debug(
"fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format(
fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item()
)
)
logging.debug(
"fused_attn_fwd_f16 min {:.6f} max {:.6f}".format(
fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item()
)
)
logging.debug("fused_attn_fwd RMSE: {:.6f}".format(fwd_rmse))
try:
torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols)
except Exception as e:
logging.debug(e)
assert(fwd_rmse < rmse_tol * fwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range)
for i,_ in enumerate(fused_attn_bwd_f16):
assert (
fwd_rmse < rmse_tol * fwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range
)
for i, _ in enumerate(fused_attn_bwd_f16):
bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i])
bwd_range = max(fused_attn_bwd_fp8[i].max().item(),
fused_attn_bwd_f16[i].max().item()) - min(fused_attn_bwd_fp8[i].min().item(),
fused_attn_bwd_f16[i].min().item())
logging.debug('========== {:^25s} =========='.format(bwd_names[i]))
logging.debug('fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}'.format(i,
fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()))
logging.debug('fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}'.format(i,
fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()))
logging.debug('fused_attn_bwd RMSE[{}]: {:.6f}'.format(i, bwd_rmse))
bwd_range = max(
fused_attn_bwd_fp8[i].max().item(), fused_attn_bwd_f16[i].max().item()
) - min(fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_f16[i].min().item())
logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
logging.debug(
"fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}".format(
i, fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()
)
)
logging.debug(
"fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}".format(
i, fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()
)
)
logging.debug("fused_attn_bwd RMSE[{}]: {:.6f}".format(i, bwd_rmse))
try:
torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols)
except Exception as e:
logging.debug(e)
assert(bwd_rmse < rmse_tol * bwd_range
), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range)
assert (
bwd_rmse < rmse_tol * bwd_range
), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range
)
def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
......@@ -1327,6 +1551,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
......@@ -1339,60 +1564,60 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
fp8_dpa=fp8_dpa,
)
qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
with fp8_model_init(enabled=fp8_dpa):
dpa = (
DotProductAttention(
config.num_heads,
config.head_dim,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
sequence_parallel=False,
tp_size=1,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
tp_group=None,
layer_number=1,
attention_type="self",
qkv_format=qkv_format,
).to(dtype=dtype, device="cuda")
)
dpa = DotProductAttention(
config.num_heads,
config.head_dim,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
sequence_parallel=False,
tp_size=1,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
tp_group=None,
layer_number=1,
attention_type="self",
qkv_format=qkv_format,
).to(dtype=dtype, device="cuda")
seqlens_q = torch.full([config.batch_size], config.max_seqlen_q,
dtype=torch.int32, device="cuda")
seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv,
dtype=torch.int32, device="cuda")
seqlens_q = torch.full(
[config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
)
seqlens_kv = torch.full(
[config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
)
cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
dim_to_num = {
'b' : config.batch_size,
'sq' : config.max_seqlen_q,
'skv': config.max_seqlen_kv,
'h' : config.num_heads,
'hg' : config.num_gqa_groups,
'd' : config.head_dim,
't' : cu_seqlens_q[-1],
'tg' : cu_seqlens_kv[-1],
'3' : 3,
'2' : 2,
'1' : 1,
}
"b": config.batch_size,
"sq": config.max_seqlen_q,
"skv": config.max_seqlen_kv,
"h": config.num_heads,
"hg": config.num_gqa_groups,
"d": config.head_dim,
"t": cu_seqlens_q[-1],
"tg": cu_seqlens_kv[-1],
"3": 3,
"2": 2,
"1": 1,
}
inp = []
for i,layout in enumerate(qkv_layout.split('_')):
layout = '_'.join(layout)
for i, layout in enumerate(qkv_layout.split("_")):
layout = "_".join(layout)
if i == 0:
layout = layout.replace('s', 'sq')
layout = layout.replace("s", "sq")
else:
layout = layout.replace('s', 'skv')
layout = layout.replace('h', 'hg')
layout = layout.replace('t', 'tg')
tensor_shape = [dim_to_num[j] for j in layout.split('_')]
layout = layout.replace("s", "skv")
layout = layout.replace("h", "hg")
layout = layout.replace("t", "tg")
tensor_shape = [dim_to_num[j] for j in layout.split("_")]
tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda")
tensor_count = 1
split_dim = 0
for dim, l in enumerate(layout.split('_')):
for dim, l in enumerate(layout.split("_")):
if l.isdigit():
tensor_count = int(l)
split_dim = dim
......@@ -1406,14 +1631,17 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
for i in range(3):
inp[i].requires_grad = True
qkv_format_kv = '_'.join(qkv_format)
qkv_format_kv = qkv_format_kv.replace('s', 'sq')
out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split('_')]
qkv_format_kv = "_".join(qkv_format)
qkv_format_kv = qkv_format_kv.replace("s", "sq")
out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda")
with fp8_autocast(enabled=fp8_dpa, fp8_recipe=fp8_recipe):
out = dpa(inp[0], inp[1], inp[2],
out = dpa(
inp[0],
inp[1],
inp[2],
qkv_format=qkv_format,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
......@@ -1423,7 +1651,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type,
is_first_microbatch=True,
)
)
out.backward(out_grad)
return out, (inp[0].grad, inp[1].grad, inp[2].grad)
......@@ -1431,22 +1659,22 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
model_configs_fp8 = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"fp8_1": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "no_mask", "no_bias"),
"fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),
"fp8_3": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_1": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "no_mask", "no_bias"),
"fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),
"fp8_3": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_4": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_5": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "causal", "no_bias"),
"fp8_6": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "causal", "no_bias"),
"fp8_7": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_8": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_5": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "causal", "no_bias"),
"fp8_6": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "causal", "no_bias"),
"fp8_7": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_8": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
}
param_types_fp8 = [torch.float16, torch.bfloat16]
cudnn_frontend_version = int(os.getenv('NVTE_FUSED_ATTN_FE_VER','1'))
models_v0 = ['fp8_1', 'fp8_2', 'fp8_5', 'fp8_6']
models_v1 = ['fp8_3', 'fp8_4', 'fp8_7', 'fp8_8']
cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1"))
models_v0 = ["fp8_1", "fp8_2", "fp8_5", "fp8_6"]
models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
@pytest.mark.skipif(get_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8)
......@@ -1460,50 +1688,62 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
config = model_configs_fp8[model]
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(
dtype, config, "FusedAttention")
unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(
dtype, config, "UnfusedAttention")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention")
tols = dict(atol=5e-1, rtol=5e-1)
rmse_tol = 0.1
fwd_rmse = _rmse(fused_attn_fwd_fp8, unfused_attn_fwd_f16)
fwd_range = max(fused_attn_fwd_fp8.max().item(),
unfused_attn_fwd_f16.max().item()) - min(fused_attn_fwd_fp8.min().item(),
unfused_attn_fwd_f16.min().item())
fwd_range = max(fused_attn_fwd_fp8.max().item(), unfused_attn_fwd_f16.max().item()) - min(
fused_attn_fwd_fp8.min().item(), unfused_attn_fwd_f16.min().item()
)
bwd_rmse = _rmse(fused_attn_bwd_fp8, unfused_attn_bwd_f16)
bwd_range = max(fused_attn_bwd_fp8.max().item(),
unfused_attn_bwd_f16.max().item()) - min(fused_attn_bwd_fp8.min().item(),
unfused_attn_bwd_f16.min().item())
logging.debug('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format(
fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item()))
logging.debug('unfused_attn_fwd_f16 min {:.6f} max {:.6f}'.format(
unfused_attn_fwd_f16.min().item(), unfused_attn_fwd_f16.max().item()))
logging.debug('fused_attn_fwd_fp8 vs unfused_attn_fwd_f16 RMSE: {:.6f}'.format(
fwd_rmse))
bwd_range = max(fused_attn_bwd_fp8.max().item(), unfused_attn_bwd_f16.max().item()) - min(
fused_attn_bwd_fp8.min().item(), unfused_attn_bwd_f16.min().item()
)
logging.debug(
"fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format(
fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item()
)
)
logging.debug(
"unfused_attn_fwd_f16 min {:.6f} max {:.6f}".format(
unfused_attn_fwd_f16.min().item(), unfused_attn_fwd_f16.max().item()
)
)
logging.debug("fused_attn_fwd_fp8 vs unfused_attn_fwd_f16 RMSE: {:.6f}".format(fwd_rmse))
try:
torch.testing.assert_close(fused_attn_fwd_fp8, unfused_attn_fwd_f16, **tols)
except Exception as e:
logging.debug(e)
logging.debug('fused_attn_bwd_fp8 min {:.6f} max {:.6f}'.format(
fused_attn_bwd_fp8.min().item(), fused_attn_bwd_fp8.max().item()))
logging.debug('unfused_attn_bwd_f16 min {:.6f} max {:.6f}'.format(
unfused_attn_bwd_f16.min().item(), unfused_attn_bwd_f16.max().item()))
logging.debug('fused_attn_bwd_fp8 vs unfused_attn_bwd_f16 RMSE: {:.6f}'.format(
bwd_rmse))
logging.debug(
"fused_attn_bwd_fp8 min {:.6f} max {:.6f}".format(
fused_attn_bwd_fp8.min().item(), fused_attn_bwd_fp8.max().item()
)
)
logging.debug(
"unfused_attn_bwd_f16 min {:.6f} max {:.6f}".format(
unfused_attn_bwd_f16.min().item(), unfused_attn_bwd_f16.max().item()
)
)
logging.debug("fused_attn_bwd_fp8 vs unfused_attn_bwd_f16 RMSE: {:.6f}".format(bwd_rmse))
try:
torch.testing.assert_close(fused_attn_bwd_fp8, unfused_attn_bwd_f16, **tols)
except Exception as e:
logging.debug(e)
assert(fwd_rmse < rmse_tol * fwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range)
assert(bwd_rmse < rmse_tol * bwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range)
assert (
fwd_rmse < rmse_tol * fwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range
)
assert (
bwd_rmse < rmse_tol * bwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range
)
def _run_custom_mha_fp8(dtype, config, backend):
......@@ -1517,18 +1757,25 @@ def _run_custom_mha_fp8(dtype, config, backend):
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
inp = 0.0001 * torch.randint(-100, 100,
(config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim),
dtype=dtype, device="cuda", requires_grad=True)
seqlens = torch.full([config.batch_size], config.max_seqlen_q,
dtype=torch.int32, device="cuda")
inp = 0.0001 * torch.randint(
-100,
100,
(config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim),
dtype=dtype,
device="cuda",
requires_grad=True,
)
seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
out_grad = 0.01 * torch.randn(
config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim,
dtype=dtype, device="cuda")
torch.save(out_grad, 'out_grad.pt')
config.batch_size * config.max_seqlen_q,
config.num_heads * config.head_dim,
dtype=dtype,
device="cuda",
)
torch.save(out_grad, "out_grad.pt")
fp8_recipe = recipe.DelayedScaling(
margin=0,
......@@ -1543,10 +1790,13 @@ def _run_custom_mha_fp8(dtype, config, backend):
out.backward(out_grad)
out = torch.load("out.pt")
dqkv = torch.load('dqkv.pt')
return (out.view(config.batch_size, config.max_seqlen_q, -1),
dqkv.view(config.batch_size, config.max_seqlen_q, 3,
config.num_heads, config.head_dim).contiguous())
dqkv = torch.load("dqkv.pt")
return (
out.view(config.batch_size, config.max_seqlen_q, -1),
dqkv.view(
config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim
).contiguous(),
)
def _run_ref_mha_f16(dtype, config, backend):
......@@ -1560,13 +1810,14 @@ def _run_ref_mha_f16(dtype, config, backend):
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
inp = torch.load('qkv.pt').to(device="cuda")
inp = torch.load("qkv.pt").to(device="cuda")
inp.requires_grad = True
seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
out_grad = torch.load('out_grad.pt').to(device="cuda").view(
config.batch_size, config.max_seqlen_q, -1)
out_grad = (
torch.load("out_grad.pt").to(device="cuda").view(config.batch_size, config.max_seqlen_q, -1)
)
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
......@@ -1582,24 +1833,22 @@ def _run_ref_mha_f16(dtype, config, backend):
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
block = (
DotProductAttention(
config.num_heads,
config.head_dim,
attention_dropout=config.dropout_p,
sequence_parallel=False,
tp_size=1,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
tp_group=None,
layer_number=1,
attention_type="self",
qkv_format="bshd",
).to(dtype=dtype, device="cuda")
)
q = inp[:,:,0,:,:]
k = inp[:,:,1,:,:]
v = inp[:,:,2,:,:]
block = DotProductAttention(
config.num_heads,
config.head_dim,
attention_dropout=config.dropout_p,
sequence_parallel=False,
tp_size=1,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
tp_group=None,
layer_number=1,
attention_type="self",
qkv_format="bshd",
).to(dtype=dtype, device="cuda")
q = inp[:, :, 0, :, :]
k = inp[:, :, 1, :, :]
v = inp[:, :, 2, :, :]
out = block(q, k, v, attn_mask_type=config.attn_mask_type)
out.backward(out_grad)
......@@ -1611,12 +1860,12 @@ _2X_ACC_FPROP = False
_2X_ACC_DGRAD = False
_2X_ACC_WGRAD = False
META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
class _custom_mha_fp8(torch.autograd.Function):
......@@ -1682,46 +1931,57 @@ class _custom_mha_fp8(torch.autograd.Function):
D_dtype=fp8_dtype_forward,
)
qkv = qkv.view(-1, 3, h, d)
qkv_fp16 = ext.cast_from_fp8(qkv, fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward,
tex.DType.kFloat16).view(b, max_s, 3, h, d).contiguous()
torch.save(qkv_fp16, 'qkv.pt')
qkv_fp16 = (
ext.cast_from_fp8(
qkv, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward, tex.DType.kFloat16
)
.view(b, max_s, 3, h, d)
.contiguous()
)
torch.save(qkv_fp16, "qkv.pt")
if cudnn_frontend_version == 1:
qkv = qkv.view(b, max_s, 3, h, d) # bs3hd
qkv = qkv.view(b, max_s, 3, h, d) # bs3hd
# FMHA
out, aux_ctx_tensors, *rest = fused_attn_fwd(
is_training,
max_s,
max_s,
cu_seqlens,
cu_seqlens,
qkv[:,:,0,:,:] if cudnn_frontend_version == 1 else qkv[:,0,:,:],
qkv[:,:,1,:,:] if cudnn_frontend_version == 1 else qkv[:,1,:,:],
qkv[:,:,2,:,:] if cudnn_frontend_version == 1 else qkv[:,2,:,:],
fp8_dtype_forward,
FusedAttnBackend["FP8"],
None, None, None, None, None,
fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S],
fp8_meta["scaling_fwd"].scale[META_O],
fp8_meta["scaling_fwd"].amax_history[0][META_S],
fp8_meta["scaling_fwd"].amax_history[0][META_O],
attn_scale=None,
dropout=p_dropout,
fast_zero_fill=fast_zero_fill,
qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
attn_bias_type="no_bias",
attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding",
rng_gen=None,
)
is_training,
max_s,
max_s,
cu_seqlens,
cu_seqlens,
qkv[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv[:, 0, :, :],
qkv[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv[:, 1, :, :],
qkv[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv[:, 2, :, :],
fp8_dtype_forward,
FusedAttnBackend["FP8"],
None,
None,
None,
None,
None,
fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S],
fp8_meta["scaling_fwd"].scale[META_O],
fp8_meta["scaling_fwd"].amax_history[0][META_S],
fp8_meta["scaling_fwd"].amax_history[0][META_O],
attn_scale=None,
dropout=p_dropout,
fast_zero_fill=fast_zero_fill,
qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
attn_bias_type="no_bias",
attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding",
rng_gen=None,
)
M, ZInv, philox_unpacked = aux_ctx_tensors
ctx.save_for_backward(
inp_t_fp8, qkv_weight_t_fp8, workspace,
qkv, out,
inp_t_fp8,
qkv_weight_t_fp8,
workspace,
qkv,
out,
fp8_meta["scaling_fwd"].scale,
fp8_meta["scaling_fwd"].scale_inv,
)
......@@ -1736,82 +1996,84 @@ class _custom_mha_fp8(torch.autograd.Function):
ctx.mask_type = mask_type
ctx.dtype = inp.dtype
out = out.view(-1, in_features) # (bs)(hd)
out_fp16 = ext.cast_from_fp8(out, fp8_meta["scaling_fwd"],
META_O, fp8_dtype_forward, tex.DType.kFloat16)
torch.save(out_fp16, 'out.pt') # (bs)(hd)
out = out.view(-1, in_features) # (bs)(hd)
out_fp16 = ext.cast_from_fp8(
out, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward, tex.DType.kFloat16
)
torch.save(out_fp16, "out.pt") # (bs)(hd)
return out_fp16
@staticmethod
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
with torch.cuda.nvtx.range("_DPA"):
(
inp_t_fp8,
qkv_weight_t_fp8,
workspace,
qkv, out,
qkv,
out,
fwd_scales,
fwd_scale_inverses,
) = ctx.saved_tensors
fp8_dtype_forward = fp8.get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True
)
fp8_dtype_backward = fp8.get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
)
fp8_dtype_forward = fp8.get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = fp8.get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
proj_dgrad = ext.cast_to_fp8(
grad_output, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
) # (bs)(hd)
) # (bs)(hd)
dq, dk, dv, *rest = fused_attn_bwd(
ctx.max_s,
ctx.max_s,
ctx.cu_seqlens,
ctx.cu_seqlens,
qkv[:,:,0,:,:] if cudnn_frontend_version == 1 else qkv[:,0,:,:],
qkv[:,:,1,:,:] if cudnn_frontend_version == 1 else qkv[:,1,:,:],
qkv[:,:,2,:,:] if cudnn_frontend_version == 1 else qkv[:,2,:,:],
out,
proj_dgrad.view_as(out),
fp8_dtype_forward,
fp8_dtype_backward,
ctx.aux_ctx_tensors,
FusedAttnBackend["FP8"],
None, None, None, None,
fwd_scale_inverses[META_QKV], # d_scale_qkv,
fwd_scale_inverses[META_S], # d_scale_s,
fwd_scale_inverses[META_O], # d_scale_o,
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp
fwd_scales[META_S], # q_scale_s
ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp
ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv
ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp
ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv
attn_scale=None,
dropout=ctx.p_dropout,
fast_zero_fill=ctx.fast_zero_fill,
qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
attn_bias_type="no_bias",
attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding",
)
ctx.max_s,
ctx.max_s,
ctx.cu_seqlens,
ctx.cu_seqlens,
qkv[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv[:, 0, :, :],
qkv[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv[:, 1, :, :],
qkv[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv[:, 2, :, :],
out,
proj_dgrad.view_as(out),
fp8_dtype_forward,
fp8_dtype_backward,
ctx.aux_ctx_tensors,
FusedAttnBackend["FP8"],
None,
None,
None,
None,
fwd_scale_inverses[META_QKV], # d_scale_qkv,
fwd_scale_inverses[META_S], # d_scale_s,
fwd_scale_inverses[META_O], # d_scale_o,
ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO], # d_scale_do
ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP], # d_scale_dp
fwd_scales[META_S], # q_scale_s
ctx.fp8_meta["scaling_bwd"].scale[META_DP], # q_scale_dp
ctx.fp8_meta["scaling_bwd"].scale[META_DQKV], # q_scale_dqkv
ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP], # amax_dp
ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV], # amax_dqkv
attn_scale=None,
dropout=ctx.p_dropout,
fast_zero_fill=ctx.fast_zero_fill,
qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
attn_bias_type="no_bias",
attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding",
)
dim = 2 if cudnn_frontend_version == 1 else 1
dqkv = torch.Tensor().to(device=dq.device, dtype=dq.dtype)
dqkv_shape = list(dq.shape)
dqkv_shape.insert(dim, 3)
dqkv_stride = list(dq.stride())
dqkv_stride.insert(dim, int(dqkv_stride[-3]/3))
dqkv.set_(dq.untyped_storage(), dq.storage_offset(), dqkv_shape, dqkv_stride) # bs3hd
dqkv_stride.insert(dim, int(dqkv_stride[-3] / 3))
dqkv.set_(dq.untyped_storage(), dq.storage_offset(), dqkv_shape, dqkv_stride) # bs3hd
dqkv_c = dqkv.view(-1, 3*ctx.hidden_size)
dqkv_c_fp16 = ext.cast_from_fp8(dqkv_c,
ctx.fp8_meta["scaling_bwd"], META_DQKV,
fp8_dtype_backward, tex.DType.kFloat16)
torch.save(dqkv_c_fp16, 'dqkv.pt')
dqkv_c = dqkv.view(-1, 3 * ctx.hidden_size)
dqkv_c_fp16 = ext.cast_from_fp8(
dqkv_c,
ctx.fp8_meta["scaling_bwd"],
META_DQKV,
fp8_dtype_backward,
tex.DType.kFloat16,
)
torch.save(dqkv_c_fp16, "dqkv.pt")
qkv_bgrad, dqkv_t = ext.fp8_transpose_bgrad_fused(
dqkv_c,
......@@ -1850,7 +2112,8 @@ class _custom_mha_fp8(torch.autograd.Function):
use_split_accumulator=_2X_ACC_WGRAD,
)
return (qkv_dgrad,
return (
qkv_dgrad,
qkv_wgrad,
qkv_bgrad,
None,
......@@ -1862,14 +2125,12 @@ class _custom_mha_fp8(torch.autograd.Function):
None,
None,
None,
None)
None,
)
class Custom_MHA_FP8(TransformerEngineBaseModule):
def __init__(
self,
config,
params_dtype: torch.dtype = torch.float32):
def __init__(self, config, params_dtype: torch.dtype = torch.float32):
super().__init__()
self.p_dropout = config.dropout_p
self.h = config.num_heads
......@@ -1901,8 +2162,10 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
)
def forward(
self, inp: torch.Tensor,
cu_seqlens, max_s,
self,
inp: torch.Tensor,
cu_seqlens,
max_s,
) -> torch.Tensor:
with self.prepare_forward(inp, None, num_gemms=3) as inp:
out = _custom_mha_fp8.apply(
......@@ -1917,5 +2180,6 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
self.fp8_meta,
self.workspace,
self.training,
self.mask_type)
self.mask_type,
)
return out
......@@ -14,12 +14,13 @@ from transformer_engine.pytorch.utils import get_device_compute_capability
model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
}
def get_bash_arguments(**kwargs):
args = ["python", "-m", "torch.distributed.launch", "--nproc-per-node=2"]
te_path = os.getenv("TE_PATH", "/opt/transformerengine")
......@@ -29,46 +30,43 @@ def get_bash_arguments(**kwargs):
args.append(f"{k}={v}")
return args
@pytest.mark.skipif(not _is_flash_attention_2_available(), reason="Flash-attn 2.0+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ['bf16', 'fp16'])
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd', 'thd'])
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
def test_cp_with_flash_attention(dtype, model, qkv_format):
subprocess.run(
get_bash_arguments(
dtype=dtype,
model=model,
qkv_format=qkv_format,
kernel_backend='FlashAttention'
dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FlashAttention"
),
check=True
check=True,
)
model_configs_fused_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
"cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA
"cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
"cp_2_3": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
"cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA
"cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
"cp_2_3": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
}
@pytest.mark.skipif(_cudnn_version() < (8,9,7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ['bf16', 'fp16'])
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd', 'thd'])
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
def test_cp_with_fused_attention(dtype, model, qkv_format):
subprocess.run(
get_bash_arguments(
dtype=dtype,
model=model,
qkv_format=qkv_format,
kernel_backend='FusedAttention'
dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FusedAttention"
),
check=True
check=True,
)
......@@ -8,8 +8,15 @@ import pytest
import torch
from transformer_engine.pytorch import (
DotProductAttention, LayerNormLinear, LayerNormMLP, Linear, make_graphed_callables,
MultiheadAttention, TransformerLayer, fp8_autocast, fp8_model_init,
DotProductAttention,
LayerNormLinear,
LayerNormMLP,
Linear,
make_graphed_callables,
MultiheadAttention,
TransformerLayer,
fp8_autocast,
fp8_model_init,
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible
......@@ -26,15 +33,18 @@ torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
@dataclass
class ModelConfig:
"""Data tensor dimensions within Transformer model"""
sequence_length: int
batch_size: int
hidden_size: int
num_heads: int
kv_channels: int
model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}
modules = ["transformer", "layernorm_mlp", "layernorm_linear", "linear", "mha", "dpa"]
......@@ -66,7 +76,9 @@ def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None)
for i, (t1, t2) in enumerate(zip(l1, l2)):
if not torch.equal(t1, t2):
failed = True
failed_tensors += f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n"
failed_tensors += (
f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n"
)
assert not failed, "Output mismatches in:\n" + failed_tensors
......@@ -157,41 +169,51 @@ def _test_cuda_graphs(
with fp8_model_init(enabled=fp8_params):
# Create modules.
if module == "transformer":
modules = [TransformerLayer(
config.hidden_size,
config.hidden_size,
config.num_heads,
hidden_dropout=0.0,
attention_dropout=0.0,
fuse_qkv_params=True,
params_dtype=dtype,
) for _ in range(num_layers)]
modules = [
TransformerLayer(
config.hidden_size,
config.hidden_size,
config.num_heads,
hidden_dropout=0.0,
attention_dropout=0.0,
fuse_qkv_params=True,
params_dtype=dtype,
)
for _ in range(num_layers)
]
elif module == "layernorm_mlp":
modules = [LayerNormMLP(
config.hidden_size, config.hidden_size, params_dtype=dtype
) for _ in range(num_layers)]
modules = [
LayerNormMLP(config.hidden_size, config.hidden_size, params_dtype=dtype)
for _ in range(num_layers)
]
elif module == "layernorm_linear":
modules = [LayerNormLinear(
config.hidden_size, config.hidden_size, params_dtype=dtype
) for _ in range(num_layers)]
modules = [
LayerNormLinear(config.hidden_size, config.hidden_size, params_dtype=dtype)
for _ in range(num_layers)
]
elif module == "mha":
modules = [MultiheadAttention(
config.hidden_size,
config.num_heads,
attention_dropout=0.0,
params_dtype=dtype,
fuse_qkv_params=True,
) for _ in range(num_layers)]
modules = [
MultiheadAttention(
config.hidden_size,
config.num_heads,
attention_dropout=0.0,
params_dtype=dtype,
fuse_qkv_params=True,
)
for _ in range(num_layers)
]
elif dpa:
assert config.hidden_size % config.num_heads == 0, "Err."
assert num_layers == 1, "Err."
modules = [DotProductAttention(
config.num_heads, config.kv_channels, attention_dropout=0.0
) for _ in range(num_layers)]
modules = [
DotProductAttention(config.num_heads, config.kv_channels, attention_dropout=0.0)
for _ in range(num_layers)
]
else:
modules = [Linear(
config.hidden_size, config.hidden_size, device="cuda", params_dtype=dtype
) for _ in range(num_layers)]
modules = [
Linear(config.hidden_size, config.hidden_size, device="cuda", params_dtype=dtype)
for _ in range(num_layers)
]
# Initialize gradient buffers.
for module in modules:
......@@ -238,7 +260,7 @@ def _test_cuda_graphs(
with fp8_autocast(enabled=fp8):
kwargs = {}
if fp8_weight_caching:
kwargs["is_first_microbatch"] = (grad_accumulation_step == 0)
kwargs["is_first_microbatch"] = grad_accumulation_step == 0
output = model(*inputs, **kwargs)
output.backward(grad_output)
if not dpa:
......
......@@ -27,33 +27,31 @@ num_heads = 16
head_dim = 64
dtype = torch.bfloat16
class TestDeferredInit:
@staticmethod
def get_module_args(module):
hidden_size = num_heads * head_dim
args = (hidden_size,)
kwargs = {
'params_dtype': dtype,
'device': 'meta'
}
kwargs = {"params_dtype": dtype, "device": "meta"}
if module in [te.Linear, te.LayerNormLinear, te.LayerNormMLP]:
ffn_hidden_size = 2 * hidden_size
args += (ffn_hidden_size, )
kwargs['bias'] = True
args += (ffn_hidden_size,)
kwargs["bias"] = True
if module == te.LayerNormMLP:
kwargs['seq_length'] = seq_length
kwargs["seq_length"] = seq_length
elif module == te.MultiheadAttention:
args += (num_heads, )
kwargs['fuse_qkv_params'] = True
args += (num_heads,)
kwargs["fuse_qkv_params"] = True
elif module == te.TransformerLayer:
args += (3 * hidden_size, num_heads)
kwargs['fuse_qkv_params'] = True
kwargs['seq_length'] = seq_length
kwargs["fuse_qkv_params"] = True
kwargs["seq_length"] = seq_length
return args, kwargs
@pytest.mark.parametrize("module_type", _core_modules+_composed_modules)
@pytest.mark.parametrize("module_type", _core_modules + _composed_modules)
def test_zero_memory_init(
self,
module_type: torch.nn.Module,
......
......@@ -26,6 +26,7 @@ _tols: Dict[tex.DType, Dict[str, float]] = {
tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125), # epsilon = 0.125
}
def _to_list(x: Union[Iterable, Any]) -> List:
"""Convert to list if iterable, otherwise put in singleton list"""
if isinstance(x, Iterable):
......@@ -33,12 +34,14 @@ def _to_list(x: Union[Iterable, Any]) -> List:
else:
return [x]
# Types that can be interpreted as tensor dims
DimsType = Union[Iterable[int], int]
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFloat8Tensor:
......@@ -108,7 +111,7 @@ class TestFloat8Tensor:
def test_quantize_dequantize_scales(self, scale: float) -> None:
self._test_quantize_dequantize(scale=scale)
@pytest.mark.parametrize("dims", [[], 1, 311, [7,11], [7,5,3], [2,3,5,3]])
@pytest.mark.parametrize("dims", [[], 1, 311, [7, 11], [7, 5, 3], [2, 3, 5, 3]])
def test_quantize_dequantize_dims(self, dims: DimsType) -> None:
self._test_quantize_dequantize(dims=dims)
......@@ -310,7 +313,7 @@ class TestFloat8Tensor:
def test_serialization(
self,
dims: DimsType = [2,3,5],
dims: DimsType = [2, 3, 5],
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 0.5,
dtype: torch.dtype = torch.float32,
......
......@@ -117,9 +117,7 @@ class TestFusedAdam(TestFusedOptimizer):
tensors = []
for size in sizes:
tensors.append(torch.rand(size, dtype=torch.float, device="cuda"))
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
tensors, self.options
)
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(tensors, self.options)
for i in range(self.iters):
self.gen_grad(ref_param, tst_param)
......@@ -139,9 +137,7 @@ class TestFusedAdam(TestFusedOptimizer):
}
tensor = torch.rand(nelem, dtype=torch.float, device="cuda")
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
[tensor], adam_option
)
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim([tensor], adam_option)
for i in range(self.iters):
self.gen_grad(ref_param, tst_param)
......@@ -161,9 +157,7 @@ class TestFusedAdam(TestFusedOptimizer):
}
tensor = torch.rand(nelem, dtype=torch.float, device="cuda")
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
[tensor], adam_option
)
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim([tensor], adam_option)
# Add an empty param group which may occur for pipeline parallel p-tuning
tst_optim.add_param_group({"params": []})
......@@ -175,10 +169,11 @@ class TestFusedAdam(TestFusedOptimizer):
torch.testing.assert_close(ref_param, tst_param)
class TestFusedSGD(TestFusedOptimizer):
def __init__(self, *args, **kwargs):
super(TestFusedSGD, self).__init__(*args, **kwargs)
self.options = {"lr": .25, "momentum": .125}
self.options = {"lr": 0.25, "momentum": 0.125}
self.ref_optim = torch.optim.SGD
self.fused_optim = te.optimizers.FusedSGD
......@@ -188,7 +183,7 @@ class TestFusedSGD(TestFusedOptimizer):
def test_half(self):
self.gen_single_type_test(param_type=torch.float16)
@unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
@unittest.skipIf(torch.cuda.device_count() < 2, "more than 1 GPU required")
def test_multi_device(self):
devices = ("cuda:0", "cuda:1")
for current_dev, tensor_dev in product(devices, devices):
......@@ -452,8 +447,8 @@ class AdamTest(unittest.TestCase):
@largeTensorTest("60GB", "cuda")
def testLargeTensor(self):
t = torch.zeros(2359332864, dtype=torch.half, device='cuda')
t2 = torch.zeros(2359332864, dtype=torch.half, device='cuda')
t = torch.zeros(2359332864, dtype=torch.half, device="cuda")
t2 = torch.zeros(2359332864, dtype=torch.half, device="cuda")
grad = torch.randn_like(t)
t.grad = grad
t2.grad = grad
......
......@@ -26,10 +26,7 @@ def apply_rotary_pos_emb_thd(
"""
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return torch.cat(
[
apply_rotary_pos_emb(x.unsqueeze(1), freqs[: x.size(0)])
for x in torch.split(t, seqlens)
]
[apply_rotary_pos_emb(x.unsqueeze(1), freqs[: x.size(0)]) for x in torch.split(t, seqlens)]
).squeeze(1)
......@@ -45,6 +42,7 @@ def get_tol(dtype: torch.dtype) -> Dict:
def _overlapping_grad(output: torch.Tensor) -> torch.Tensor:
return output.sum() * 2
# Gradient is a full tensor
def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
t = torch.ones_like(output)
......@@ -86,9 +84,7 @@ def test_fused_rope(
emb = rotary_pos_emb(seq_length)
# unfused
output_unfused = apply_rotary_pos_emb(
t, emb, tensor_format=tensor_format, fused=False
)
output_unfused = apply_rotary_pos_emb(t, emb, tensor_format=tensor_format, fused=False)
loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
......
......@@ -13,23 +13,16 @@ num_heads = 16
head_dim = 64
dtype = torch.bfloat16
num_attn_head = 16
ffn_hidden_size=1024
ffn_hidden_size = 1024
@pytest.mark.parametrize("kv_channels", [128, 256])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("num_gqa_groups", [1, 2, 4, 8, 16])
def test_gqa(
kv_channels,
hidden_size,
num_gqa_groups
) -> None:
def test_gqa(kv_channels, hidden_size, num_gqa_groups) -> None:
model = te.TransformerLayer(
hidden_size,
ffn_hidden_size,
num_attn_head,
num_gqa_groups,
kv_channels=kv_channels
hidden_size, ffn_hidden_size, num_attn_head, num_gqa_groups, kv_channels=kv_channels
)
# Run forward pass
......@@ -42,10 +35,9 @@ def test_gqa(
assert model.self_attention.layernorm_qkv.query_weight.shape[0] == kv_channels * num_attn_head
assert model.self_attention.layernorm_qkv.query_weight.shape[1] == hidden_size
assert model.self_attention.layernorm_qkv.value_weight.shape[0] == kv_channels * num_gqa_groups
assert model.self_attention.layernorm_qkv.value_weight.shape[1] == hidden_size
assert model.self_attention.proj.weight.shape[0] == hidden_size
assert model.self_attention.proj.weight.shape[1] == kv_channels * num_attn_head
......@@ -11,11 +11,11 @@ import transformer_engine.pytorch as te
# Model names for test_torch_dynamo
_model_factory = {
"Linear": [(lambda: te.Linear(16, 16)), [16, 16]],
"LayerNorm": [(lambda: te.LayerNorm(16)), [16, 16]],
"LayerNormLinear": [(lambda: te.LayerNormLinear(16, 16)), [16, 16]],
"LayerNormMLP": [(lambda: te.LayerNormMLP(16, 16)), [16, 16]],
"TransformerLayer": [(lambda: te.TransformerLayer(128, 128, 2)), [4, 1, 128]],
"Linear": [(lambda: te.Linear(16, 16)), [16, 16]],
"LayerNorm": [(lambda: te.LayerNorm(16)), [16, 16]],
"LayerNormLinear": [(lambda: te.LayerNormLinear(16, 16)), [16, 16]],
"LayerNormMLP": [(lambda: te.LayerNormMLP(16, 16)), [16, 16]],
"TransformerLayer": [(lambda: te.TransformerLayer(128, 128, 2)), [4, 1, 128]],
}
......@@ -31,11 +31,11 @@ def test_torch_dynamo(model_name: str):
# Helper function to construct tensor with default options
def make_tensor(
dims: Tuple[int],
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
requires_grad: bool = True,
**kwargs,
dims: Tuple[int],
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
requires_grad: bool = True,
**kwargs,
):
return torch.zeros(
dims,
......
......@@ -28,9 +28,7 @@ appliers = [MultiTensorApply(2048 * 32), MultiTensorApply(333), MultiTensorApply
@pytest.mark.parametrize("in_type", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("out_type", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("inplace", [False, True])
def test_multi_tensor_scale(
input_size_pair, applier, repeat, in_type, out_type, inplace
):
def test_multi_tensor_scale(input_size_pair, applier, repeat, in_type, out_type, inplace):
if inplace is True and (out_type is not in_type):
pytest.skip("inplace=True and out_type != in_type is not supported.")
elif (in_type == torch.float16 and out_type == torch.bfloat16) or (
......@@ -154,9 +152,7 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens
in_list += [a.clone().to(in_type), b.clone().to(in_type)]
if per_tensor:
norm, norm_per_tensor = applier(
tex.multi_tensor_l2norm, overflow_buf, [in_list], True
)
norm, norm_per_tensor = applier(tex.multi_tensor_l2norm, overflow_buf, [in_list], True)
normab = torch.cat((a.norm().view(1), b.norm().view(1)))
norm_per_tensor = norm_per_tensor.view(-1, 2)
else:
......@@ -168,9 +164,7 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens
torch.testing.assert_close(norm, reference.broadcast_to(norm.shape))
if per_tensor:
torch.testing.assert_close(
norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape)
)
torch.testing.assert_close(norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape))
assert overflow_buf.item() == 0
......@@ -179,9 +173,7 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens
@pytest.mark.parametrize("repeat", [1, 55])
@pytest.mark.parametrize("in_type", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("per_tensor", [False, True])
def test_multi_tensor_unscale_l2norm(
input_size_pair, applier, repeat, in_type, per_tensor
):
def test_multi_tensor_unscale_l2norm(input_size_pair, applier, repeat, in_type, per_tensor):
sizea, sizeb = input_size_pair
device = torch.device("cuda")
val = 4.0
......@@ -205,9 +197,7 @@ def test_multi_tensor_unscale_l2norm(
inv_scale_cuda,
True,
)
normab = torch.cat(
((a * inv_scale).norm().view(1), (b * inv_scale).norm().view(1))
)
normab = torch.cat(((a * inv_scale).norm().view(1), (b * inv_scale).norm().view(1)))
norm_per_tensor = norm_per_tensor.view(-1, 2)
else:
norm, _ = applier(
......@@ -224,7 +214,5 @@ def test_multi_tensor_unscale_l2norm(
torch.testing.assert_close(norm, reference.broadcast_to(norm.shape))
if per_tensor:
torch.testing.assert_close(
norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape)
)
torch.testing.assert_close(norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape))
assert overflow_buf.item() == 0
......@@ -20,8 +20,15 @@ from transformer_engine.pytorch.utils import (
is_bf16_compatible,
)
from transformer_engine.pytorch import (
DotProductAttention, LayerNormLinear, LayerNormMLP, Linear,
MultiheadAttention, RMSNorm, TransformerLayer, LayerNorm, InferenceParams
DotProductAttention,
LayerNormLinear,
LayerNormMLP,
Linear,
MultiheadAttention,
RMSNorm,
TransformerLayer,
LayerNorm,
InferenceParams,
)
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
......@@ -106,10 +113,11 @@ def assert_allclose(
if not result:
diff = torch.abs(t1 - t2).flatten()
m = torch.argmax(diff)
msg = (f"Outputs not close enough in tensor at idx={i}. "
f"Location of the maximum difference: {m.item()} "
f"with {t1.flatten()[m].item()} vs {t2.flatten()[m].item()} "
f"(diff {diff[m].item()})."
msg = (
f"Outputs not close enough in tensor at idx={i}. "
f"Location of the maximum difference: {m.item()} "
f"with {t1.flatten()[m].item()} vs {t2.flatten()[m].item()} "
f"(diff {diff[m].item()})."
)
raise AssertionError(msg)
......@@ -175,9 +183,7 @@ class TorchDotProductAttention(torch.nn.Module):
)
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.reshape(
output_size[2], output_size[0] * output_size[1], -1
)
query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)
......@@ -216,14 +222,10 @@ class TorchDotProductAttention(torch.nn.Module):
)
# change view [sk, b * np, hn]
value_layer = value_layer.reshape(
value_layer.size(0), output_size[0] * output_size[1], -1
)
value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(
output_size[0] * output_size[1], output_size[2], -1
)
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
......@@ -241,9 +243,7 @@ class TorchDotProductAttention(torch.nn.Module):
class TorchLayerNorm(nn.Module):
def __init__(self, in_features: int,
eps: float,
zero_centered_gamma: bool):
def __init__(self, in_features: int, eps: float, zero_centered_gamma: bool):
super().__init__()
self.eps = eps
self.in_features = in_features
......@@ -260,10 +260,12 @@ class TorchLayerNorm(nn.Module):
w = w.to(torch.float32)
b = self.bias.to(torch.float32)
inp = x.to(torch.float32)
out = torch.nn.functional.layer_norm(inp, (self.in_features,), weight=w,
bias=b, eps=self.eps)
out = torch.nn.functional.layer_norm(
inp, (self.in_features,), weight=w, bias=b, eps=self.eps
)
return out.to(x.dtype)
# Adapted from https://github.com/bzhangGo/rmsnorm/blob/c6691f20ec0af4128c8159c903071f7575404295/rmsnorm_torch.py
class TorchRMSNorm(nn.Module):
def __init__(self, in_features, zero_centered_gamma, eps=1e-5):
......@@ -278,11 +280,11 @@ class TorchRMSNorm(nn.Module):
self.register_parameter("weight", self.weight)
def forward(self, x):
norm_x2 = torch.sum(x.float()**2, dim=-1, keepdim=True)
norm_x2 = torch.sum(x.float() ** 2, dim=-1, keepdim=True)
d_x = self.in_features
rms_x2 = norm_x2 / d_x + self.eps
r_rms_x = rms_x2 ** (-1. / 2)
r_rms_x = rms_x2 ** (-1.0 / 2)
x_normed = x * r_rms_x
w = self.weight.float()
......@@ -292,17 +294,24 @@ class TorchRMSNorm(nn.Module):
class TorchLayerNormLinear(nn.Module):
def __init__(self, in_features: int, out_features: int,
eps: float, bias: bool = True,
normalization: str = "LayerNorm",
zero_centered_gamma: bool = False):
def __init__(
self,
in_features: int,
out_features: int,
eps: float,
bias: bool = True,
normalization: str = "LayerNorm",
zero_centered_gamma: bool = False,
):
super().__init__()
if normalization == "LayerNorm":
self.layernorm = TorchLayerNorm(in_features, eps=eps,
zero_centered_gamma=zero_centered_gamma)
self.layernorm = TorchLayerNorm(
in_features, eps=eps, zero_centered_gamma=zero_centered_gamma
)
elif normalization == "RMSNorm":
self.layernorm = TorchRMSNorm(in_features, eps=eps,
zero_centered_gamma=zero_centered_gamma)
self.layernorm = TorchRMSNorm(
in_features, eps=eps, zero_centered_gamma=zero_centered_gamma
)
else:
raise RuntimeError("Unsupported normalization")
......@@ -329,21 +338,26 @@ class TorchMHA(nn.Module):
output = output[0]
return output
class TorchQuickGELU(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return input * torch.sigmoid(1.702 * input)
class TorchSquaredRELU(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return (input > 0) * input * input
_supported_act = {'geglu' : nn.GELU(approximate="tanh"),
'gelu' : nn.GELU(approximate="tanh"),
'reglu' : nn.ReLU(),
'relu' : nn.ReLU(),
'swiglu' : nn.SiLU(),
'qgelu' : TorchQuickGELU(),
'srelu' : TorchSquaredRELU()}
_supported_act = {
"geglu": nn.GELU(approximate="tanh"),
"gelu": nn.GELU(approximate="tanh"),
"reglu": nn.ReLU(),
"relu": nn.ReLU(),
"swiglu": nn.SiLU(),
"qgelu": TorchQuickGELU(),
"srelu": TorchSquaredRELU(),
}
class TorchGLU(nn.Module):
......@@ -353,26 +367,29 @@ class TorchGLU(nn.Module):
def forward(self, x):
shape = x.size(-1)
a = x[..., :shape // 2]
b = x[..., (shape // 2):]
a = x[..., : shape // 2]
b = x[..., (shape // 2) :]
a = self.act(a)
return a * b
class TorchLayerNormMLP(nn.Module):
def __init__(self, hidden_size: int, ffn_hidden_size: int,
eps: float = 1e-5, activation = 'gelu',
normalization: str = "LayerNorm"):
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
eps: float = 1e-5,
activation="gelu",
normalization: str = "LayerNorm",
):
super().__init__()
if normalization == "LayerNorm":
self.ln = TorchLayerNorm(hidden_size, eps=eps,
zero_centered_gamma=False)
self.ln = TorchLayerNorm(hidden_size, eps=eps, zero_centered_gamma=False)
elif normalization == "RMSNorm":
self.ln = TorchRMSNorm(hidden_size, eps=eps,
zero_centered_gamma=False)
self.ln = TorchRMSNorm(hidden_size, eps=eps, zero_centered_gamma=False)
else:
raise RuntimeError("Unsupported normalization")
if 'glu' in activation:
if "glu" in activation:
fc1_output_features = 2 * ffn_hidden_size
self.gelu = TorchGLU(activation)
else:
......@@ -387,7 +404,9 @@ class TorchLayerNormMLP(nn.Module):
class TorchGPT(nn.Module):
def __init__(self, hidden_size: int, eps: float, num_attention_heads: int, parallel_attention_mlp: bool):
def __init__(
self, hidden_size: int, eps: float, num_attention_heads: int, parallel_attention_mlp: bool
):
super().__init__()
self.ln = nn.LayerNorm(hidden_size, eps=eps)
self.causal_attn = TorchMHA(hidden_size, num_attention_heads)
......@@ -411,7 +430,6 @@ class TorchGPT(nn.Module):
return x
def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False):
reset_rng_states()
FP8GlobalStateManager.reset()
......@@ -421,23 +439,21 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8 and fp8_model_params):
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
params_dtype=dtype,
fuse_qkv_params=True,
device="cuda",
)
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
params_dtype=dtype,
fuse_qkv_params=True,
device="cuda",
)
te_inp_hidden_states = torch.randn(
......@@ -477,8 +493,12 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_par
config = model_configs[model]
outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False)
outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True)
outputs = _test_e2e_selective_recompute(
bs, dtype, config, fp8, fp8_model_params, recompute=False
)
outputs_recompute = _test_e2e_selective_recompute(
bs, dtype, config, fp8, fp8_model_params, recompute=True
)
# Check that results match
tols = dtype_tols(dtype)
......@@ -496,10 +516,7 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_par
def _test_e2e_full_recompute(
bs, dtype, config, fp8,
fp8_model_params=False,
recompute=False,
use_reentrant=True
bs, dtype, config, fp8, fp8_model_params=False, recompute=False, use_reentrant=True
):
reset_rng_states()
FP8GlobalStateManager.reset()
......@@ -586,10 +603,12 @@ def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params,
# Non-reentrant checkpoint becomes non-deterministic with bias+GELU fusion
os.environ["NVTE_BIAS_GELU_NVFUSION"] = "0"
outputs, names = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params,
recompute=False, use_reentrant=use_reentrant)
outputs_recompute, _ = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params,
recompute=True, use_reentrant=use_reentrant)
outputs, names = _test_e2e_full_recompute(
bs, dtype, config, fp8, fp8_model_params, recompute=False, use_reentrant=use_reentrant
)
outputs_recompute, _ = _test_e2e_full_recompute(
bs, dtype, config, fp8, fp8_model_params, recompute=True, use_reentrant=use_reentrant
)
if not use_reentrant:
# Reset bias+GELU fusion flag to avoid contaminating other tests
......@@ -753,22 +772,19 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
config = model_configs[model]
te_gpt = (
TransformerLayer(
hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
num_attention_heads=config.num_attention_heads,
layernorm_epsilon=config.eps,
attention_dropout=0.1,
hidden_dropout=0.1,
params_dtype=dtype,
fuse_qkv_params=True,
qkv_weight_interleaved=False,
parallel_attention_mlp=parallel_attention_mlp,
device="cuda",
)
.eval()
)
te_gpt = TransformerLayer(
hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
num_attention_heads=config.num_attention_heads,
layernorm_epsilon=config.eps,
attention_dropout=0.1,
hidden_dropout=0.1,
params_dtype=dtype,
fuse_qkv_params=True,
qkv_weight_interleaved=False,
parallel_attention_mlp=parallel_attention_mlp,
device="cuda",
).eval()
torch_gpt = (
TorchGPT(
......@@ -853,18 +869,15 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
def test_mha_accuracy(dtype, bs, model, mask_type):
config = model_configs[model]
te_mha = (
MultiheadAttention(
config.hidden_size,
config.num_attention_heads,
fuse_qkv_params=True,
params_dtype=dtype,
qkv_weight_interleaved=False,
input_layernorm=False,
device="cuda",
)
.eval()
)
te_mha = MultiheadAttention(
config.hidden_size,
config.num_attention_heads,
fuse_qkv_params=True,
params_dtype=dtype,
qkv_weight_interleaved=False,
input_layernorm=False,
device="cuda",
).eval()
torch_mha = (
TorchMHA(
......@@ -919,7 +932,9 @@ def _test_granular_accuracy(block, bs, dtype, config):
def _test_dpa_accuracy(block, bs, dtype, config):
reset_rng_states()
mask = torch.triu(torch.ones(config.seq_len, config.seq_len, dtype=torch.bool, device="cuda"), diagonal=1)
mask = torch.triu(
torch.ones(config.seq_len, config.seq_len, dtype=torch.bool, device="cuda"), diagonal=1
)
query, key, value = [
torch.randn(
(config.seq_len, bs, config.num_attention_heads, config.embed),
......@@ -953,7 +968,7 @@ def test_dpa_accuracy(dtype, bs, model):
DotProductAttention(
config.num_attention_heads,
config.embed,
attention_dropout=0.0, # disable dropout, FU uses rng differently
attention_dropout=0.0, # disable dropout, FU uses rng differently
)
.to(dtype=dtype)
.cuda()
......@@ -962,7 +977,7 @@ def test_dpa_accuracy(dtype, bs, model):
torch_dpa = (
TorchDotProductAttention(
config.embed,
0.0, # dropout
0.0, # dropout
)
.to(dtype=dtype)
.cuda()
......@@ -984,27 +999,21 @@ def test_dpa_accuracy(dtype, bs, model):
def test_linear_accuracy(dtype, bs, model):
config = model_configs[model]
te_linear = (
Linear(
config.hidden_size,
4 * config.hidden_size,
bias=True,
params_dtype=dtype,
device="cuda",
)
.eval()
)
te_linear = Linear(
config.hidden_size,
4 * config.hidden_size,
bias=True,
params_dtype=dtype,
device="cuda",
).eval()
torch_linear = (
torch.nn.Linear(
config.hidden_size,
4 * config.hidden_size,
bias=True,
device="cuda",
dtype=dtype,
)
.eval()
)
torch_linear = torch.nn.Linear(
config.hidden_size,
4 * config.hidden_size,
bias=True,
device="cuda",
dtype=dtype,
).eval()
# Share params
with torch.no_grad():
......@@ -1029,23 +1038,16 @@ def test_linear_accuracy(dtype, bs, model):
def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
config = model_configs[model]
te_rmsnorm = (
RMSNorm(
config.hidden_size,
eps=eps,
params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
device="cuda",
)
.eval()
)
te_rmsnorm = RMSNorm(
config.hidden_size,
eps=eps,
params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
device="cuda",
).eval()
torch_rmsnorm = (
TorchRMSNorm(
config.hidden_size,
eps=eps,
zero_centered_gamma=zero_centered_gamma
)
TorchRMSNorm(config.hidden_size, eps=eps, zero_centered_gamma=zero_centered_gamma)
.to(dtype=dtype)
.cuda()
.eval()
......@@ -1059,12 +1061,14 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
torch_outputs = _test_granular_accuracy(torch_rmsnorm, bs, dtype, config)
# Check output.
atol = {torch.float32 : 1e-7,
torch.half : 2e-3,
torch.bfloat16: 2e-2,
atol = {
torch.float32: 1e-7,
torch.half: 2e-3,
torch.bfloat16: 2e-2,
}
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
......@@ -1073,23 +1077,16 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
config = model_configs[model]
te_layernorm = (
LayerNorm(
config.hidden_size,
eps=eps,
params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
device="cuda",
)
.eval()
)
te_layernorm = LayerNorm(
config.hidden_size,
eps=eps,
params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
device="cuda",
).eval()
torch_layernorm = (
TorchLayerNorm(
config.hidden_size,
eps=eps,
zero_centered_gamma=zero_centered_gamma
)
TorchLayerNorm(config.hidden_size, eps=eps, zero_centered_gamma=zero_centered_gamma)
.to(dtype=dtype)
.cuda()
.eval()
......@@ -1104,9 +1101,10 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
torch_outputs = _test_granular_accuracy(torch_layernorm, bs, dtype, config)
# Check output.
atol = {torch.float32 : 1e-7,
torch.half : 2e-3,
torch.bfloat16: 2e-2,
atol = {
torch.float32: 1e-7,
torch.half: 2e-3,
torch.bfloat16: 2e-2,
}
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
......@@ -1119,19 +1117,16 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centered_gamma):
config = model_configs[model]
te_ln_linear = (
LayerNormLinear(
config.hidden_size,
4 * config.hidden_size,
config.eps,
bias=True,
normalization=normalization,
params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
device="cuda",
)
.eval()
)
te_ln_linear = LayerNormLinear(
config.hidden_size,
4 * config.hidden_size,
config.eps,
bias=True,
normalization=normalization,
params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
device="cuda",
).eval()
torch_ln_linear = (
TorchLayerNormLinear(
......@@ -1159,9 +1154,10 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere
torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config)
# Check output.
atol = {torch.float32 : 2.5e-4,
torch.half : 2e-3,
torch.bfloat16: 2e-2,
atol = {
torch.float32: 2.5e-4,
torch.half: 2e-3,
torch.bfloat16: 2e-2,
}
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
......@@ -1174,17 +1170,14 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere
def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
config = model_configs[model]
te_ln_mlp = (
LayerNormMLP(
config.hidden_size,
4 * config.hidden_size,
activation=activation,
normalization=normalization,
params_dtype=dtype,
device="cuda",
)
.eval()
)
te_ln_mlp = LayerNormMLP(
config.hidden_size,
4 * config.hidden_size,
activation=activation,
normalization=normalization,
params_dtype=dtype,
device="cuda",
).eval()
torch_ln_mlp = (
TorchLayerNormMLP(
......@@ -1226,8 +1219,10 @@ def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
optimizer = torch.optim.SGD(block.parameters(), lr=0.1)
# Placeholders used for graph capture.
static_input = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype, requires_grad=True)
static_target = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype)
static_input = torch.randn(
config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True
)
static_target = torch.randn(config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype)
real_input = torch.rand_like(static_input)
real_target = torch.rand_like(static_target)
......@@ -1286,22 +1281,20 @@ def test_gpt_cuda_graph(dtype, bs, model):
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
device="cuda",
)
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
device="cuda",
)
graphed_block = copy.deepcopy(block)
......@@ -1388,7 +1381,6 @@ def test_gpt_fp8_parameters(dtype, bs, model):
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
......@@ -1451,7 +1443,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
requires_grad=True,
)
x_bshd = x_sbhd.transpose(0,1).contiguous()
x_bshd = x_sbhd.transpose(0, 1).contiguous()
# To make sure forward is also identical (just in case some module decides
# to act fancy)
......@@ -1466,7 +1458,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
# Check that results match
torch.testing.assert_close(
y_bshd,
y_sbhd.transpose(0,1).contiguous(),
y_sbhd.transpose(0, 1).contiguous(),
)
......@@ -1500,19 +1492,16 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
S_max = S + 2
if module == "TransformerLayer":
model = (
TransformerLayer(
hidden_size=D,
ffn_hidden_size= 4 * D,
num_attention_heads=H,
attn_input_format=input_format,
layer_number=layer_number,
attention_dropout = 0.0,
params_dtype=dtype,
device="cuda",
)
.eval()
)
model = TransformerLayer(
hidden_size=D,
ffn_hidden_size=4 * D,
num_attention_heads=H,
attn_input_format=input_format,
layer_number=layer_number,
attention_dropout=0.0,
params_dtype=dtype,
device="cuda",
).eval()
else:
model = (
MultiheadAttention(
......@@ -1520,7 +1509,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
num_attention_heads=H,
qkv_format=input_format,
layer_number=layer_number,
attention_dropout = 0.0,
attention_dropout=0.0,
params_dtype=dtype,
)
.cuda()
......@@ -1537,39 +1526,38 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
incremental_output = torch.zeros_like(input)
# Generate output for the entire sequence
full_output = model(
hidden_states=input,
rotary_pos_emb=rotary_freqs if use_RoPE else None)
full_output = model(hidden_states=input, rotary_pos_emb=rotary_freqs if use_RoPE else None)
# Incrementaly generate outputs using KV-cache
for i in range(S):
if input_format == "sbhd":
incremental_input = input[i].view(1,B,D)
incremental_input = input[i].view(1, B, D)
else:
incremental_input = input[:, i, :].view(B,1,D)
incremental_input = input[:, i, :].view(B, 1, D)
line_output = model(
hidden_states=incremental_input,
inference_params=inference_params,
rotary_pos_emb=rotary_freqs if use_RoPE else None)
rotary_pos_emb=rotary_freqs if use_RoPE else None,
)
inference_params.sequence_len_offset += 1
if input_format == "sbhd":
incremental_output[i] = line_output.view(B,D)
incremental_output[i] = line_output.view(B, D)
else:
incremental_output[:, i, :] = line_output.view(B,D)
incremental_output[:, i, :] = line_output.view(B, D)
if module == "TransformerLayer":
atol = {
torch.float32 : 5e-3,
torch.half : 5e-3,
torch.float32: 5e-3,
torch.half: 5e-3,
torch.bfloat16: 5e-2,
}
else:
atol = {
torch.float32 : 1e-3,
torch.half : 1e-3,
torch.float32: 1e-3,
torch.half: 1e-3,
torch.bfloat16: 1e-2,
}
......
......@@ -32,7 +32,13 @@ from typing import Optional, Union, Tuple, List
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, gelu, cast_to_fp8, cast_from_fp8
from transformer_engine.pytorch.cpp_extensions import (
gemm,
fp8_gemm,
gelu,
cast_to_fp8,
cast_from_fp8,
)
from transformer_engine.pytorch.module.base import get_workspace
import transformer_engine.pytorch.cpp_extensions as texcpp
import transformer_engine.pytorch.softmax as softmax_defs
......@@ -50,8 +56,10 @@ if SAVE_TEST_IO:
from polygraphy.comparator import RunResults
# The directory where generated ONNX test models are stored.
NVTE_TEST_ARTIFACTS_DIR = os.environ.get('NVTE_TEST_ARTIFACTS_DIR')
NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join(tempfile.gettempdir(), "./gen_onnx_models")
NVTE_TEST_ARTIFACTS_DIR = os.environ.get("NVTE_TEST_ARTIFACTS_DIR")
NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join(
tempfile.gettempdir(), "./gen_onnx_models"
)
# The directory where this file is stored.
......@@ -100,23 +108,21 @@ def do_export(
model: torch.nn.Module,
inp: torch.Tensor,
fname: str,
use_fp8: bool=True,
opset: int=OPSET,
input_names: List[str]=None,
output_names: List[str]=None,
dynamic_axes: List[str]=None
use_fp8: bool = True,
opset: int = OPSET,
input_names: List[str] = None,
output_names: List[str] = None,
dynamic_axes: List[str] = None,
):
"""Export to ONNX"""
fp8_recipe = create_fp8_recipe()
input_names = input_names or ["input"]
output_names = output_names or ["output"]
with torch.inference_mode(), te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), warnings.catch_warnings():
warnings.filterwarnings(
action='ignore',
category=torch.jit.TracerWarning,
module=r'.*'
)
with torch.inference_mode(), te.fp8_autocast(
enabled=use_fp8, fp8_recipe=fp8_recipe
), warnings.catch_warnings():
warnings.filterwarnings(action="ignore", category=torch.jit.TracerWarning, module=r".*")
model.cuda().eval()
os.makedirs(NVTE_TEST_ARTIFACTS_DIR, exist_ok=True)
......@@ -138,7 +144,8 @@ def do_export(
input_names=input_names,
output_names=output_names,
do_constant_folding=True,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH)
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH,
)
def to_numpy(tensor):
......@@ -154,24 +161,30 @@ def set_layer_scale(module: torch.nn.Module, scale: float, num_gemms: int):
NB_SCALES_PER_GEMM = 3 # One scale per: input, weights, and output GEMM tensors.
nb_total_scales = num_gemms * NB_SCALES_PER_GEMM
module.init_fp8_metadata(num_gemms)
module.fp8_meta["scaling_fwd"].scale = torch.ones(
nb_total_scales, dtype=torch.float32, device="cuda") / scale
module.fp8_meta["scaling_fwd"].scale_inv = torch.ones(
nb_total_scales, dtype=torch.float32, device="cuda") * scale
module.fp8_meta["scaling_fwd"].scale = (
torch.ones(nb_total_scales, dtype=torch.float32, device="cuda") / scale
)
module.fp8_meta["scaling_fwd"].scale_inv = (
torch.ones(nb_total_scales, dtype=torch.float32, device="cuda") * scale
)
def te_infer(model: torch.nn.Module, inps: Union[Tuple[torch.tensor], torch.tensor], is_fp8: bool):
"""Transformer Engine forward propagation."""
fp8_recipe = create_fp8_recipe()
with torch.inference_mode(), te.fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe), warnings.catch_warnings():
with torch.inference_mode(), te.fp8_autocast(
enabled=is_fp8, fp8_recipe=fp8_recipe
), warnings.catch_warnings():
te_outputs = model(*inps if isinstance(inps, tuple) else (inps,))
if not isinstance(te_outputs, tuple):
te_outputs = (te_outputs,)
return te_outputs
def compare_outputs(onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname):
""" Compare ORT and TE outputs."""
def compare_outputs(
onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname
):
"""Compare ORT and TE outputs."""
assert len(onnx_outputs) == len(te_outputs)
# Compare ORT and PyTorch outputs.
for onnx_output, te_output in zip(onnx_outputs, te_outputs):
......@@ -192,11 +205,15 @@ def compare_outputs(onnx_outputs, te_outputs, atol, rtol, max_errors_printed, al
errors = abs_err[mismatches]
for loc in mismatched_ids[:nb_vals]:
ref = te_output[loc]
print(f"{onnx_output[loc]} -- {te_output[loc]} err={abs_err[loc]} > {atol + rtol * abs(ref)}")
print(
f"{onnx_output[loc]} -- {te_output[loc]} err={abs_err[loc]} >"
f" {atol + rtol * abs(ref)}"
)
print(f"Max error: {np.max(errors)}")
if nb_errors > allow_cnt_errors:
raise ValueError(f"Output validation of {fname} failed with {nb_errors} errors")
def serialize_inputs_outputs(
fname: str,
inputs: Union[Tuple[torch.Tensor], torch.Tensor],
......@@ -214,10 +231,10 @@ def serialize_inputs_outputs(
inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,)
named_inputs = zip(input_names, inputs)
input_data = [{k: v.cpu() for k, v in named_inputs if v is not None}]
json_fname = fname[:-len(".onnx")] + "_inputs.json"
json_fname = fname[: -len(".onnx")] + "_inputs.json"
save_json(input_data, json_fname, description="custom input data")
json_fname = fname[:-len(".onnx")] + "_output.json"
json_fname = fname[: -len(".onnx")] + "_output.json"
named_outputs = zip(output_names, te_outputs)
output_data = {k: v.detach().cpu() for k, v in named_outputs if v is not None}
custom_outputs = RunResults()
......@@ -229,14 +246,14 @@ def validate_result(
fname: str,
inps: Union[Tuple[torch.Tensor], torch.Tensor],
model: torch.nn.Module,
atol: float=1.e-8, # np.isclose default atol
rtol: float=1.e-5, # np.isclose default rtol
max_errors_printed: int=10,
is_fp8: bool=False,
allow_cnt_errors: int=0,
input_names: List[str]=None,
output_names: List[str]=None,
te_outputs: List[torch.Tensor]=None,
atol: float = 1.0e-8, # np.isclose default atol
rtol: float = 1.0e-5, # np.isclose default rtol
max_errors_printed: int = 10,
is_fp8: bool = False,
allow_cnt_errors: int = 0,
input_names: List[str] = None,
output_names: List[str] = None,
te_outputs: List[torch.Tensor] = None,
):
"""Compare the outputs of a Transformer Engine (TE) module vs the outputs of its ONNX
representation using ONNX Runtime (ORT) and ensure they are close.
......@@ -262,7 +279,7 @@ def validate_result(
print("registered custom FP8 Q/DQ ops!")
"""Create an ONNX Runtime session for validation."""
kwargs = {"providers": ['CUDAExecutionProvider', 'CPUExecutionProvider']}
kwargs = {"providers": ["CUDAExecutionProvider", "CPUExecutionProvider"]}
if is_fp8:
sess_options = ort.SessionOptions()
load_custom_ops(sess_options)
......@@ -288,10 +305,12 @@ def validate_result(
ort_s = create_ort_session(fname, is_fp8)
input_feed = create_ort_input_dict(ort_s, inps)
onnx_outputs = ort_s.run(None, input_feed=input_feed)
compare_outputs(onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname)
compare_outputs(
onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname
)
def create_meta(scale_factor: float, size: int=1):
def create_meta(scale_factor: float, size: int = 1):
meta = tex.FP8TensorMeta()
meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda")
meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor
......@@ -324,7 +343,9 @@ def get_attn_mask_str(use_mask, attn_mask_type):
return "_mask" if use_mask else "_no-mask"
attn_mask_str = "_arbitrary-no-mask"
attn_mask_str = "_causal-mask" if attn_mask_type == "causal" else attn_mask_str
attn_mask_str = "_arbitrary-mask" if use_mask and attn_mask_type == "arbitrary" else attn_mask_str
attn_mask_str = (
"_arbitrary-mask" if use_mask and attn_mask_type == "arbitrary" else attn_mask_str
)
return attn_mask_str
......@@ -351,17 +372,11 @@ class FP8GemmModule(nn.Module):
self.outp_type = precision
def forward(self, inp, weight):
inp_fp8 = cast_to_fp8(
inp,
self.meta_inp,
self.fp8_tensor_inp,
self.inp_type)
inp_fp8 = cast_to_fp8(inp, self.meta_inp, self.fp8_tensor_inp, self.inp_type)
weight_fp8 = cast_to_fp8(
weight,
self.meta_weight,
self.fp8_tensor_weight,
self.weights_type)
weight, self.meta_weight, self.fp8_tensor_weight, self.weights_type
)
ret, _ = fp8_gemm(
weight_fp8,
......@@ -376,9 +391,11 @@ class FP8GemmModule(nn.Module):
get_workspace(),
bias=self.bias,
use_bias=self.use_bias,
use_split_accumulator=False)
use_split_accumulator=False,
)
return ret
"""
Tests cases begin here.
"""
......@@ -387,13 +404,17 @@ Tests cases begin here.
@skip_FP8
@pytest.mark.parametrize("scale_factor", [1, 224])
@pytest.mark.parametrize(
"precision, atol", [
[torch.float32, 1e-7],
[torch.float16, 1e-7],
[torch.bfloat16, 5e-3],
["fake-torch.bfloat16", 5e-3],
])
def test_export_cast_ops(seed_default_rng, scale_factor: float, atol: float, precision: torch.dtype):
"precision, atol",
[
[torch.float32, 1e-7],
[torch.float16, 1e-7],
[torch.bfloat16, 5e-3],
["fake-torch.bfloat16", 5e-3],
],
)
def test_export_cast_ops(
seed_default_rng, scale_factor: float, atol: float, precision: torch.dtype
):
fake_bf16_io = precision == "fake-torch.bfloat16"
# reset precision to torch.bfloat16 after capturing fake BF16 mode
precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision
......@@ -408,18 +429,9 @@ def test_export_cast_ops(seed_default_rng, scale_factor: float, atol: float, pre
self.fake_bf16_io = fake_bf16_io
def forward(self, inp):
ret = cast_to_fp8(
inp,
self.meta,
self.fp8_tensor,
self.fp8_type)
ret = cast_to_fp8(inp, self.meta, self.fp8_tensor, self.fp8_type)
ret = cast_from_fp8(
ret,
self.meta,
self.fp8_tensor,
self.fp8_type,
self.highprec_type)
ret = cast_from_fp8(ret, self.meta, self.fp8_tensor, self.fp8_type, self.highprec_type)
if self.fake_bf16_io:
ret = ret.type(torch.float32)
return ret
......@@ -427,8 +439,9 @@ def test_export_cast_ops(seed_default_rng, scale_factor: float, atol: float, pre
# Set dimensions (these are arbitrary).
in_features = 64
hidden_size = 256
inp = torch.randn(hidden_size, in_features, device="cuda",
dtype=torch.float if fake_bf16_io else precision)
inp = torch.randn(
hidden_size, in_features, device="cuda", dtype=torch.float if fake_bf16_io else precision
)
high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io)
fname = f"te.cast_fp8_{scale_factor}{high_prec_str}.onnx"
model = TestFP8_QDQ(fake_bf16_io)
......@@ -439,15 +452,18 @@ def test_export_cast_ops(seed_default_rng, scale_factor: float, atol: float, pre
if fake_bf16_io or precision != torch.bfloat16:
validate_result(fname, inp, model, atol=atol, is_fp8=True, te_outputs=te_outputs)
@skip_FP8
@pytest.mark.parametrize("scale_factor", [448])
@pytest.mark.parametrize(
"precision, atol", [
[torch.float32, 1e-5],
[torch.float16, 1e-5],
[torch.bfloat16, 5e-3],
["fake-torch.bfloat16", 5e-3]
])
"precision, atol",
[
[torch.float32, 1e-5],
[torch.float16, 1e-5],
[torch.bfloat16, 5e-3],
["fake-torch.bfloat16", 5e-3],
],
)
def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: float):
fake_bf16_io = precision == "fake-torch.bfloat16"
# reset precision to torch.bfloat16 after capturing fake BF16 mode
......@@ -463,17 +479,8 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa
self.fake_bf16_io = fake_bf16_io
def forward(self, inp):
ret = gelu(
inp,
self.meta,
self.fp8_tensor,
self.fp8_type)
ret = cast_from_fp8(
ret,
self.meta,
self.fp8_tensor,
self.fp8_type,
self.highprec_type)
ret = gelu(inp, self.meta, self.fp8_tensor, self.fp8_type)
ret = cast_from_fp8(ret, self.meta, self.fp8_tensor, self.fp8_type, self.highprec_type)
if self.fake_bf16_io:
ret = ret.type(torch.float32)
return ret
......@@ -481,8 +488,9 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa
# Set dimensions (these are arbitrary).
in_features = 64
hidden_size = 256
inp = torch.randn(hidden_size, in_features, device="cuda",
dtype=torch.float if fake_bf16_io else precision)
inp = torch.randn(
hidden_size, in_features, device="cuda", dtype=torch.float if fake_bf16_io else precision
)
high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io)
fname = f"te.gelu_fp8_{scale_factor}{high_prec_str}.onnx"
model = TestFP8_Gelu(fake_bf16_io)
......@@ -490,39 +498,55 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa
te_outputs = te_infer(model, inp, is_fp8=True)
serialize_inputs_outputs(fname, inp, te_outputs)
if fake_bf16_io or precision != torch.bfloat16:
validate_result(fname, inp, model, rtol=0, atol=atol, is_fp8=True, allow_cnt_errors=2, te_outputs=te_outputs)
validate_result(
fname,
inp,
model,
rtol=0,
atol=atol,
is_fp8=True,
allow_cnt_errors=2,
te_outputs=te_outputs,
)
@pytest.mark.parametrize("scale_factors",
[(224, 224,),
])
@pytest.mark.parametrize(
"precision, use_fp8, use_bias, use_gelu", [
(torch.float32, False, False, False),
(torch.float16, False, False, False),
(torch.bfloat16, False, False, False),
(torch.float32, False, True, False),
(torch.float16, False, True, False),
(torch.bfloat16, False, True, False),
(torch.float32, False, True, True),
(torch.float16, False, True, True),
(torch.bfloat16, False, True, True),
# For FP8 GEMM GeLU is not used.
(torch.float32, True, False, False),
(torch.float16, True, False, False),
(torch.bfloat16, True, False, False),
# When enabling bias we must use float16 or bfloat16 (because of kernel limitations)
(torch.float16, True, True, False),
(torch.bfloat16, True, True, False),
])
"scale_factors",
[
(
224,
224,
),
],
)
@pytest.mark.parametrize(
"precision, use_fp8, use_bias, use_gelu",
[
(torch.float32, False, False, False),
(torch.float16, False, False, False),
(torch.bfloat16, False, False, False),
(torch.float32, False, True, False),
(torch.float16, False, True, False),
(torch.bfloat16, False, True, False),
(torch.float32, False, True, True),
(torch.float16, False, True, True),
(torch.bfloat16, False, True, True),
# For FP8 GEMM GeLU is not used.
(torch.float32, True, False, False),
(torch.float16, True, False, False),
(torch.bfloat16, True, False, False),
# When enabling bias we must use float16 or bfloat16 (because of kernel limitations)
(torch.float16, True, True, False),
(torch.bfloat16, True, True, False),
],
)
def test_export_gemm(
seed_default_rng,
precision, # Precision of inputs, weights, output and bias
precision, # Precision of inputs, weights, output and bias
use_fp8,
use_bias,
use_gelu,
scale_factors
scale_factors,
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and not fp8_available:
......@@ -548,21 +572,20 @@ def test_export_gemm(
inp,
outp_type,
get_workspace(),
# test bias
bias=self.bias,
use_bias=self.use_bias,
# test gelu
gelu=self.gelu,
gelu_input=self.gelu_input,
grad=False, # only True for backward pass
grad=False, # only True for backward pass
accumulate=False,
)
return ret
# If gelu is applied then bias must be added, as defined by TE kernel.
if use_gelu: assert use_bias
if use_gelu:
assert use_bias
# Set dimensions (these are arbitrary).
out_features = 128
hidden_size = 256
......@@ -574,45 +597,64 @@ def test_export_gemm(
gelu_str = "_gelu" if use_gelu else ""
high_prec_str = dtype2str(precision)
fname = f"te.gemm{fp8_str}{bias_str}{gelu_str}{high_prec_str}.onnx"
input_names = ['input', 'weight']
input_names = ["input", "weight"]
if use_fp8:
model = FP8GemmModule(precision, use_bias, use_gelu, scale_factors, hidden_size, out_features)
model = FP8GemmModule(
precision, use_bias, use_gelu, scale_factors, hidden_size, out_features
)
do_export(model, (inp, weight), fname, use_fp8, input_names=input_names)
te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8)
serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names)
if precision != torch.bfloat16:
validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2,
is_fp8=True, input_names=input_names, te_outputs=te_outputs)
validate_result(
fname,
(inp, weight),
model,
rtol=1e-2,
atol=2e-2,
is_fp8=True,
input_names=input_names,
te_outputs=te_outputs,
)
else:
model = Test_GEMM(precision, use_bias, use_gelu)
do_export(model, (inp, weight), fname, use_fp8, input_names=input_names)
te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8)
serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names)
if precision != torch.bfloat16:
validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2,
input_names=input_names, te_outputs=te_outputs)
validate_result(
fname,
(inp, weight),
model,
rtol=1e-2,
atol=2e-2,
input_names=input_names,
te_outputs=te_outputs,
)
@pytest.mark.parametrize("scale_factor", [448, 112])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize(
"use_fp8, precision, atol", [
[False, torch.float32, 1e-7],
[False, torch.float16, 1e-7],
[False, torch.bfloat16, 1e-7],
[False, "fake-torch.bfloat16", 1e-7],
[True, torch.float32, 1e-7],
[True, torch.float16, 1e-7],
[True, torch.bfloat16, 1e-2],
[True, "fake-torch.bfloat16", 1e-2]
])
"use_fp8, precision, atol",
[
[False, torch.float32, 1e-7],
[False, torch.float16, 1e-7],
[False, torch.bfloat16, 1e-7],
[False, "fake-torch.bfloat16", 1e-7],
[True, torch.float32, 1e-7],
[True, torch.float16, 1e-7],
[True, torch.bfloat16, 1e-2],
[True, "fake-torch.bfloat16", 1e-2],
],
)
def test_export_layernorm(
seed_default_rng,
use_fp8: bool,
scale_factor: float,
precision: torch.dtype,
zero_centered_gamma: bool,
atol: float
atol: float,
):
fake_bf16_io = precision == "fake-torch.bfloat16"
# reset precision to torch.bfloat16 after capturing fake BF16 mode
......@@ -628,10 +670,15 @@ def test_export_layernorm(
class Test_Layernorm(nn.Module):
def __init__(self) -> None:
super().__init__()
eps = 1e-6 # An arbitrary small value
eps = 1e-6 # An arbitrary small value
dtype = torch.float if fake_bf16_io else precision
self.ln = te.LayerNorm(inp_shape[1], eps, params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma).eval().cuda()
self.ln = (
te.LayerNorm(
inp_shape[1], eps, params_dtype=dtype, zero_centered_gamma=zero_centered_gamma
)
.eval()
.cuda()
)
def forward(self, inp):
ret = self.ln(inp)
......@@ -641,11 +688,13 @@ def test_export_layernorm(
def __init__(self) -> None:
super().__init__()
normalized_shape = torch.Size(inp.shape[1:])
self.weight = torch.randn(*normalized_shape, device="cuda",
dtype=torch.float32 if fake_bf16_io else precision)
self.bias = torch.zeros(*normalized_shape, device="cuda",
dtype=torch.float32 if fake_bf16_io else precision)
self.eps = 1e-6 # An arbitrary small value
self.weight = torch.randn(
*normalized_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision
)
self.bias = torch.zeros(
*normalized_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision
)
self.eps = 1e-6 # An arbitrary small value
self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT
self.meta = create_meta(scale_factor)
......@@ -661,14 +710,12 @@ def test_export_layernorm(
self.fp8_tensor,
self.fp8_type,
0,
zero_centered_gamma)
zero_centered_gamma,
)
ret = cast_from_fp8(
ret,
self.meta,
self.fp8_tensor,
self.fp8_type,
as_te_type(precision))
ret, self.meta, self.fp8_tensor, self.fp8_type, as_te_type(precision)
)
if fake_bf16_io:
ret = ret.type(torch.float32)
return ret
......@@ -683,28 +730,32 @@ def test_export_layernorm(
serialize_inputs_outputs(fname, inp, te_outputs)
if fake_bf16_io or precision != torch.bfloat16:
validate_result(
fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs)
fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs
)
@pytest.mark.parametrize("scale_factor", [448, 112])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize(
"use_fp8, precision, atol", [
[False, torch.float32, 1e-7],
[False, torch.float16, 1e-7],
[False, torch.bfloat16, 1e-7],
[False, "fake-torch.bfloat16", 1e-7],
[True, torch.float32, 1e-7],
[True, torch.float16, 1e-7],
[True, torch.bfloat16, 1e-2],
[True, "fake-torch.bfloat16", 1e-2]
])
"use_fp8, precision, atol",
[
[False, torch.float32, 1e-7],
[False, torch.float16, 1e-7],
[False, torch.bfloat16, 1e-7],
[False, "fake-torch.bfloat16", 1e-7],
[True, torch.float32, 1e-7],
[True, torch.float16, 1e-7],
[True, torch.bfloat16, 1e-2],
[True, "fake-torch.bfloat16", 1e-2],
],
)
def test_export_rmsnorm(
seed_default_rng,
use_fp8: bool,
scale_factor: float,
precision: torch.dtype,
zero_centered_gamma: bool,
atol: float
atol: float,
):
fake_bf16_io = precision == "fake-torch.bfloat16"
# reset precision to torch.bfloat16 after capturing fake BF16 mode
......@@ -720,10 +771,15 @@ def test_export_rmsnorm(
class Test_RMSnorm(nn.Module):
def __init__(self) -> None:
super().__init__()
eps = 1e-6 # An arbitrary small value
eps = 1e-6 # An arbitrary small value
dtype = torch.float if fake_bf16_io else precision
self.ln = te.RMSNorm(inp_shape[1], eps, params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma).eval().cuda()
self.ln = (
te.RMSNorm(
inp_shape[1], eps, params_dtype=dtype, zero_centered_gamma=zero_centered_gamma
)
.eval()
.cuda()
)
def forward(self, inp):
ret = self.ln(inp)
......@@ -733,9 +789,10 @@ def test_export_rmsnorm(
def __init__(self) -> None:
super().__init__()
normalized_shape = torch.Size(inp.shape[1:])
self.weight = torch.randn(*normalized_shape, device="cuda",
dtype=torch.float32 if fake_bf16_io else precision)
self.eps = 1e-6 # An arbitrary small value
self.weight = torch.randn(
*normalized_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision
)
self.eps = 1e-6 # An arbitrary small value
self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT
self.meta = create_meta(scale_factor)
......@@ -750,14 +807,12 @@ def test_export_rmsnorm(
self.fp8_tensor,
self.fp8_type,
0,
zero_centered_gamma)
zero_centered_gamma,
)
ret = cast_from_fp8(
ret,
self.meta,
self.fp8_tensor,
self.fp8_type,
as_te_type(precision))
ret, self.meta, self.fp8_tensor, self.fp8_type, as_te_type(precision)
)
if fake_bf16_io:
ret = ret.type(torch.float32)
return ret
......@@ -772,7 +827,8 @@ def test_export_rmsnorm(
serialize_inputs_outputs(fname, inp, te_outputs)
if fake_bf16_io or precision != torch.bfloat16:
validate_result(
fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs)
fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs
)
@pytest.mark.parametrize("scale_factor", [1])
......@@ -780,23 +836,25 @@ def test_export_rmsnorm(
# Returning the bias is a TE fusion optimization we don't care about.
@pytest.mark.parametrize("return_bias", [False])
@pytest.mark.parametrize(
"precision, use_bias",[
(torch.float32, False),
(torch.float32, True),
(torch.float16, False),
(torch.float16, True),
# Todo: cannot configure BF16 when bias is disabled (ORT issue?)
(torch.bfloat16, False),
# Todo: cannot configure BF16 when bias is enabled (ORT issue?)
(torch.bfloat16, True),
])
"precision, use_bias",
[
(torch.float32, False),
(torch.float32, True),
(torch.float16, False),
(torch.float16, True),
# Todo: cannot configure BF16 when bias is disabled (ORT issue?)
(torch.bfloat16, False),
# Todo: cannot configure BF16 when bias is enabled (ORT issue?)
(torch.bfloat16, True),
],
)
def test_export_linear(
seed_default_rng,
scale_factor: float,
use_fp8: bool,
use_bias: bool,
return_bias: bool,
precision: torch.dtype
precision: torch.dtype,
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and not fp8_available:
......@@ -808,20 +866,14 @@ def test_export_linear(
hidden_size = 256
class Test_Linear(nn.Module):
def __init__(self,
in_features,
out_features,
use_bias,
return_bias,
precision
):
def __init__(self, in_features, out_features, use_bias, return_bias, precision):
super().__init__()
self.linear = te.Linear(
in_features,
out_features,
bias=use_bias,
return_bias=return_bias,
params_dtype=precision
params_dtype=precision,
)
def forward(self, inp):
......@@ -834,20 +886,16 @@ def test_export_linear(
high_prec_str = dtype2str(precision)
fname = f"te.linear{fp8_str}{bias_str}{high_prec_str}.onnx"
with te.fp8_autocast(enabled=use_fp8):
model = Test_Linear(
in_features,
out_features,
use_bias,
return_bias,
precision
).to(device='cuda')
model = Test_Linear(in_features, out_features, use_bias, return_bias, precision).to(
device="cuda"
)
if use_fp8:
set_layer_scale(model.linear, scale_factor, num_gemms=1)
do_export(model, inp, fname, use_fp8)
te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs)
if precision in (torch.bfloat16, ):
if precision in (torch.bfloat16,):
return
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs)
......@@ -861,14 +909,16 @@ def test_export_linear(
@pytest.mark.parametrize("return_bias", [False])
@pytest.mark.parametrize("return_layernorm_output", [False])
@pytest.mark.parametrize(
"precision, use_bias",[
(torch.float32, False),
(torch.float32, True),
(torch.float16, True),
(torch.float16, False),
(torch.bfloat16, True),
(torch.bfloat16, False),
])
"precision, use_bias",
[
(torch.float32, False),
(torch.float32, True),
(torch.float16, True),
(torch.float16, False),
(torch.bfloat16, True),
(torch.bfloat16, False),
],
)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("normalization", all_normalizations)
def test_export_layernorm_linear(
......@@ -907,13 +957,13 @@ def test_export_layernorm_linear(
params_dtype=precision,
zero_centered_gamma=zero_centered_gamma,
normalization=normalization,
).to(device='cuda')
).to(device="cuda")
if use_fp8:
set_layer_scale(model, scale_factor, num_gemms=1)
do_export(model, inp, fname, use_fp8)
te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs)
if precision in (torch.bfloat16, ):
if precision in (torch.bfloat16,):
return
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs)
......@@ -927,14 +977,16 @@ def test_export_layernorm_linear(
@pytest.mark.parametrize("return_bias", [False])
@pytest.mark.parametrize("return_layernorm_output", [False])
@pytest.mark.parametrize(
"precision, use_bias",[
(torch.float32, False),
(torch.float32, True),
(torch.float16, True),
(torch.float16, False),
(torch.bfloat16, True),
(torch.bfloat16, False),
])
"precision, use_bias",
[
(torch.float32, False),
(torch.float32, True),
(torch.float16, True),
(torch.float16, False),
(torch.bfloat16, True),
(torch.bfloat16, False),
],
)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("activation", supported_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
......@@ -954,7 +1006,6 @@ def test_export_layernorm_mlp(
if use_fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
# Set dimensions (these are arbitrary).
in_features = 64
out_features = 256
......@@ -977,30 +1028,32 @@ def test_export_layernorm_mlp(
zero_centered_gamma=zero_centered_gamma,
activation=activation,
normalization=normalization,
).to(device='cuda')
).to(device="cuda")
if use_fp8:
set_layer_scale(model, scale_factor, num_gemms=2)
do_export(model, inp, fname, use_fp8)
te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs)
if precision in (torch.bfloat16, ):
if precision in (torch.bfloat16,):
return
atol = 1e-6 if use_fp8 else (5e-1 if activation=="swiglu" else 1e-3)
atol = 1e-6 if use_fp8 else (5e-1 if activation == "swiglu" else 1e-3)
validate_result(fname, inp, model, atol=atol, is_fp8=use_fp8, te_outputs=te_outputs)
@skip_FP8
@pytest.mark.parametrize(
"precision, use_mask, attn_mask_type", [
(torch.float32, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.float16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.bfloat16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
])
"precision, use_mask, attn_mask_type",
[
(torch.float32, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.float16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.bfloat16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
],
)
def test_export_core_attention(
seed_default_rng,
set_max_seq_len,
......@@ -1034,40 +1087,42 @@ def test_export_core_attention(
attention_dropout=0.5,
qkv_format=qkv_format,
attn_mask_type=attn_mask_type,
).to(device='cuda')
do_export(model,
inp,
fname,
input_names=input_names,
use_fp8=True)
).to(device="cuda")
do_export(model, inp, fname, input_names=input_names, use_fp8=True)
te_outputs = te_infer(model, inp, is_fp8=True)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
if precision in (torch.bfloat16, ):
if precision in (torch.bfloat16,):
return
validate_result(fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs)
validate_result(
fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs
)
test_configs_multihead_attention = [
#"use_mask, attn_mask_type"
(False, "no_mask"), # calls ScaledSoftmax
(True, "arbitrary"), # calls ScaledMaskedSoftmax
# "use_mask, attn_mask_type"
(False, "no_mask"), # calls ScaledSoftmax
(True, "arbitrary"), # calls ScaledMaskedSoftmax
]
test_configs_attention_type = [
#"input_layernorm, attention_type, fuse_qkv_params"
(True, "self", True),
(False, "self", True),
(True, "self", False),
(False, "self", False),
(True, "cross", True),
(False, "cross", True),
(True, "cross", False),
(False, "cross", False),
# "input_layernorm, attention_type, fuse_qkv_params"
(True, "self", True),
(False, "self", True),
(True, "self", False),
(False, "self", False),
(True, "cross", True),
(False, "cross", True),
(True, "cross", False),
(False, "cross", False),
]
@pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("return_layernorm_output", [False])
@pytest.mark.parametrize("input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type)
@pytest.mark.parametrize(
"input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type
)
def test_export_multihead_attention(
seed_default_rng,
set_max_seq_len,
......@@ -1078,7 +1133,7 @@ def test_export_multihead_attention(
return_layernorm_output: bool,
input_layernorm: bool,
attention_type: str,
fuse_qkv_params: bool
fuse_qkv_params: bool,
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and not fp8_available:
......@@ -1102,17 +1157,23 @@ def test_export_multihead_attention(
output_layer_init_method,
)
hidden_states_context = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
hidden_states_context = torch.randn(
sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
)
attention_mask = None
if use_mask and attn_mask_type != "causal":
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision)
probs = 0.5 * torch.ones(
batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision
)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
encoder_output = None
if attention_type == "cross":
encoder_output = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
encoder_output = torch.randn(
sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
)
fp8_str = "_fp8" if use_fp8 else ""
dtype_str = dtype2str(precision)
......@@ -1131,49 +1192,98 @@ def test_export_multihead_attention(
attention_type=attention_type,
fuse_qkv_params=fuse_qkv_params,
return_bias=True,
).to(device='cuda')
).to(device="cuda")
inp_context = (hidden_states_context, attention_mask, encoder_output)
input_names = ["hidden_states", "attention_mask", "encoder_output"]
output_names=["attention_output", "attention_bias"]
do_export(model, inp_context, fname, use_fp8, input_names=input_names, output_names=output_names,
dynamic_axes={"hidden_states": {0: "seq", 1:"bs"},
"attention_output": {0: "seq", 1:"bs"}})
output_names = ["attention_output", "attention_bias"]
do_export(
model,
inp_context,
fname,
use_fp8,
input_names=input_names,
output_names=output_names,
dynamic_axes={
"hidden_states": {0: "seq", 1: "bs"},
"attention_output": {0: "seq", 1: "bs"},
},
)
te_outputs = te_infer(model, inp_context, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp_context, te_outputs, input_names=input_names, output_names=output_names)
if precision in (torch.bfloat16, ):
serialize_inputs_outputs(
fname, inp_context, te_outputs, input_names=input_names, output_names=output_names
)
if precision in (torch.bfloat16,):
return
if not use_fp8:
validate_result(fname, inp_context, model, atol=1e-3, input_names=input_names,
output_names=output_names, te_outputs=te_outputs)
validate_result(
fname,
inp_context,
model,
atol=1e-3,
input_names=input_names,
output_names=output_names,
te_outputs=te_outputs,
)
else:
validate_result(fname, inp_context, model, atol=1e-2, is_fp8=use_fp8,
input_names=input_names, output_names=output_names, allow_cnt_errors=3,
te_outputs=te_outputs)
validate_result(
fname,
inp_context,
model,
atol=1e-2,
is_fp8=use_fp8,
input_names=input_names,
output_names=output_names,
allow_cnt_errors=3,
te_outputs=te_outputs,
)
# In GPT generative phase (inference) the input sequence is smaller than the maximum
# allowed sequence length and we want to test this condition.
# Pretend that we're in generative phase when it makes sense (causal mask and self-attention).
is_generative_phase = (attn_mask_type == "causal" and attention_type == "self")
is_generative_phase = attn_mask_type == "causal" and attention_type == "self"
if is_generative_phase:
seq_len_offset = 8
hidden_states_generative = torch.randn(sequence_length-seq_len_offset, batch_size, hidden_size, dtype=precision, device="cuda")
hidden_states_generative = torch.randn(
sequence_length - seq_len_offset,
batch_size,
hidden_size,
dtype=precision,
device="cuda",
)
inp_generative = (hidden_states_generative, attention_mask, encoder_output)
if not use_fp8:
validate_result(fname, inp_generative, model, atol=1e-3, input_names=input_names, output_names=output_names)
validate_result(
fname,
inp_generative,
model,
atol=1e-3,
input_names=input_names,
output_names=output_names,
)
else:
validate_result(fname, inp_generative, model, atol=1e-2, is_fp8=use_fp8,
input_names=input_names, output_names=output_names, allow_cnt_errors=3)
validate_result(
fname,
inp_generative,
model,
atol=1e-2,
is_fp8=use_fp8,
input_names=input_names,
output_names=output_names,
allow_cnt_errors=3,
)
@pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
@pytest.mark.parametrize("output_layernorm", [
#True, # TO DO: handle this
False
])
@pytest.mark.parametrize(
"output_layernorm",
[
# True, # TO DO: handle this
False
],
)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("fuse_qkv_params", [False, True])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
......@@ -1201,12 +1311,16 @@ def test_export_transformer_layer(
ffn_hidden_size = 256
num_attention_heads = 4
input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
input_tensor = torch.rand(
sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
)
input_names = ["input", "attention_mask"]
attention_mask = None
if use_mask and attn_mask_type != "causal":
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision)
probs = 0.5 * torch.ones(
batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision
)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
inp = (input_tensor, attention_mask)
......@@ -1225,19 +1339,30 @@ def test_export_transformer_layer(
params_dtype=precision,
fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma,
activation=activation).to(device='cuda')
activation=activation,
).to(device="cuda")
do_export(model, inp, fname, use_fp8, input_names=input_names)
te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
if precision in (torch.bfloat16, ):
if precision in (torch.bfloat16,):
return
atol = 5e-1 if use_fp8 else (5e-1 if activation=="swiglu" else 1e-3)
validate_result(fname, inp, model, atol=atol, is_fp8=use_fp8, input_names=input_names, te_outputs=te_outputs)
atol = 5e-1 if use_fp8 else (5e-1 if activation == "swiglu" else 1e-3)
validate_result(
fname, inp, model, atol=atol, is_fp8=use_fp8, input_names=input_names, te_outputs=te_outputs
)
@pytest.mark.parametrize("use_fp8", [True])
@pytest.mark.parametrize("ln_scale_factor", [448*2])
@pytest.mark.parametrize("gemm_scale_factors", [(224, 224,),])
@pytest.mark.parametrize("ln_scale_factor", [448 * 2])
@pytest.mark.parametrize(
"gemm_scale_factors",
[
(
224,
224,
),
],
)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
def test_export_gemm_layernorm(
......@@ -1246,7 +1371,7 @@ def test_export_gemm_layernorm(
ln_scale_factor: float,
gemm_scale_factors: Tuple[float, float],
precision: torch.dtype,
zero_centered_gamma: bool
zero_centered_gamma: bool,
):
"""This is a regression test for testing that all LN inputs have the same type.
......@@ -1260,20 +1385,26 @@ def test_export_gemm_layernorm(
# Skip FP8 tests on non-hopper devices
if use_fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
class TestFP8_GemmLayernorm(nn.Module):
def __init__(self) -> None:
super().__init__()
normalized_shape = torch.Size(inp.shape[1:])
self.weight = torch.randn(*normalized_shape, dtype=precision, device="cuda")
self.bias = torch.zeros(*normalized_shape, dtype=precision, device="cuda")
self.eps = 1e-6 # An arbitrary small value
self.eps = 1e-6 # An arbitrary small value
self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT
self.meta = create_meta(ln_scale_factor)
self.fp8_type = tex.DType.kFloat8E4M3
self.gemm = FP8GemmModule(
precision, use_bias=False, gelu=False, scale_factors=gemm_scale_factors,
hidden_size=hidden_size, out_features=out_features)
precision,
use_bias=False,
gelu=False,
scale_factors=gemm_scale_factors,
hidden_size=hidden_size,
out_features=out_features,
)
def forward(self, inp, weight):
x = self.gemm(inp, weight)
......@@ -1286,14 +1417,16 @@ def test_export_gemm_layernorm(
self.fp8_tensor,
self.fp8_type,
0,
zero_centered_gamma)
zero_centered_gamma,
)
x = cast_from_fp8(
x,
self.meta,
self.fp8_tensor,
self.fp8_type,
tex.DType.kFloat32 if precision == torch.float32 else tex.DType.kFloat16)
tex.DType.kFloat32 if precision == torch.float32 else tex.DType.kFloat16,
)
return x
inp = torch.randn(hidden_size, in_features, dtype=precision, device="cuda")
......@@ -1302,14 +1435,21 @@ def test_export_gemm_layernorm(
high_prec_str = dtype2str(precision)
fp8_str = f"_fp8" if use_fp8 else ""
fname = f"te.gemm_layernorm{fp8_str}{high_prec_str}.onnx"
input_names = ['input', 'weight']
input_names = ["input", "weight"]
do_export(model, (inp, weight), fname, use_fp8=use_fp8, input_names=input_names)
te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8)
serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names)
if precision not in (torch.bfloat16, ):
if precision not in (torch.bfloat16,):
validate_result(
fname, (inp, weight), model, atol=5e-2, is_fp8=use_fp8, allow_cnt_errors=2,
input_names=input_names, te_outputs=te_outputs)
fname,
(inp, weight),
model,
atol=5e-2,
is_fp8=use_fp8,
allow_cnt_errors=2,
input_names=input_names,
te_outputs=te_outputs,
)
@skip_FP8
......@@ -1357,32 +1497,61 @@ def test_export_gpt_generation(
output_layernorm=output_layernorm,
params_dtype=precision,
fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma).to(device='cuda')
zero_centered_gamma=zero_centered_gamma,
).to(device="cuda")
# "Context phase": use full input sequence length
input_names = ["input"]
output_names = ["output"]
input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
input_tensor = torch.rand(
sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
)
inp = (input_tensor,)
do_export(model, inp, fname, use_fp8,
input_names=input_names, output_names=output_names,
dynamic_axes={"input": {0: "seq", 1:"bs"},
"output": {0: "seq", 1:"bs"}, })
do_export(
model,
inp,
fname,
use_fp8,
input_names=input_names,
output_names=output_names,
dynamic_axes={
"input": {0: "seq", 1: "bs"},
"output": {0: "seq", 1: "bs"},
},
)
te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names, output_names=output_names)
if precision not in (torch.bfloat16, ):
validate_result(fname, inp, model, atol=6e-3, is_fp8=use_fp8, input_names=input_names,
te_outputs=te_outputs)
serialize_inputs_outputs(
fname, inp, te_outputs, input_names=input_names, output_names=output_names
)
if precision not in (torch.bfloat16,):
validate_result(
fname,
inp,
model,
atol=6e-3,
is_fp8=use_fp8,
input_names=input_names,
te_outputs=te_outputs,
)
# "Generative phase": use a single input (sequence len=1). For FP8 we need to pad the sequence to mult of 8.
sequence_length = 1 if not use_fp8 else 8
input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
input_tensor = torch.rand(
sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
)
inp = (input_tensor, attention_mask)
te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
if precision not in (torch.bfloat16, ):
validate_result(fname, inp, model, atol=6e-3, is_fp8=use_fp8, input_names=input_names,
te_outputs=te_outputs)
if precision not in (torch.bfloat16,):
validate_result(
fname,
inp,
model,
atol=6e-3,
is_fp8=use_fp8,
input_names=input_names,
te_outputs=te_outputs,
)
@pytest.mark.parametrize("enabled", [True, False])
......
......@@ -19,6 +19,7 @@ from transformer_engine.pytorch.fp8 import (
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFP8Recipe:
......@@ -95,8 +96,8 @@ class TestFP8Recipe:
ref_amax_backward = amax_history_backward[-1]
else:
raise ValueError(f"{amax_compute_algo=} is not supported")
ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2 ** margin)
ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2 ** margin)
ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2**margin)
ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2**margin)
ref_scale_inv_forward = torch.reciprocal(ref_scale_forward)
update_weight_amax = is_first_microbatch is None or is_first_microbatch
if not update_weight_amax:
......@@ -128,8 +129,8 @@ class TestFP8Recipe:
ref_amax_backward = amax_history_backward[-1]
else:
raise ValueError(f"{amax_compute_algo=} is not supported")
ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2 ** margin)
ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2 ** margin)
ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2**margin)
ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2**margin)
ref_scale_inv_forward = torch.reciprocal(ref_scale_forward)
ref_scale_inv_backward = torch.reciprocal(ref_scale_backward)
......@@ -180,8 +181,9 @@ class TestFP8Recipe:
scaling_factor_compute_algo = None
if fused_update:
scaling_factor_compute_algo = (
lambda amax, scale, fp8_max, recipe:
te.fp8._default_sf_compute(amax, scale, fp8_max, recipe.margin)
lambda amax, scale, fp8_max, recipe: te.fp8._default_sf_compute(
amax, scale, fp8_max, recipe.margin
)
)
recipe = transformer_engine.common.recipe.DelayedScaling(
fp8_format=fp8_format, scaling_factor_compute_algo=scaling_factor_compute_algo
......@@ -205,7 +207,9 @@ class TestFP8Recipe:
# test different scenarios
if amax_case == "zero":
fp8_meta[forward_key].amax_history = torch.tensor([[0]], dtype=torch.float32, device="cuda")
fp8_meta[forward_key].amax_history = torch.tensor(
[[0]], dtype=torch.float32, device="cuda"
)
expected_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
elif amax_case == "tiny":
# calculate the minimum amax value that results in a FP32 maximum scale
......@@ -254,4 +258,6 @@ class TestFP8Recipe:
)
torch.testing.assert_close(fp8_meta[forward_key].scale, expected_scale)
torch.testing.assert_close(fp8_meta[forward_key].scale_inv, torch.reciprocal(expected_scale))
torch.testing.assert_close(
fp8_meta[forward_key].scale_inv, torch.reciprocal(expected_scale)
)
......@@ -30,7 +30,13 @@ from transformer_engine.pytorch import (
)
from transformer_engine.common import recipe
import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, gelu, cast_to_fp8, cast_from_fp8
from transformer_engine.pytorch.cpp_extensions import (
gemm,
fp8_gemm,
gelu,
cast_to_fp8,
cast_from_fp8,
)
from transformer_engine.pytorch.module.base import get_workspace
from test_onnx_export import create_meta
......@@ -75,6 +81,7 @@ class ModelConfig:
return False
return True
model_configs = {
"126m": ModelConfig(12, 2048, 2, 768, 12),
"small": ModelConfig(2, 32, 2, 64, 2),
......@@ -82,7 +89,7 @@ model_configs = {
}
fp8_recipes = [
None, # Handles non-FP8 case
None, # Handles non-FP8 case
recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3),
recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID),
recipe.DelayedScaling(
......@@ -126,6 +133,7 @@ batch_sizes_with_zero = [0, 1, 2]
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu"]
all_normalizations = ["LayerNorm", "RMSNorm"]
def _disable_wgrads(block):
for p in block.parameters():
p.requires_grad = False
......@@ -143,8 +151,17 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
optimizer = torch.optim.SGD(block.parameters(), lr=0.1)
# Placeholders used for capture.
static_input = torch.randn(config.seq_len, config.batch_size, config.hidden_size, device='cuda', dtype=dtype, requires_grad=True)
static_target = torch.randn(config.seq_len, config.batch_size, config.hidden_size, device='cuda', dtype=dtype)
static_input = torch.randn(
config.seq_len,
config.batch_size,
config.hidden_size,
device="cuda",
dtype=dtype,
requires_grad=True,
)
static_target = torch.randn(
config.seq_len, config.batch_size, config.hidden_size, device="cuda", dtype=dtype
)
real_input = torch.rand_like(static_input)
real_target = torch.rand_like(static_target)
......@@ -403,11 +420,7 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz
config = model_configs[model]
module = RMSNorm if normalization == "RMSNorm" else LayerNorm
block = (
module(config.hidden_size)
.to(dtype=torch.float32)
.cuda()
)
block = module(config.hidden_size).to(dtype=torch.float32).cuda()
_test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
......@@ -418,9 +431,9 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_layernorm_linear(dtype, fp8_recipe, model, skip_wgrad,
zero_centered_gamma, skip_dgrad,
normalization):
def test_sanity_layernorm_linear(
dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, normalization
):
config = model_configs[model]
if fp8_recipe is not None:
......@@ -480,7 +493,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
config = model_configs[model]
ffn_hidden_size = 4 * config.hidden_size
num_tokens = bs*config.seq_len
num_tokens = bs * config.seq_len
if fp8_recipe is not None:
if not fp8_available:
......@@ -490,15 +503,9 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
use_fp8 = fp8_recipe is not None
with fp8_model_init(enabled=use_fp8 and fp8_model_params):
te_linear = (
Linear(
config.hidden_size,
ffn_hidden_size,
bias=use_bias,
params_dtype=dtype
)
.cuda()
)
te_linear = Linear(
config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
).cuda()
inp_hidden_states = torch.randn(
num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
......@@ -518,9 +525,9 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad,
zero_centered_gamma, skip_dgrad, activation,
normalization):
def test_sanity_layernorm_mlp(
dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, activation, normalization
):
config = model_configs[model]
if fp8_recipe is not None:
......@@ -557,10 +564,18 @@ def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad,
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
@pytest.mark.parametrize("cpu_offload", all_boolean)
def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad,
zero_centered_gamma, bias, activation,
normalization, parallel_attention_mlp,
cpu_offload):
def test_sanity_gpt(
dtype,
fp8_recipe,
model,
skip_wgrad,
zero_centered_gamma,
bias,
activation,
normalization,
parallel_attention_mlp,
cpu_offload,
):
config = model_configs[model]
if fp8_recipe is not None:
......@@ -625,8 +640,7 @@ def test_sanity_gpt_126m():
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
normalization):
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
config = model_configs[model]
if fp8_recipe is not None:
......@@ -683,8 +697,7 @@ def test_sanity_bert_126m():
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
normalization):
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
config = model_configs[model]
if fp8_recipe is not None:
......@@ -845,7 +858,9 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
def test_sanity_gradient_accumulation_fusion(
dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma
):
config = model_configs[model]
if fp8_recipe is not None:
......@@ -885,8 +900,7 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
normalization):
def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
config = model_configs[model]
if fp8_recipe is not None:
......@@ -919,9 +933,10 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm
_test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad)
def test_model_multiple_cast():
a = torch.zeros((16,16), device="cuda")
m = Linear(16,32)
a = torch.zeros((16, 16), device="cuda")
m = Linear(16, 32)
y = m(a)
assert y.dtype == torch.float32
......@@ -937,15 +952,11 @@ def test_model_multiple_cast():
@pytest.mark.parametrize("offset", [1, 3, 5])
@pytest.mark.parametrize("datatype", param_types)
def test_sanity_gemm_with_unalignment(N, offset, datatype):
scratchpad = torch.randn(N*N + 2*offset, device="cuda", dtype=datatype)
scratchpad = torch.randn(N * N + 2 * offset, device="cuda", dtype=datatype)
inp = torch.reshape(scratchpad[offset:-offset], (N, N))
weight = torch.reshape(scratchpad[offset*2:], (N, N))
weight = torch.reshape(scratchpad[offset * 2 :], (N, N))
_, _, _ = gemm(
A=weight,
B=inp,
dtype=datatype,
workspace=get_workspace())
_, _, _ = gemm(A=weight, B=inp, dtype=datatype, workspace=get_workspace())
torch.cuda.synchronize()
......@@ -954,38 +965,35 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype):
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
def test_sanity_fp8_gemm_with_unalignment(N, datatype):
offset = 16
scratchpad = torch.randn(N*N + offset, device="cuda", dtype=datatype)
scratchpad = torch.randn(N * N + offset, device="cuda", dtype=datatype)
fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT
fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT
nb_inp_scales, nb_weight_scales = 1, N
scale_factor = 1.
scale_factor = 1.0
meta_inp = create_meta(scale_factor, nb_inp_scales)
meta_weight = create_meta(scale_factor, nb_weight_scales)
inp_type = tex.DType.kFloat8E4M3
weights_type = tex.DType.kFloat8E4M3
outp_type = datatype
scratchpad_fp8 = cast_to_fp8(
scratchpad,
meta_weight,
fp8_tensor_inp,
inp_type)
scratchpad_fp8 = cast_to_fp8(scratchpad, meta_weight, fp8_tensor_inp, inp_type)
inp_fp8 = torch.reshape(scratchpad_fp8[:-offset], (N, N))
weight_fp8 = torch.reshape(scratchpad_fp8[offset:], (N, N))
_, _ = fp8_gemm(
weight_fp8,
meta_weight.scale_inv,
fp8_tensor_weight,
inp_type,
inp_fp8,
meta_inp.scale_inv,
fp8_tensor_inp,
weights_type,
outp_type,
get_workspace(),
bias=None,
use_bias=False,
use_split_accumulator=False)
weight_fp8,
meta_weight.scale_inv,
fp8_tensor_weight,
inp_type,
inp_fp8,
meta_inp.scale_inv,
fp8_tensor_inp,
weights_type,
outp_type,
get_workspace(),
bias=None,
use_bias=False,
use_split_accumulator=False,
)
torch.cuda.synchronize()
......@@ -3,4 +3,5 @@
# See LICENSE for license information.
import transformer_engine.pytorch
print("OK")
......@@ -28,7 +28,7 @@ from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
def init_meta(size: int=1):
def init_meta(size: int = 1):
meta = tex.FP8TensorMeta()
meta.scale = torch.ones(size, dtype=torch.float32, device="cuda")
meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda")
......@@ -65,22 +65,18 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd
self.inp_type = tex.DType.kFloat8E4M3
self.weights_type = tex.DType.kFloat8E4M3
self.outp_type = precision
def get_fp8_weights_scratchpad(self, is_first_microbatch):
raise RuntimeError("Method get_fp8_weights_scratchpad is dummy and should not be invoked.")
raise RuntimeError(
"Method get_fp8_weights_scratchpad is dummy and should not be invoked."
)
def forward(self, inp, weight):
inp_fp8 = cast_to_fp8(
inp,
self.meta_inp,
self.fp8_tensor_inp,
self.inp_type)
inp_fp8 = cast_to_fp8(inp, self.meta_inp, self.fp8_tensor_inp, self.inp_type)
weight_fp8 = cast_to_fp8(
weight,
self.meta_weight,
self.fp8_tensor_weight,
self.weights_type)
weight, self.meta_weight, self.fp8_tensor_weight, self.weights_type
)
ret = fp8_gemm(
weight_fp8,
......@@ -95,20 +91,33 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd
get_workspace(),
bias=self.bias,
use_bias=self.use_bias,
use_split_accumulator=False)
use_split_accumulator=False,
)
return ret
model_in = Test_TE_Export(precision, True)
with te.fp8_autocast(enabled=True):
model_in.init_fp8_metadata()
# scaling fwd
model_in.fp8_meta["scaling_fwd"].scale = torch.ones(3, dtype=torch.float32, device="cuda") * scale_fwd
model_in.fp8_meta["scaling_fwd"].scale_inv = torch.ones(3, dtype=torch.float32, device="cuda") / scale_fwd
model_in.fp8_meta["scaling_fwd"].amax_history = torch.ones(3, dtype=torch.float32, device="cuda") * history_fwd
model_in.fp8_meta["scaling_fwd"].scale = (
torch.ones(3, dtype=torch.float32, device="cuda") * scale_fwd
)
model_in.fp8_meta["scaling_fwd"].scale_inv = (
torch.ones(3, dtype=torch.float32, device="cuda") / scale_fwd
)
model_in.fp8_meta["scaling_fwd"].amax_history = (
torch.ones(3, dtype=torch.float32, device="cuda") * history_fwd
)
# scaling bwd
model_in.fp8_meta["scaling_bwd"].scale = torch.ones(2, dtype=torch.float32, device="cuda") * scale_bwd
model_in.fp8_meta["scaling_bwd"].scale_inv = torch.ones(2, dtype=torch.float32, device="cuda") / scale_bwd
model_in.fp8_meta["scaling_bwd"].amax_history = torch.ones(2, dtype=torch.float32, device="cuda") * history_bwd
model_in.fp8_meta["scaling_bwd"].scale = (
torch.ones(2, dtype=torch.float32, device="cuda") * scale_bwd
)
model_in.fp8_meta["scaling_bwd"].scale_inv = (
torch.ones(2, dtype=torch.float32, device="cuda") / scale_bwd
)
model_in.fp8_meta["scaling_bwd"].amax_history = (
torch.ones(2, dtype=torch.float32, device="cuda") * history_bwd
)
torch.save(model_in.state_dict(), tmp_filename)
......@@ -117,13 +126,27 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd
model_out.eval()
# scaling fwd
assert torch.allclose(model_in.fp8_meta["scaling_fwd"].scale, model_out.fp8_meta["scaling_fwd"].scale)
assert torch.allclose(model_in.fp8_meta["scaling_fwd"].scale_inv, model_out.fp8_meta["scaling_fwd"].scale_inv)
assert torch.allclose(model_in.fp8_meta["scaling_fwd"].amax_history, model_out.fp8_meta["scaling_fwd"].amax_history)
assert torch.allclose(
model_in.fp8_meta["scaling_fwd"].scale, model_out.fp8_meta["scaling_fwd"].scale
)
assert torch.allclose(
model_in.fp8_meta["scaling_fwd"].scale_inv, model_out.fp8_meta["scaling_fwd"].scale_inv
)
assert torch.allclose(
model_in.fp8_meta["scaling_fwd"].amax_history,
model_out.fp8_meta["scaling_fwd"].amax_history,
)
# scaling bwd
assert torch.allclose(model_in.fp8_meta["scaling_bwd"].scale, model_out.fp8_meta["scaling_bwd"].scale)
assert torch.allclose(model_in.fp8_meta["scaling_bwd"].scale_inv, model_out.fp8_meta["scaling_bwd"].scale_inv)
assert torch.allclose(model_in.fp8_meta["scaling_bwd"].amax_history, model_out.fp8_meta["scaling_bwd"].amax_history)
assert torch.allclose(
model_in.fp8_meta["scaling_bwd"].scale, model_out.fp8_meta["scaling_bwd"].scale
)
assert torch.allclose(
model_in.fp8_meta["scaling_bwd"].scale_inv, model_out.fp8_meta["scaling_bwd"].scale_inv
)
assert torch.allclose(
model_in.fp8_meta["scaling_bwd"].amax_history,
model_out.fp8_meta["scaling_bwd"].amax_history,
)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
......@@ -132,7 +155,7 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd
def test_fp8_model_checkpoint(
save_fp8_model: bool,
load_fp8_model: bool,
dims: Iterable[int] = [32,32],
dims: Iterable[int] = [32, 32],
dtype: torch.dtype = torch.float32,
device: Union[torch.device, str] = "cuda",
):
......@@ -153,7 +176,7 @@ def test_fp8_model_checkpoint(
with te.fp8_autocast():
y_ref = model(x.detach().clone()).detach().clone()
fp8_meta_ref = { "scaling_fwd": {}, "scaling_bwd": {} }
fp8_meta_ref = {"scaling_fwd": {}, "scaling_bwd": {}}
with te.fp8_autocast(), torch.no_grad():
fp8_meta_fwd = model.fp8_meta["scaling_fwd"]
fp8_meta_bwd = model.fp8_meta["scaling_bwd"]
......@@ -168,7 +191,7 @@ def test_fp8_model_checkpoint(
fp8_meta_bwd.scale.copy_(fp8_meta_bwd_ref["scale"])
fp8_meta_bwd.scale_inv.copy_(fp8_meta_bwd_ref["scale_inv"])
del fp8_meta_fwd, fp8_meta_bwd
# [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ]
# This line copies the fp8 scale_inv from the model metadata to the weight fp8 tensor.
# The sole purpose of the following lines is to set the scale_inv of the weight tensor, which is the simplest method.
......@@ -226,15 +249,14 @@ def test_fp8_model_checkpoint(
with pytest.raises(AssertionError):
torch.testing.assert_close(y, y_ref, **tols)
# [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ]
# When save_fp8_model=True, we load a model with weights in high precision,
# When save_fp8_model=True, we load a model with weights in high precision,
# which does not include _scale_inv,
# but has the fp8 scaling factor in the meta data. This scenario can occur
# but has the fp8 scaling factor in the meta data. This scenario can occur
# when using te.fp8_autocast(enabled=False, calibrating=True).
#
# In such cases, the default behavior of load_state_dict is incorrect - it loads tensors first,
# followed by the fp8 metadata. This results in an incorrect _scale_inv for the tensor. This behavior
# followed by the fp8 metadata. This results in an incorrect _scale_inv for the tensor. This behavior
# is corrected by overriding the _load_state_dict method from PyTorch in TransformerEngineBaseModule,
# to load the fp8 metadata before loading tensors.
#
......@@ -262,4 +284,6 @@ def test_fp8_model_checkpoint(
# We need to ensure that the tensor's scale_inv parameter matches its meta data.
# This is crucial to avoid confusion about which value is correct.
meta_index = model.weight._fp8_meta_index
torch.testing.assert_close(model.weight._scale_inv.item(), fp8_meta_fwd_ref["scale_inv"][meta_index].item())
\ No newline at end of file
torch.testing.assert_close(
model.weight._scale_inv.item(), fp8_meta_fwd_ref["scale_inv"][meta_index].item()
)
......@@ -4,74 +4,59 @@
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/activation.h>
#include <cuda_runtime.h>
#include "../util/vectorized_pointwise.h"
#include "../common.h"
#include <transformer_engine/activation.h>
#include "../common.h"
#include "../util/vectorized_pointwise.h"
namespace transformer_engine {
template <typename ComputeType, typename Param,
ComputeType (*OP)(ComputeType, const Param&)>
void act_fn(const Tensor &input,
Tensor *output,
cudaStream_t stream) {
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
void act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(input, "act_lu_input");
CheckOutputTensor(*output, "act_lu_output");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
const size_t tot_elts = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, Param, OP>(
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const ComputeType*>(output->scale.dptr),
reinterpret_cast<ComputeType*>(output->amax.dptr),
tot_elts,
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, Param, OP>(
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const ComputeType *>(output->scale.dptr),
reinterpret_cast<ComputeType *>(output->amax.dptr), tot_elts, {},
stream);); // NOLINT(*)
); // NOLINT(*)
}
template <typename ComputeType, typename Param,
ComputeType (*OP)(ComputeType, const Param&)>
void dact_fn(const Tensor &grad,
const Tensor &input,
Tensor *output,
cudaStream_t stream) {
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
void dact_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(input, "dact_lu_input");
CheckInputTensor(grad, "dact_lu_input_grad");
CheckOutputTensor(*output, "dact_lu_output");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
NVTE_CHECK(input.data.dtype == grad.data.dtype,
"Input and incoming gradient types must match.");
NVTE_CHECK(input.data.dtype == grad.data.dtype, "Input and incoming gradient types must match.");
const size_t tot_elts = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryGradKernelLauncher<nvec, Param, OP>(
reinterpret_cast<const IType*>(grad.data.dptr),
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const ComputeType*>(output->scale.dptr),
reinterpret_cast<ComputeType*>(output->amax.dptr),
tot_elts,
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryGradKernelLauncher<nvec, Param, OP>(
reinterpret_cast<const IType *>(grad.data.dptr),
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const ComputeType *>(output->scale.dptr),
reinterpret_cast<ComputeType *>(output->amax.dptr), tot_elts, {},
stream);); // NOLINT(*)
); // NOLINT(*)
}
template <typename ComputeType, typename Param,
ComputeType (*OP)(ComputeType, const Param&)>
void gated_act_fn(const Tensor &input,
Tensor *output,
cudaStream_t stream) {
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
void gated_act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(input, "gated_act_input");
CheckOutputTensor(*output, "gated_act_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
......@@ -81,29 +66,23 @@ void gated_act_fn(const Tensor &input,
NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2,
"Input shape[1] must be 2x larger than output shape[1].");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
GatedActivationKernelLauncher<nvec, ComputeType, Param, OP>(
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const ComputeType*>(output->scale.dptr),
reinterpret_cast<ComputeType*>(output->amax.dptr),
output->data.shape[0],
output->data.shape[1],
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
GatedActivationKernelLauncher<nvec, ComputeType, Param, OP>(
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const ComputeType *>(output->scale.dptr),
reinterpret_cast<ComputeType *>(output->amax.dptr), output->data.shape[0],
output->data.shape[1], {},
stream);); // NOLINT(*)
); // NOLINT(*)
}
template <typename ComputeType, typename Param,
ComputeType (*OP1)(ComputeType, const Param&),
ComputeType (*OP2)(ComputeType, const Param&)>
void dgated_act_fn(const Tensor &grad,
const Tensor &input,
Tensor *output,
cudaStream_t stream) {
template <typename ComputeType, typename Param, ComputeType (*OP1)(ComputeType, const Param &),
ComputeType (*OP2)(ComputeType, const Param &)>
void dgated_act_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(grad, "dgated_act_grad");
CheckInputTensor(input, "dgated_act_input");
CheckOutputTensor(*output, "dgated_act_output");
......@@ -114,23 +93,19 @@ void dgated_act_fn(const Tensor &grad,
"Output shape[0] must be equal to grad shape[0].");
NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2,
"Output shape[1] must be 2x larger than grad shape[1].");
NVTE_CHECK(input.data.shape == output->data.shape,
"Input and output shapes must match.");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
DGatedActivationKernelLauncher<nvec, ComputeType, Param, OP1, OP2>(
reinterpret_cast<const IType*>(grad.data.dptr),
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
grad.data.shape[0],
grad.data.shape[1],
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
DGatedActivationKernelLauncher<nvec, ComputeType, Param, OP1, OP2>(
reinterpret_cast<const IType *>(grad.data.dptr),
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr), grad.data.shape[0], grad.data.shape[1],
{},
stream);); // NOLINT(*)
); // NOLINT(*)
}
} // namespace transformer_engine
......@@ -3,96 +3,69 @@
*
* See LICENSE for license information.
************************************************************************/
#include "./activation_template.h"
#include "../util/math.h"
#include "./activation_template.h"
void nvte_gelu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_gelu);
using namespace transformer_engine;
act_fn<fp32, Empty, gelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
reinterpret_cast<Tensor*>(output), stream);
}
void nvte_dgelu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dgelu);
using namespace transformer_engine;
dact_fn<fp32, Empty, dgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
reinterpret_cast<Tensor*>(output), stream);
}
void nvte_geglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_geglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, gelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
reinterpret_cast<Tensor*>(output), stream);
}
void nvte_dgeglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, gelu<fp32, fp32>, dgelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
*reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
}
void nvte_qgelu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_qgelu);
using namespace transformer_engine;
act_fn<fp32, Empty, qgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
reinterpret_cast<Tensor*>(output), stream);
}
void nvte_dqgelu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgelu);
using namespace transformer_engine;
dact_fn<fp32, Empty, dqgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
}
void nvte_qgeglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_qgeglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, qgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
reinterpret_cast<Tensor*>(output), stream);
}
void nvte_dqgeglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgeglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, qgelu<fp32, fp32>, dqgelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
*reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
}
......@@ -4,96 +4,69 @@
* See LICENSE for license information.
************************************************************************/
#include "./activation_template.h"
#include "../util/math.h"
#include "./activation_template.h"
void nvte_relu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_relu);
using namespace transformer_engine;
act_fn<fp32, Empty, relu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
reinterpret_cast<Tensor*>(output), stream);
}
void nvte_drelu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_drelu);
using namespace transformer_engine;
dact_fn<fp32, Empty, drelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
}
void nvte_reglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_reglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, relu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
reinterpret_cast<Tensor*>(output), stream);
}
void nvte_dreglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dreglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, relu<fp32, fp32>, drelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
*reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
}
void nvte_srelu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_srelu);
using namespace transformer_engine;
act_fn<fp32, Empty, srelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
reinterpret_cast<Tensor*>(output), stream);
}
void nvte_dsrelu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsrelu);
using namespace transformer_engine;
dact_fn<fp32, Empty, dsrelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
}
void nvte_sreglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_sreglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, srelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
reinterpret_cast<Tensor*>(output), stream);
}
void nvte_dsreglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsreglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, srelu<fp32, fp32>, dsrelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
*reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment