Unverified Commit 9416519d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Apply formatting (#929)



* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d99142a0
...@@ -9,9 +9,10 @@ from transformer_engine.pytorch.attention import DotProductAttention ...@@ -9,9 +9,10 @@ from transformer_engine.pytorch.attention import DotProductAttention
import transformer_engine_torch as tex import transformer_engine_torch as tex
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn
dtypes={'fp16' : torch.float16, 'bf16' : torch.bfloat16} dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16}
def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend='FlashAttention'):
def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention"):
"""Test DotProductAttention module with context parallelism""" """Test DotProductAttention module with context parallelism"""
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
...@@ -22,11 +23,13 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= ...@@ -22,11 +23,13 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
if kernel_backend == "FusedAttention": if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model] config = model_configs_fused_attn[model]
if qkv_format == 'thd' and (config.num_heads != config.num_gqa_groups or config.attn_bias_type == "post_scale_bias"): if qkv_format == "thd" and (
config.num_heads != config.num_gqa_groups or config.attn_bias_type == "post_scale_bias"
):
return return
rank = int(os.getenv('RANK', '0')) rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv('WORLD_SIZE', '1')) world_size = int(os.getenv("WORLD_SIZE", "1"))
if dist.is_initialized(): if dist.is_initialized():
world_size = dist.get_world_size() world_size = dist.get_world_size()
...@@ -38,51 +41,76 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= ...@@ -38,51 +41,76 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
print(f"[INFO] world_size:{world_size}, rank:{rank}") print(f"[INFO] world_size:{world_size}, rank:{rank}")
dist.init_process_group(backend='nccl', world_size=world_size, rank=rank) dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
# create flash attn comm group for CP # create flash attn comm group for CP
cp_comm_ranks = range(world_size) cp_comm_ranks = range(world_size)
assert(rank in cp_comm_ranks) assert rank in cp_comm_ranks
cp_comm_group = dist.new_group(cp_comm_ranks, backend='nccl') cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
assert config.attn_mask_type in ['causal', 'no_mask'], f"{config.attn_mask_type} is an unsupported attention mask type!" assert config.attn_mask_type in [
"causal",
"no_mask",
], f"{config.attn_mask_type} is an unsupported attention mask type!"
if kernel_backend == 'FusedAttention' and qkv_format == 'thd': if kernel_backend == "FusedAttention" and qkv_format == "thd":
if 'causal' in config.attn_mask_type: if "causal" in config.attn_mask_type:
config.attn_mask_type = 'padding_causal' config.attn_mask_type = "padding_causal"
else: else:
config.attn_mask_type = 'padding' config.attn_mask_type = "padding"
# instantiate core attn module # instantiate core attn module
core_attn = DotProductAttention(config.num_heads, core_attn = DotProductAttention(
config.head_dim, config.num_heads,
num_gqa_groups=config.num_gqa_groups, config.head_dim,
attention_dropout=config.dropout_p, num_gqa_groups=config.num_gqa_groups,
qkv_format=qkv_format, attention_dropout=config.dropout_p,
attn_mask_type=config.attn_mask_type) qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type,
)
core_attn = core_attn.cuda() core_attn = core_attn.cuda()
# create flash attn inputs # create flash attn inputs
if qkv_format == "bshd": if qkv_format == "bshd":
q_input_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim) q_input_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim)
kv_input_shape = (config.batch_size, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim) kv_input_shape = (
attn_output_shape = (config.batch_size, config.max_seqlen_q, config.num_heads*config.head_dim) config.batch_size,
config.max_seqlen_kv,
config.num_gqa_groups,
config.head_dim,
)
attn_output_shape = (
config.batch_size,
config.max_seqlen_q,
config.num_heads * config.head_dim,
)
cu_seqlens_q = None cu_seqlens_q = None
cu_seqlens_kv = None cu_seqlens_kv = None
elif qkv_format == "sbhd": elif qkv_format == "sbhd":
q_input_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim) q_input_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim)
kv_input_shape = (config.max_seqlen_kv, config.batch_size, config.num_gqa_groups, config.head_dim) kv_input_shape = (
attn_output_shape = (config.max_seqlen_q, config.batch_size, config.num_heads*config.head_dim) config.max_seqlen_kv,
config.batch_size,
config.num_gqa_groups,
config.head_dim,
)
attn_output_shape = (
config.max_seqlen_q,
config.batch_size,
config.num_heads * config.head_dim,
)
cu_seqlens_q = None cu_seqlens_q = None
cu_seqlens_kv = None cu_seqlens_kv = None
elif qkv_format == "thd": elif qkv_format == "thd":
seqlens_q = torch.randint(world_size * 2, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32) seqlens_q = torch.randint(world_size * 2, config.max_seqlen_q + 1, [config.batch_size]).to(
torch.int32
)
seqlens_q = seqlens_q - seqlens_q % (world_size * 2) seqlens_q = seqlens_q - seqlens_q % (world_size * 2)
cu_seqlens_q = torch.cat([torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0)]) cu_seqlens_q = torch.cat([torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0)])
cu_seqlens_kv = cu_seqlens_q cu_seqlens_kv = cu_seqlens_q
q_input_shape = (cu_seqlens_q[-1], config.num_heads, config.head_dim) q_input_shape = (cu_seqlens_q[-1], config.num_heads, config.head_dim)
kv_input_shape = (cu_seqlens_kv[-1], config.num_gqa_groups, config.head_dim) kv_input_shape = (cu_seqlens_kv[-1], config.num_gqa_groups, config.head_dim)
attn_output_shape = (cu_seqlens_q[-1], config.num_heads*config.head_dim) attn_output_shape = (cu_seqlens_q[-1], config.num_heads * config.head_dim)
cu_seqlens_q = cu_seqlens_q.to(torch.int32).cuda() cu_seqlens_q = cu_seqlens_q.to(torch.int32).cuda()
cu_seqlens_kv = cu_seqlens_kv.to(torch.int32).cuda() cu_seqlens_kv = cu_seqlens_kv.to(torch.int32).cuda()
else: else:
...@@ -111,7 +139,9 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= ...@@ -111,7 +139,9 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
for x in [q, k, v]: for x in [q, k, v]:
x.requires_grad = True x.requires_grad = True
out = core_attn( out = core_attn(
q, k, v, q,
k,
v,
core_attention_bias_type=config.attn_bias_type, core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias, core_attention_bias=bias,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
...@@ -120,17 +150,28 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= ...@@ -120,17 +150,28 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
out.backward(dout) out.backward(dout)
# run core_attn wit CP # run core_attn wit CP
q_, k_, v_, dout_, *rest = [x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])] q_, k_, v_, dout_, *rest = [
x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])
]
bias_ = rest[0] if len(rest) else None bias_ = rest[0] if len(rest) else None
if qkv_format == "bshd" or qkv_format == "sbhd": if qkv_format == "bshd" or qkv_format == "sbhd":
seq_dim = qkv_format.index('s') seq_dim = qkv_format.index("s")
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \ q_, k_, v_, dout_ = [
for x in [q_, k_, v_, dout_]] x.view(
seq_idx = torch.tensor([rank, 2*world_size-rank-1], device=q_.device) *x.shape[:seq_dim],
2 * world_size,
x.shape[seq_dim] // (2 * world_size),
*x.shape[(seq_dim + 1) :],
)
for x in [q_, k_, v_, dout_]
]
seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=q_.device)
q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]] q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]]
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim+2):]) for x in [q_, k_, v_, dout_]] q_, k_, v_, dout_ = [
x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) for x in [q_, k_, v_, dout_]
]
elif qkv_format == "thd": elif qkv_format == "thd":
seq_idx_q = tex.thd_get_partitioned_indices(cu_seqlens_q, q_.size(0), world_size, rank) seq_idx_q = tex.thd_get_partitioned_indices(cu_seqlens_q, q_.size(0), world_size, rank)
seq_idx_kv = tex.thd_get_partitioned_indices(cu_seqlens_kv, k_.size(0), world_size, rank) seq_idx_kv = tex.thd_get_partitioned_indices(cu_seqlens_kv, k_.size(0), world_size, rank)
q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]] q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]]
k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]] k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]]
...@@ -140,14 +181,18 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= ...@@ -140,14 +181,18 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
assert False, f"{qkv_format} is an unsupported qkv_format!" assert False, f"{qkv_format} is an unsupported qkv_format!"
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
if bias_ is not None: if bias_ is not None:
bias_ = bias_.view(*bias_.shape[:-2], 2*world_size, bias_.shape[-2]//(2*world_size), bias_.shape[-1]) bias_ = bias_.view(
*bias_.shape[:-2], 2 * world_size, bias_.shape[-2] // (2 * world_size), bias_.shape[-1]
)
bias_ = bias_.index_select(2, seq_idx) bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream()) core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream())
max_seqlen_q = config.max_seqlen_q max_seqlen_q = config.max_seqlen_q
max_seqlen_kv = config.max_seqlen_kv max_seqlen_kv = config.max_seqlen_kv
out_ = core_attn( out_ = core_attn(
q_, k_, v_, q_,
k_,
v_,
core_attention_bias_type=config.attn_bias_type, core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias_, core_attention_bias=bias_,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
...@@ -158,23 +203,32 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= ...@@ -158,23 +203,32 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
out_.backward(dout_) out_.backward(dout_)
for x in [out_, q_.grad, k_.grad, v_.grad]: for x in [out_, q_.grad, k_.grad, v_.grad]:
assert(torch.all(~torch.isnan(x))) assert torch.all(~torch.isnan(x))
assert(torch.all(~torch.isinf(x))) assert torch.all(~torch.isinf(x))
# compare results with and without CP # compare results with and without CP
tols = dict(atol=5e-3, rtol=5e-3) tols = dict(atol=5e-3, rtol=5e-3)
if dtype == 'bf16': if dtype == "bf16":
if config.num_heads == config.num_gqa_groups: if config.num_heads == config.num_gqa_groups:
tols = dict(atol=2.5e-2, rtol=2.5e-2) tols = dict(atol=2.5e-2, rtol=2.5e-2)
else: else:
tols = dict(atol=3.5e-2, rtol=3.5e-2) tols = dict(atol=3.5e-2, rtol=3.5e-2)
if qkv_format == "bshd" or qkv_format == "sbhd": if qkv_format == "bshd" or qkv_format == "sbhd":
dq, dk, dv, out = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \ dq, dk, dv, out = [
for x in [q.grad, k.grad, v.grad, out]] x.view(
*x.shape[:seq_dim],
2 * world_size,
x.shape[seq_dim] // (2 * world_size),
*x.shape[(seq_dim + 1) :],
)
for x in [q.grad, k.grad, v.grad, out]
]
dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]] dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]]
dq_, dk_, dv_, out_ = [x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim]//2, *x.shape[(seq_dim+1):]) \ dq_, dk_, dv_, out_ = [
for x in [q_.grad, k_.grad, v_.grad, out_]] x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :])
for x in [q_.grad, k_.grad, v_.grad, out_]
]
elif qkv_format == "thd": elif qkv_format == "thd":
dq, out = [x.index_select(0, seq_idx_q).contiguous().view(-1) for x in [q.grad, out]] dq, out = [x.index_select(0, seq_idx_q).contiguous().view(-1) for x in [q.grad, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous().view(-1) for x in [k.grad, v.grad]] dk, dv = [x.index_select(0, seq_idx_kv).contiguous().view(-1) for x in [k.grad, v.grad]]
...@@ -208,9 +262,11 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= ...@@ -208,9 +262,11 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
else: else:
assert False, f"{qkv_format} is an unsupported qkv_format!" assert False, f"{qkv_format} is an unsupported qkv_format!"
def main(**kwargs): def main(**kwargs):
run_dpa_with_cp(**kwargs) run_dpa_with_cp(**kwargs)
if __name__ == "__main__": if __name__ == "__main__":
kwargs = dict(arg.split('=') for arg in sys.argv[2:]) kwargs = dict(arg.split("=") for arg in sys.argv[2:])
main(**kwargs) main(**kwargs)
...@@ -91,10 +91,10 @@ class ModelConfig: ...@@ -91,10 +91,10 @@ class ModelConfig:
self.max_seqlen_q = max_seqlen_q self.max_seqlen_q = max_seqlen_q
self.max_seqlen_kv = max_seqlen_kv self.max_seqlen_kv = max_seqlen_kv
self.dropout_p = dropout_p self.dropout_p = dropout_p
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.attn_bias_type = attn_bias_type self.attn_bias_type = attn_bias_type
self.alibi_type = alibi_type self.alibi_type = alibi_type
self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross" self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross"
self.num_layers = num_layers self.num_layers = num_layers
self.bias_shape = bias_shape self.bias_shape = bias_shape
...@@ -184,28 +184,29 @@ def _is_flash_attention_supported(config: ModelConfig) -> bool: ...@@ -184,28 +184,29 @@ def _is_flash_attention_supported(config: ModelConfig) -> bool:
return False return False
return True return True
def _is_unfused_attention_supported( def _is_unfused_attention_supported(
config: ModelConfig, config: ModelConfig,
qkv_format: str, qkv_format: str,
) -> bool: ) -> bool:
"""Check if UnfusedDotProductAttention supports a model configuration""" """Check if UnfusedDotProductAttention supports a model configuration"""
if ("padding" in config.attn_mask_type): if "padding" in config.attn_mask_type:
return False return False
if ("causal" in config.attn_mask_type and config.attn_type == 'cross'): if "causal" in config.attn_mask_type and config.attn_type == "cross":
return False return False
if qkv_format == 'thd': if qkv_format == "thd":
return False return False
return True return True
model_configs_base = { model_configs_base = {
# test: b, h, hg, d, sq, skv, p, mask, bias # attn , backend # test: b, h, hg, d, sq, skv, p, mask, bias # attn , backend
"base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0 "base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0
"base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), # cross, 0 "base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), # cross, 0
"base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), # self , 1 "base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), # self , 1
"base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1 "base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1
"base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), # inference "base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
"base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), # inference "base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
} }
...@@ -221,13 +222,13 @@ def get_swa(seq_q, seq_kv, w=None): ...@@ -221,13 +222,13 @@ def get_swa(seq_q, seq_kv, w=None):
if w is None: if w is None:
w = torch.randint(0, seq_kv, [2], dtype=torch.int32, device="cuda") w = torch.randint(0, seq_kv, [2], dtype=torch.int32, device="cuda")
m = torch.ones(seq_q, seq_kv, dtype=torch.bool, device="cuda") m = torch.ones(seq_q, seq_kv, dtype=torch.bool, device="cuda")
mu = torch.triu(m, diagonal=seq_kv-seq_q-w[0]) mu = torch.triu(m, diagonal=seq_kv - seq_q - w[0])
ml = torch.tril(mu, diagonal=seq_kv-seq_q+w[1]) ml = torch.tril(mu, diagonal=seq_kv - seq_q + w[1])
ml = ~ ml ml = ~ml
return w, ml return w, ml
@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base]) @pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", model_configs_base.keys()) @pytest.mark.parametrize("model", model_configs_base.keys())
...@@ -236,8 +237,9 @@ def get_swa(seq_q, seq_kv, w=None): ...@@ -236,8 +237,9 @@ def get_swa(seq_q, seq_kv, w=None):
@pytest.mark.parametrize("qkv_layout", [None]) @pytest.mark.parametrize("qkv_layout", [None])
@pytest.mark.parametrize("swa", [False]) @pytest.mark.parametrize("swa", [False])
@pytest.mark.parametrize("pad_between_seqs", [False]) @pytest.mark.parametrize("pad_between_seqs", [False])
def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, def test_dot_product_attention(
workspace_opt, qkv_layout, swa, pad_between_seqs): dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa, pad_between_seqs
):
"""Test DotProductAttention module""" """Test DotProductAttention module"""
# Get configs # Get configs
...@@ -251,36 +253,43 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, ...@@ -251,36 +253,43 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn,
else: else:
qkv_layout = "sbhd_sb2hd" qkv_layout = "sbhd_sb2hd"
if "3" in qkv_layout and config.attn_type == "cross": if "3" in qkv_layout and config.attn_type == "cross":
pytest.skip( pytest.skip("No need to test this layout for cross attention")
"No need to test this layout for cross attention"
)
# Skip if only unfused backend is supported # Skip if only unfused backend is supported
qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()]) qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
unfused_attn_supported = _is_unfused_attention_supported(config, qkv_format) unfused_attn_supported = _is_unfused_attention_supported(config, qkv_format)
if config.max_seqlen_q <= 512 and config.max_seqlen_kv <= 512: if config.max_seqlen_q <= 512 and config.max_seqlen_kv <= 512:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported( fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config, dtype, qkv_layout=qkv_layout, config,
dtype,
qkv_layout=qkv_layout,
) )
if swa: if swa:
fused_attn_supported = False fused_attn_supported = False
flash_attn_supported = _is_flash_attention_supported(config) flash_attn_supported = _is_flash_attention_supported(config)
if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2: if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.") pytest.skip("Less than two backends to compare.")
if (qkv_format == 'thd' and 'padding' not in config.attn_mask_type): if qkv_format == "thd" and "padding" not in config.attn_mask_type:
pytest.skip("THD layout requires padding/padding_causal mask type.") pytest.skip("THD layout requires padding/padding_causal mask type.")
# d=256 is supported by cuDNN 9.0+ for inference but not training # d=256 is supported by cuDNN 9.0+ for inference but not training
is_training = (config.head_dim <= 128) is_training = config.head_dim <= 128
# UnfusedDotProductAttention backend # UnfusedDotProductAttention backend
if unfused_attn_supported: if unfused_attn_supported:
if swa: if swa:
attn_mask_type = config.attn_mask_type attn_mask_type = config.attn_mask_type
config.attn_mask_type = "arbitrary" config.attn_mask_type = "arbitrary"
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention( unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
dtype, config, "UnfusedDotProductAttention", dtype,
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training, config,
"UnfusedDotProductAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
) )
if swa: if swa:
config.attn_mask_type = attn_mask_type config.attn_mask_type = attn_mask_type
...@@ -289,51 +298,79 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, ...@@ -289,51 +298,79 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn,
if fused_attn_supported: if fused_attn_supported:
if len(fused_attn_backend) == 1: if len(fused_attn_backend) == 1:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention", dtype,
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training, config,
"FusedAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
) )
if len(fused_attn_backend) == 2: if len(fused_attn_backend) == 2:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention", dtype,
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training, config,
"FusedAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
) )
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention( fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
dtype, config, "FusedAttention", dtype,
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training, config,
"FusedAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
) )
# FlashAttention backend # FlashAttention backend
if flash_attn_supported: if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, config, "FlashAttention", dtype,
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training, config,
"FlashAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
) )
if unfused_attn_supported and fused_attn_supported: if unfused_attn_supported and fused_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs fused attn") logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
for i,_ in enumerate(unfused_attn_bwd): for i, _ in enumerate(unfused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols) torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
if unfused_attn_supported and flash_attn_supported: if unfused_attn_supported and flash_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs flash attn") logging.info("[test_dot_product_attention]: unfused attn vs flash attn")
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols) torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
for i,_ in enumerate(flash_attn_bwd): for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols) torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_supported and flash_attn_supported: if fused_attn_supported and flash_attn_supported:
logging.info("[test_dot_product_attention]: fused attn vs flash attn") logging.info("[test_dot_product_attention]: fused attn vs flash attn")
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i,_ in enumerate(flash_attn_bwd): for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols) torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_supported and len(fused_attn_backend) == 2: if fused_attn_supported and len(fused_attn_backend) == 2:
logging.info("[test_dot_product_attention]: fused attn backend 0 vs 1") logging.info("[test_dot_product_attention]: fused attn backend 0 vs 1")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols) torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
for i,_ in enumerate(fused_attn_bwd): for i, _ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols) torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)
@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base]) @pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", ["base_1_1", "base_2_1"]) @pytest.mark.parametrize("model", ["base_1_1", "base_2_1"])
...@@ -344,22 +381,22 @@ def test_dpa_checkpoint(dtype, model_configs, model): ...@@ -344,22 +381,22 @@ def test_dpa_checkpoint(dtype, model_configs, model):
model_configs_mask = { model_configs_mask = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"), "mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
"mask_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"), "mask_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"),
"mask_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), "mask_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"mask_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), "mask_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"mask_3_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"), "mask_3_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"),
"mask_3_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), "mask_3_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"mask_4_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding", "no_bias"), "mask_4_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"mask_4_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), "mask_4_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"mask_5_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), "mask_5_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
"mask_5_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding_causal", "no_bias"), "mask_5_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding_causal", "no_bias"),
"mask_6_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), "mask_6_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), "mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
} }
@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_mask]) @pytest.mark.parametrize("model_configs", [model_configs_mask])
@pytest.mark.parametrize("model", model_configs_mask.keys()) @pytest.mark.parametrize("model", model_configs_mask.keys())
...@@ -370,34 +407,48 @@ def test_dpa_mask(dtype, model_configs, model): ...@@ -370,34 +407,48 @@ def test_dpa_mask(dtype, model_configs, model):
model_configs_bias = { model_configs_bias = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"), "bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
"bias_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "post_scale_bias"), "bias_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "post_scale_bias"),
"bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias"), "bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias"),
"bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "post_scale_bias"), "bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "post_scale_bias"),
"bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "alibi"), # skipped "bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "alibi"), # skipped
"bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "alibi"), # skipped "bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "alibi"), # skipped
"bias_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), # skipped "bias_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), # skipped
"bias_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "post_scale_bias"), # skipped "bias_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "post_scale_bias"), # skipped
"bias_2_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding", "post_scale_bias"), # skipped "bias_2_2": ModelConfig(
"bias_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "post_scale_bias"), # skipped 4, 24, 24, 128, 2048, 2048, 0.0, "padding", "post_scale_bias"
"bias_2_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding", "alibi"), # skipped ), # skipped
"bias_2_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "alibi"), # skipped "bias_2_3": ModelConfig(
"bias_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"), 2, 24, 24, 128, 2048, 4096, 0.0, "padding", "post_scale_bias"
"bias_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "post_scale_bias"), ), # skipped
"bias_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), "bias_2_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding", "alibi"), # skipped
"bias_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "post_scale_bias"), # skipped "bias_2_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "alibi"), # skipped
"bias_3_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi"), "bias_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"bias_3_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "alibi"), # skipped "bias_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "post_scale_bias"),
"bias_4_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "padding_causal", "post_scale_bias"), # skipped "bias_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"),
"bias_4_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"), # skipped "bias_3_3": ModelConfig(
"bias_4_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "post_scale_bias"), # skipped 2, 24, 24, 128, 2048, 4096, 0.0, "causal", "post_scale_bias"
"bias_4_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"), # skipped ), # skipped
"bias_4_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "alibi"), # skipped "bias_3_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi"),
"bias_4_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "alibi"), # skipped "bias_3_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "alibi"), # skipped
"bias_4_0": ModelConfig(
4, 16, 16, 64, 128, 128, 0.0, "padding_causal", "post_scale_bias"
), # skipped
"bias_4_1": ModelConfig(
2, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"
), # skipped
"bias_4_2": ModelConfig(
4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "post_scale_bias"
), # skipped
"bias_4_3": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"
), # skipped
"bias_4_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "alibi"), # skipped
"bias_4_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "alibi"), # skipped
} }
@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias]) @pytest.mark.parametrize("model_configs", [model_configs_bias])
@pytest.mark.parametrize("model", model_configs_bias.keys()) @pytest.mark.parametrize("model", model_configs_bias.keys())
...@@ -408,23 +459,38 @@ def test_dpa_bias(dtype, model_configs, model): ...@@ -408,23 +459,38 @@ def test_dpa_bias(dtype, model_configs, model):
model_configs_bias_shapes = { model_configs_bias_shapes = {
# test: b, h, hg, d, sq, skv, p, # test: b, h, hg, d, sq, skv, p,
"bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "bias_1_0": ModelConfig(
# mask, bias, bias_shape, 4,
"no_mask", "post_scale_bias", bias_shape='11ss'), 16,
"bias_1_1": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, 16,
"no_mask", "post_scale_bias", bias_shape='1hss'), 64,
"bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, 128,
"no_mask", "post_scale_bias", bias_shape='b1ss'), 128,
"bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, 0.0,
"no_mask", "post_scale_bias", bias_shape='bhss'), # mask, bias, bias_shape,
"bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask",
"causal", "alibi", bias_shape='1hss', alibi_type='custom'), "post_scale_bias",
"bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, bias_shape="11ss",
"causal", "alibi", bias_shape='bhss', alibi_type='custom'), ),
"bias_1_1": ModelConfig(
2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias", bias_shape="1hss"
),
"bias_1_2": ModelConfig(
4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias", bias_shape="b1ss"
),
"bias_1_3": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias", bias_shape="bhss"
),
"bias_1_4": ModelConfig(
4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi", bias_shape="1hss", alibi_type="custom"
),
"bias_1_5": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi", bias_shape="bhss", alibi_type="custom"
),
} }
@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias_shapes]) @pytest.mark.parametrize("model_configs", [model_configs_bias_shapes])
@pytest.mark.parametrize("model", model_configs_bias_shapes.keys()) @pytest.mark.parametrize("model", model_configs_bias_shapes.keys())
...@@ -435,10 +501,10 @@ def test_dpa_bias_shapes(dtype, model_configs, model): ...@@ -435,10 +501,10 @@ def test_dpa_bias_shapes(dtype, model_configs, model):
model_configs_swa = { model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), "swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
"swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), "swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"),
"swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), "swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), "swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
} }
...@@ -453,10 +519,14 @@ def test_dpa_sliding_window(dtype, model_configs, model): ...@@ -453,10 +519,14 @@ def test_dpa_sliding_window(dtype, model_configs, model):
model_configs_alibi_slopes = { model_configs_alibi_slopes = {
# test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type # test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type
"alibi_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "alibi", alibi_type="vanilla"), "alibi_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "alibi", alibi_type="vanilla"),
"alibi_1_1": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "causal", "alibi", alibi_type="vanilla"), "alibi_1_1": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "causal", "alibi", alibi_type="vanilla"),
"alibi_2_0": ModelConfig(2, 24, 24, 128, 1024, 1024, 0.0, "causal", "alibi", alibi_type= "custom"), "alibi_2_0": ModelConfig(
"alibi_2_1": ModelConfig(1, 24, 24, 128, 1024, 2048, 0.0, "causal", "alibi", alibi_type= "custom"), 2, 24, 24, 128, 1024, 1024, 0.0, "causal", "alibi", alibi_type="custom"
),
"alibi_2_1": ModelConfig(
1, 24, 24, 128, 1024, 2048, 0.0, "causal", "alibi", alibi_type="custom"
),
} }
...@@ -470,27 +540,35 @@ def test_dpa_alibi_slopes(dtype, model_configs, model): ...@@ -470,27 +540,35 @@ def test_dpa_alibi_slopes(dtype, model_configs, model):
qkv_layouts = [ qkv_layouts = [
'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd', "sb3hd",
'bs3hd', 'bsh3d', 'bshd_bs2hd', 'bshd_bsh2d', 'bshd_bshd_bshd', "sbh3d",
] "sbhd_sb2hd",
"sbhd_sbh2d",
"sbhd_sbhd_sbhd",
"bs3hd",
"bsh3d",
"bshd_bs2hd",
"bshd_bsh2d",
"bshd_bshd_bshd",
]
model_configs_layout = { model_configs_layout = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), "layout_0_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
"layout_0_1": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"), "layout_0_1": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"layout_0_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), "layout_0_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"layout_0_3": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"), "layout_0_3": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"),
"layout_1_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), "layout_1_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"layout_1_1": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), "layout_1_1": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"),
"layout_1_2": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), "layout_1_2": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"layout_1_3": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"), "layout_1_3": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"),
"layout_2_0": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), "layout_2_0": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"),
"layout_2_1": ModelConfig(2, 24, 24, 256, 2048, 2048, 0.0, "causal", "post_scale_bias"), "layout_2_1": ModelConfig(2, 24, 24, 256, 2048, 2048, 0.0, "causal", "post_scale_bias"),
} }
@pytest.mark.skipif(get_cudnn_version() < (8,9,5), reason="cuDNN 8.9.5+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8, 9, 5), reason="cuDNN 8.9.5+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_layout]) @pytest.mark.parametrize("model_configs", [model_configs_layout])
@pytest.mark.parametrize("model", model_configs_layout.keys()) @pytest.mark.parametrize("model", model_configs_layout.keys())
...@@ -500,26 +578,28 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): ...@@ -500,26 +578,28 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False) test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)
qkv_layouts_thd = ['t3hd', 'th3d', 'thd_t2hd', 'thd_th2d', 'thd_thd_thd'] qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"]
model_configs_layout_thd = { model_configs_layout_thd = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_1": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"), "layout_0_1": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"),
"layout_0_2": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"), "layout_0_2": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"),
"layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), "layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
"layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), "layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
"layout_1_1": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), "layout_1_1": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"layout_1_2": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), "layout_1_2": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"layout_1_3": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), "layout_1_3": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_1_4": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), "layout_1_4": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_2_1": ModelConfig(1, 16, 16, 128, 128, 128, 0.0, "padding", "no_bias"), "layout_2_1": ModelConfig(1, 16, 16, 128, 128, 128, 0.0, "padding", "no_bias"),
"layout_2_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), "layout_2_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"layout_2_3": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), "layout_2_3": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_2_4": ModelConfig(8, 16, 16, 64, 2048, 4096, 0.0, "padding_causal", "no_bias"), "layout_2_4": ModelConfig(8, 16, 16, 64, 2048, 4096, 0.0, "padding_causal", "no_bias"),
} }
@pytest.mark.skipif(get_cudnn_version() < (9,0,0), reason="cuDNN 9.0.0+ is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 0, 0), reason="cuDNN 9.0.0+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="THD is only supported on Hopper+.") @pytest.mark.skipif(
get_device_compute_capability() < (9, 0), reason="THD is only supported on Hopper+."
)
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_layout_thd]) @pytest.mark.parametrize("model_configs", [model_configs_layout_thd])
@pytest.mark.parametrize("model", model_configs_layout_thd.keys()) @pytest.mark.parametrize("model", model_configs_layout_thd.keys())
...@@ -527,24 +607,26 @@ model_configs_layout_thd = { ...@@ -527,24 +607,26 @@ model_configs_layout_thd = {
def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout): def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with different QKV layouts""" """Test DotProductAttention module with different QKV layouts"""
pad_between_seqs = False pad_between_seqs = False
test_dot_product_attention(dtype, model_configs, model, False, True, test_dot_product_attention(
qkv_layout, False, pad_between_seqs) dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
)
pad_between_seqs = True pad_between_seqs = True
test_dot_product_attention(dtype, model_configs, model, False, True, test_dot_product_attention(
qkv_layout, False, pad_between_seqs) dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
)
def _run_dot_product_attention( def _run_dot_product_attention(
dtype: torch.dtype, dtype: torch.dtype,
config: ModelConfig, config: ModelConfig,
backend: str, backend: str,
ckpt_attn: bool, ckpt_attn: bool,
qkv_layout: str, qkv_layout: str,
workspace_opt: bool, workspace_opt: bool,
swa: bool, swa: bool,
pad_between_seqs: bool, pad_between_seqs: bool,
is_training: bool, is_training: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run DotProductAttention module with one forward pass and one backward pass""" """Run DotProductAttention module with one forward pass and one backward pass"""
# Set RNG and environment varables # Set RNG and environment varables
...@@ -558,22 +640,27 @@ def _run_dot_product_attention( ...@@ -558,22 +640,27 @@ def _run_dot_product_attention(
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0" os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
# Create seqlens # Create seqlens
qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()]) qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if "padding" in config.attn_mask_type or qkv_format == 'thd': if "padding" in config.attn_mask_type or qkv_format == "thd":
if config.attn_type == 'self': if config.attn_type == "self":
seqlens_q = torch.randint(1, config.max_seqlen_q, [config.batch_size], seqlens_q = torch.randint(
dtype=torch.int32, device="cuda") 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
seqlens_kv = seqlens_q seqlens_kv = seqlens_q
if config.attn_type == 'cross': if config.attn_type == "cross":
seqlens_q = torch.randint(1, config.max_seqlen_q, [config.batch_size], seqlens_q = torch.randint(
dtype=torch.int32, device="cuda") 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
seqlens_kv = torch.randint(1, config.max_seqlen_kv, [config.batch_size], )
dtype=torch.int32, device="cuda") seqlens_kv = torch.randint(
1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
)
else: else:
seqlens_q = torch.full([config.batch_size], config.max_seqlen_q, seqlens_q = torch.full(
dtype=torch.int32, device="cuda") [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv, )
dtype=torch.int32, device="cuda") seqlens_kv = torch.full(
[config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
)
cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
...@@ -586,7 +673,7 @@ def _run_dot_product_attention( ...@@ -586,7 +673,7 @@ def _run_dot_product_attention(
pad_len = [0] * config.batch_size pad_len = [0] * config.batch_size
if pad_between_seqs: if pad_between_seqs:
max_pad_len = 3 max_pad_len = 3
pad_len = torch.randint(0, max_pad_len+1, [config.batch_size], device="cuda") #3 pad_len = torch.randint(0, max_pad_len + 1, [config.batch_size], device="cuda") # 3
seqlens_q_after_pad = seqlens_q + pad_len seqlens_q_after_pad = seqlens_q + pad_len
seqlens_kv_after_pad = seqlens_kv + pad_len seqlens_kv_after_pad = seqlens_kv + pad_len
cu_seqlens_q_after_pad[1:] = torch.cumsum(seqlens_q_after_pad, dim=0) cu_seqlens_q_after_pad[1:] = torch.cumsum(seqlens_q_after_pad, dim=0)
...@@ -595,25 +682,58 @@ def _run_dot_product_attention( ...@@ -595,25 +682,58 @@ def _run_dot_product_attention(
# Create attention mask if padding # Create attention mask if padding
attention_mask = None attention_mask = None
if "padding" in config.attn_mask_type: if "padding" in config.attn_mask_type:
if config.attn_type == 'self': if config.attn_type == "self":
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool) attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
for i in range(config.batch_size): for i in range(config.batch_size):
attention_mask_q = torch.cat([attention_mask_q, attention_mask_q = torch.cat(
torch.Tensor([False]*seqlens_q[i] + [True]*(config.max_seqlen_q-seqlens_q[i])) [
.to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0) attention_mask_q,
torch.Tensor(
[False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
)
.to(dtype=torch.bool)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0),
],
dim=0,
)
attention_mask = attention_mask_q.to(device="cuda") attention_mask = attention_mask_q.to(device="cuda")
if config.attn_type == 'cross': if config.attn_type == "cross":
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool) attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool) attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool)
for i in range(config.batch_size): for i in range(config.batch_size):
attention_mask_q = torch.cat([attention_mask_q, attention_mask_q = torch.cat(
torch.Tensor([False]*seqlens_q[i] + [True]*(config.max_seqlen_q-seqlens_q[i])) [
.to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0) attention_mask_q,
attention_mask_kv = torch.cat([attention_mask_kv, torch.Tensor( torch.Tensor(
[False]*seqlens_kv[i] + [True]*(config.max_seqlen_kv-seqlens_kv[i])) [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
.to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0) )
.to(dtype=torch.bool)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0),
],
dim=0,
)
attention_mask_kv = torch.cat(
[
attention_mask_kv,
torch.Tensor(
[False] * seqlens_kv[i]
+ [True] * (config.max_seqlen_kv - seqlens_kv[i])
)
.to(dtype=torch.bool)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0),
],
dim=0,
)
attention_mask = ( attention_mask = (
attention_mask_q.to(device="cuda"), attention_mask_kv.to(device="cuda")) attention_mask_q.to(device="cuda"),
attention_mask_kv.to(device="cuda"),
)
window_size = None window_size = None
if swa: if swa:
window_size, attention_mask = get_swa(config.max_seqlen_q, config.max_seqlen_kv) window_size, attention_mask = get_swa(config.max_seqlen_q, config.max_seqlen_kv)
...@@ -623,62 +743,84 @@ def _run_dot_product_attention( ...@@ -623,62 +743,84 @@ def _run_dot_product_attention(
alibi_slopes = None alibi_slopes = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom": if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
if config.bias_shape == "1hss": if config.bias_shape == "1hss":
alibi_slopes = torch.randn( alibi_slopes = (
config.num_heads).abs().to(dtype=torch.float32, device="cuda") torch.randn(config.num_heads).abs().to(dtype=torch.float32, device="cuda")
)
if config.bias_shape == "bhss": if config.bias_shape == "bhss":
alibi_slopes = torch.randn( alibi_slopes = (
config.batch_size, config.num_heads).abs().to(dtype=torch.float32, device="cuda") torch.randn(config.batch_size, config.num_heads)
.abs()
.to(dtype=torch.float32, device="cuda")
)
# Create input tensors # Create input tensors
dim_to_num = { dim_to_num = {
'b' : config.batch_size, "b": config.batch_size,
'sq' : config.max_seqlen_q, "sq": config.max_seqlen_q,
'skv': config.max_seqlen_kv, "skv": config.max_seqlen_kv,
'h' : config.num_heads, "h": config.num_heads,
'hg' : config.num_gqa_groups, "hg": config.num_gqa_groups,
'd' : config.head_dim, "d": config.head_dim,
't' : cu_seqlens_q_after_pad[-1], "t": cu_seqlens_q_after_pad[-1],
'tg' : cu_seqlens_kv_after_pad[-1], "tg": cu_seqlens_kv_after_pad[-1],
'3' : 3, "3": 3,
'2' : 2, "2": 2,
'1' : 1, "1": 1,
} }
inp = [] inp = []
inp_orig = [] inp_orig = []
for i,layout in enumerate(qkv_layout.split('_')): for i, layout in enumerate(qkv_layout.split("_")):
layout = '_'.join(layout) layout = "_".join(layout)
if i == 0: if i == 0:
layout = layout.replace('s', 'sq') layout = layout.replace("s", "sq")
else: else:
layout = layout.replace('s', 'skv') layout = layout.replace("s", "skv")
layout = layout.replace('h', 'hg') layout = layout.replace("h", "hg")
layout = layout.replace('t', 'tg') layout = layout.replace("t", "tg")
tensor_shape = [dim_to_num[j] for j in layout.split('_')] tensor_shape = [dim_to_num[j] for j in layout.split("_")]
tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda") tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
tensor_orig = tensor tensor_orig = tensor
if qkv_format == 'thd' and pad_between_seqs: if qkv_format == "thd" and pad_between_seqs:
tensor_orig = torch.Tensor([]).to(device="cuda",dtype=dtype) tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
if layout in ['t_h_d', 't_3_h_d', 't_h_3_d']: if layout in ["t_h_d", "t_3_h_d", "t_h_3_d"]:
for i in range(1, config.batch_size+1): for i in range(1, config.batch_size + 1):
valid_range = (cu_seqlens_q_after_pad[i-1], cu_seqlens_q_after_pad[i] - pad_len[i-1]) valid_range = (
pad_range = (cu_seqlens_q_after_pad[i] - pad_len[i-1], cu_seqlens_q_after_pad[i]) cu_seqlens_q_after_pad[i - 1],
tensor[pad_range[0]:pad_range[1]] = 0.0 cu_seqlens_q_after_pad[i] - pad_len[i - 1],
tensor_orig = torch.cat([tensor_orig, tensor[valid_range[0]:valid_range[1]]], dim=0) )
if layout in ['tg_hg_d', 'tg_2_hg_d', 'tg_hg_2_d']: pad_range = (
for i in range(1, config.batch_size+1): cu_seqlens_q_after_pad[i] - pad_len[i - 1],
valid_range = (cu_seqlens_kv_after_pad[i-1], cu_seqlens_kv_after_pad[i] - pad_len[i-1]) cu_seqlens_q_after_pad[i],
pad_range = (cu_seqlens_kv_after_pad[i] - pad_len[i-1], cu_seqlens_kv_after_pad[i]) )
tensor[pad_range[0]:pad_range[1]] = 0.0 tensor[pad_range[0] : pad_range[1]] = 0.0
tensor_orig = torch.cat([tensor_orig, tensor[valid_range[0]:valid_range[1]]], dim=0) tensor_orig = torch.cat(
[tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
)
if layout in ["tg_hg_d", "tg_2_hg_d", "tg_hg_2_d"]:
for i in range(1, config.batch_size + 1):
valid_range = (
cu_seqlens_kv_after_pad[i - 1],
cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
)
pad_range = (
cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
cu_seqlens_kv_after_pad[i],
)
tensor[pad_range[0] : pad_range[1]] = 0.0
tensor_orig = torch.cat(
[tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
)
tensor_count = 1 tensor_count = 1
split_dim = 0 split_dim = 0
for dim, l in enumerate(layout.split('_')): for dim, l in enumerate(layout.split("_")):
if l.isdigit(): if l.isdigit():
tensor_count = int(l) tensor_count = int(l)
split_dim = dim split_dim = dim
break break
tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor] tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
tensors_orig = torch.split(tensor_orig, 1, dim=split_dim) if split_dim != 0 else [tensor_orig] tensors_orig = (
torch.split(tensor_orig, 1, dim=split_dim) if split_dim != 0 else [tensor_orig]
)
for j in range(tensor_count): for j in range(tensor_count):
if split_dim != 0: if split_dim != 0:
inp.append(tensors[j].squeeze(split_dim)) inp.append(tensors[j].squeeze(split_dim))
...@@ -692,73 +834,77 @@ def _run_dot_product_attention( ...@@ -692,73 +834,77 @@ def _run_dot_product_attention(
# Create ragged offsets for q/k/v # Create ragged offsets for q/k/v
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o = None, None, None, None seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o = None, None, None, None
qkv_group = ''.join([x for x in qkv_layout if x not in 'bst']) qkv_group = "".join([x for x in qkv_layout if x not in "bst"])
if qkv_format == 'thd': if qkv_format == "thd":
seq_offsets_o = config.num_heads * config.head_dim * cu_seqlens_q_after_pad seq_offsets_o = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
if qkv_group == 'hd_hd_hd': if qkv_group == "hd_hd_hd":
seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q_after_pad seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
seq_offsets_k = config.num_gqa_groups * config.head_dim * cu_seqlens_kv_after_pad seq_offsets_k = config.num_gqa_groups * config.head_dim * cu_seqlens_kv_after_pad
seq_offsets_v = config.num_gqa_groups * config.head_dim * cu_seqlens_kv_after_pad seq_offsets_v = config.num_gqa_groups * config.head_dim * cu_seqlens_kv_after_pad
if qkv_group in ['3hd', 'h3d']: if qkv_group in ["3hd", "h3d"]:
seq_offsets_q = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad seq_offsets_q = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
seq_offsets_k = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad seq_offsets_k = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
seq_offsets_v = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad seq_offsets_v = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
if qkv_group in ['hd_2hd', 'hd_h2d']: if qkv_group in ["hd_2hd", "hd_h2d"]:
seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q_after_pad seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
seq_offsets_k = config.num_gqa_groups * config.head_dim * 2 * cu_seqlens_kv_after_pad seq_offsets_k = config.num_gqa_groups * config.head_dim * 2 * cu_seqlens_kv_after_pad
seq_offsets_v = config.num_gqa_groups * config.head_dim * 2 * cu_seqlens_kv_after_pad seq_offsets_v = config.num_gqa_groups * config.head_dim * 2 * cu_seqlens_kv_after_pad
# Create output gradient # Create output gradient
qkv_format_kv = '_'.join(qkv_format) qkv_format_kv = "_".join(qkv_format)
qkv_format_kv = qkv_format_kv.replace('s', 'sq') qkv_format_kv = qkv_format_kv.replace("s", "sq")
out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split('_')] out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]] out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda") out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda")
out_grad_orig = out_grad out_grad_orig = out_grad
if qkv_format == 'thd' and pad_between_seqs: if qkv_format == "thd" and pad_between_seqs:
out_grad_orig = torch.Tensor([]).to(device="cuda",dtype=dtype) out_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
if qkv_format_kv == 't_h_d': if qkv_format_kv == "t_h_d":
for i in range(1, config.batch_size+1): for i in range(1, config.batch_size + 1):
valid_range = (cu_seqlens_q_after_pad[i-1], cu_seqlens_q_after_pad[i] - pad_len[i-1]) valid_range = (
pad_range = (cu_seqlens_q_after_pad[i] - pad_len[i-1], cu_seqlens_q_after_pad[i]) cu_seqlens_q_after_pad[i - 1],
out_grad[pad_range[0]:pad_range[1]] = 0.0 cu_seqlens_q_after_pad[i] - pad_len[i - 1],
out_grad_orig = torch.cat([out_grad_orig, out_grad[valid_range[0]:valid_range[1]]], dim=0) )
pad_range = (cu_seqlens_q_after_pad[i] - pad_len[i - 1], cu_seqlens_q_after_pad[i])
out_grad[pad_range[0] : pad_range[1]] = 0.0
out_grad_orig = torch.cat(
[out_grad_orig, out_grad[valid_range[0] : valid_range[1]]], dim=0
)
# Create bias # Create bias
if config.attn_bias_type in ['no_bias', 'alibi']: if config.attn_bias_type in ["no_bias", "alibi"]:
bias = None bias = None
if config.attn_bias_type == 'post_scale_bias': if config.attn_bias_type == "post_scale_bias":
shape = '_'.join(config.bias_shape) shape = "_".join(config.bias_shape)
shape = shape.replace('_s_s', '_sq_skv') shape = shape.replace("_s_s", "_sq_skv")
tensor_shape = [dim_to_num[j] for j in shape.split('_')] tensor_shape = [dim_to_num[j] for j in shape.split("_")]
bias = torch.randn(tensor_shape, dtype=dtype, device="cuda") bias = torch.randn(tensor_shape, dtype=dtype, device="cuda")
if config.bias_shape != '1hss': if config.bias_shape != "1hss":
bias.requires_grad = False bias.requires_grad = False
# Create RNG # Create RNG
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"""Get cuda rng tracker.""" """Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER return _DUMMY_CUDA_RNG_STATE_TRACKER
# Set up model # Set up model
block = ( block = DotProductAttention(
DotProductAttention( config.num_heads,
config.num_heads, config.head_dim,
config.head_dim, num_gqa_groups=config.num_gqa_groups,
num_gqa_groups=config.num_gqa_groups, attention_dropout=config.dropout_p,
attention_dropout=config.dropout_p, qkv_format=qkv_format,
qkv_format=qkv_format, attn_mask_type=config.attn_mask_type,
attn_mask_type=config.attn_mask_type, sequence_parallel=False,
sequence_parallel=False, tp_size=1,
tp_size=1, get_rng_state_tracker=get_dummy_cuda_rng_tracker,
get_rng_state_tracker=get_dummy_cuda_rng_tracker, tp_group=None,
tp_group=None, layer_number=1,
layer_number=1, attention_type=config.attn_type,
attention_type=config.attn_type, ).to(dtype=dtype, device="cuda")
).to(dtype=dtype, device="cuda")
)
# Run a forward and backward pass # Run a forward and backward pass
if backend in ["FlashAttention", "UnfusedDotProductAttention"]: if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
...@@ -771,24 +917,28 @@ def _run_dot_product_attention( ...@@ -771,24 +917,28 @@ def _run_dot_product_attention(
k = inp[1] k = inp[1]
v = inp[2] v = inp[2]
d_out = out_grad d_out = out_grad
out = block(q, k, v, out = block(
window_size=window_size, q,
attention_mask=attention_mask, k,
qkv_format=qkv_format, v,
max_seqlen_q=config.max_seqlen_q, window_size=window_size,
max_seqlen_kv=config.max_seqlen_kv, attention_mask=attention_mask,
cu_seqlens_q=cu_seqlens_q, qkv_format=qkv_format,
cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=config.max_seqlen_q,
seq_offsets_q=seq_offsets_q, max_seqlen_kv=config.max_seqlen_kv,
seq_offsets_k=seq_offsets_k, cu_seqlens_q=cu_seqlens_q,
seq_offsets_v=seq_offsets_v, cu_seqlens_kv=cu_seqlens_kv,
seq_offsets_o=seq_offsets_o, seq_offsets_q=seq_offsets_q,
attn_mask_type=config.attn_mask_type, seq_offsets_k=seq_offsets_k,
checkpoint_core_attention=ckpt_attn, seq_offsets_v=seq_offsets_v,
core_attention_bias_type=config.attn_bias_type, seq_offsets_o=seq_offsets_o,
core_attention_bias=bias, attn_mask_type=config.attn_mask_type,
alibi_slopes=alibi_slopes, checkpoint_core_attention=ckpt_attn,
fast_zero_fill=True) core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=True,
)
if is_training: if is_training:
out.backward(d_out) out.backward(d_out)
...@@ -798,18 +948,30 @@ def _run_dot_product_attention( ...@@ -798,18 +948,30 @@ def _run_dot_product_attention(
else: else:
return out, (None, None, None) return out, (None, None, None)
if backend == "FusedAttention": if backend == "FusedAttention":
if qkv_format == 'thd' and pad_between_seqs: if qkv_format == "thd" and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda",dtype=dtype) out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
q_grad_orig = torch.Tensor([]).to(device="cuda",dtype=dtype) q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
k_grad_orig = torch.Tensor([]).to(device="cuda",dtype=dtype) k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
v_grad_orig = torch.Tensor([]).to(device="cuda",dtype=dtype) v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
for i in range(1, config.batch_size+1): for i in range(1, config.batch_size + 1):
valid_range_q = (cu_seqlens_q_after_pad[i-1], cu_seqlens_q_after_pad[i] - pad_len[i-1]) valid_range_q = (
valid_range_kv = (cu_seqlens_kv_after_pad[i-1], cu_seqlens_kv_after_pad[i] - pad_len[i-1]) cu_seqlens_q_after_pad[i - 1],
out_orig = torch.cat([out_orig, out[valid_range_q[0]:valid_range_q[1]]], dim=0) cu_seqlens_q_after_pad[i] - pad_len[i - 1],
q_grad_orig = torch.cat([q_grad_orig, q.grad[valid_range_q[0]:valid_range_q[1]]], dim=0) )
k_grad_orig = torch.cat([k_grad_orig, k.grad[valid_range_kv[0]:valid_range_kv[1]]], dim=0) valid_range_kv = (
v_grad_orig = torch.cat([v_grad_orig, v.grad[valid_range_kv[0]:valid_range_kv[1]]], dim=0) cu_seqlens_kv_after_pad[i - 1],
cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
)
out_orig = torch.cat([out_orig, out[valid_range_q[0] : valid_range_q[1]]], dim=0)
q_grad_orig = torch.cat(
[q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0
)
k_grad_orig = torch.cat(
[k_grad_orig, k.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
)
v_grad_orig = torch.cat(
[v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
)
if is_training: if is_training:
return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig) return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig)
else: else:
...@@ -823,18 +985,18 @@ def _run_dot_product_attention( ...@@ -823,18 +985,18 @@ def _run_dot_product_attention(
model_configs_te_layer = { model_configs_te_layer = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"), "te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
"te_1_1": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"), "te_1_1": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"te_1_2": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), "te_1_2": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"),
"te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), "te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), "te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), "te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"), "te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"),
"te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"), "te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"),
} }
@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer]) @pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", model_configs_te_layer.keys()) @pytest.mark.parametrize("model", model_configs_te_layer.keys())
...@@ -842,7 +1004,9 @@ model_configs_te_layer = { ...@@ -842,7 +1004,9 @@ model_configs_te_layer = {
@pytest.mark.parametrize("qkv_format", ["sbhd"]) @pytest.mark.parametrize("qkv_format", ["sbhd"])
@pytest.mark.parametrize("fused_qkv_params", [False]) @pytest.mark.parametrize("fused_qkv_params", [False])
@pytest.mark.parametrize("RoPE", [False]) @pytest.mark.parametrize("RoPE", [False])
def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE): def test_transformer_layer(
dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
):
"""Test TransformerLayer module""" """Test TransformerLayer module"""
# Get configs # Get configs
...@@ -916,7 +1080,7 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f ...@@ -916,7 +1080,7 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f
torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols) torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)
@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer]) @pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_1_2", "te_2_0"]) @pytest.mark.parametrize("model", ["te_1_2", "te_2_0"])
...@@ -926,22 +1090,24 @@ def test_te_layer_misc(dtype, model_configs, model, qkv_format): ...@@ -926,22 +1090,24 @@ def test_te_layer_misc(dtype, model_configs, model, qkv_format):
ckpt_attn = True ckpt_attn = True
fused_qkv_params = True fused_qkv_params = True
RoPE = True RoPE = True
test_transformer_layer(dtype, model_configs, model, test_transformer_layer(
ckpt_attn, qkv_format, fused_qkv_params, RoPE) dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
)
@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer]) @pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_2_0", "te_2_1", "te_2_2"]) @pytest.mark.parametrize("model", ["te_2_0", "te_2_1", "te_2_2"])
def test_te_layer_mqa_gqa(dtype, model_configs, model): def test_te_layer_mqa_gqa(dtype, model_configs, model):
"""Test TransformerLayer module with MQA/GQA""" """Test TransformerLayer module with MQA/GQA"""
def find_factors(x): def find_factors(x):
f = [] f = []
for i in range(2, x + 1): for i in range(2, x + 1):
if x % i == 0: if x % i == 0:
f.append(i) f.append(i)
return f return f
ckpt_attn = True ckpt_attn = True
qkv_format = "bshd" qkv_format = "bshd"
...@@ -951,21 +1117,22 @@ def test_te_layer_mqa_gqa(dtype, model_configs, model): ...@@ -951,21 +1117,22 @@ def test_te_layer_mqa_gqa(dtype, model_configs, model):
num_querys_per_gqa_group = find_factors(config.num_heads) num_querys_per_gqa_group = find_factors(config.num_heads)
for num_q_per_gqa_group in num_querys_per_gqa_group: for num_q_per_gqa_group in num_querys_per_gqa_group:
config.num_gqa_groups=config.num_heads // num_q_per_gqa_group config.num_gqa_groups = config.num_heads // num_q_per_gqa_group
test_transformer_layer(dtype, model_configs, model, test_transformer_layer(
ckpt_attn, qkv_format, fused_qkv_params, RoPE) dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
)
def _run_transformer_layer( def _run_transformer_layer(
dtype: torch.dtype, dtype: torch.dtype,
config: ModelConfig, config: ModelConfig,
backend: str, backend: str,
ckpt_attn: bool, ckpt_attn: bool,
qkv_format: str, qkv_format: str,
workspace_opt: bool, workspace_opt: bool,
fused_qkv_params: bool, fused_qkv_params: bool,
RoPE: bool, RoPE: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run TransformerLayer module with one forward pass and one backward pass""" """Run TransformerLayer module with one forward pass and one backward pass"""
# Set RNG and environment variables # Set RNG and environment variables
...@@ -978,29 +1145,47 @@ def _run_transformer_layer( ...@@ -978,29 +1145,47 @@ def _run_transformer_layer(
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
# Create input tensor # Create input tensor
inp = torch.randn(config.max_seqlen_q, config.batch_size, config.hidden_size, inp = torch.randn(
dtype=dtype, device="cuda", requires_grad = True) config.max_seqlen_q,
config.batch_size,
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)
# In case the format to be tested is batch-first, need to transpose the # In case the format to be tested is batch-first, need to transpose the
# input tensor. # input tensor.
if qkv_format == "bshd": if qkv_format == "bshd":
inp = inp.transpose(0,1) inp = inp.transpose(0, 1)
# Create seqlens # Create seqlens
if "padding" in config.attn_mask_type: if "padding" in config.attn_mask_type:
seqlens_q = torch.randint(1, config.max_seqlen_q, [config.batch_size], seqlens_q = torch.randint(
dtype=torch.int32, device="cuda") 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
else: else:
seqlens_q = torch.full([config.batch_size], config.max_seqlen_q, seqlens_q = torch.full(
dtype=torch.int32, device="cuda") [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
)
# Create attention mask if padding # Create attention mask if padding
attention_mask = None attention_mask = None
if "padding" in config.attn_mask_type: if "padding" in config.attn_mask_type:
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool) attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
for i in range(config.batch_size): for i in range(config.batch_size):
attention_mask_q = torch.cat([attention_mask_q, attention_mask_q = torch.cat(
torch.Tensor([False]*seqlens_q[i] + [True]*(config.max_seqlen_q-seqlens_q[i])) [
.to(torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0) attention_mask_q,
torch.Tensor(
[False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
)
.to(torch.bool)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0),
],
dim=0,
)
attention_mask = attention_mask_q.to(device="cuda") attention_mask = attention_mask_q.to(device="cuda")
sigma = 0.02 sigma = 0.02
...@@ -1009,14 +1194,19 @@ def _run_transformer_layer( ...@@ -1009,14 +1194,19 @@ def _run_transformer_layer(
layer_number = 1 layer_number = 1
drop_path_rate = 0.0 drop_path_rate = 0.0
drop_path_rates = [ drop_path_rates = [rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
# Create bias # Create bias
bias = None bias = None
if config.attn_bias_type == 'post_scale_bias': if config.attn_bias_type == "post_scale_bias":
bias = torch.randn(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv, bias = torch.randn(
dtype=dtype, device="cuda") 1,
config.num_heads,
config.max_seqlen_q,
config.max_seqlen_kv,
dtype=dtype,
device="cuda",
)
# Create RoPE # Create RoPE
rotary_pos_emb = None rotary_pos_emb = None
...@@ -1025,58 +1215,56 @@ def _run_transformer_layer( ...@@ -1025,58 +1215,56 @@ def _run_transformer_layer(
rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda") rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
# Set up model # Set up model
block = ( block = TransformerLayer(
TransformerLayer( config.hidden_size,
config.hidden_size, 4 * config.hidden_size,
4 * config.hidden_size, config.num_heads,
config.num_heads, num_gqa_groups=config.num_gqa_groups,
num_gqa_groups=config.num_gqa_groups, layernorm_epsilon=1e-5,
layernorm_epsilon=1e-5, hidden_dropout=0.0,
hidden_dropout=0.0, attention_dropout=config.dropout_p,
attention_dropout=config.dropout_p, init_method=init_method,
init_method=init_method, output_layer_init_method=output_layer_init_method,
output_layer_init_method=output_layer_init_method, layer_number=layer_number,
layer_number=layer_number, kv_channels=config.head_dim,
kv_channels=config.head_dim, self_attn_mask_type=config.attn_mask_type,
self_attn_mask_type=config.attn_mask_type, tp_group=None,
tp_group=None, tp_size=1,
tp_size=1, params_dtype=dtype,
params_dtype=dtype, get_rng_state_tracker=None,
get_rng_state_tracker=None, fuse_wgrad_accumulation=False,
fuse_wgrad_accumulation=False, seq_length=config.max_seqlen_q,
seq_length=config.max_seqlen_q, micro_batch_size=config.batch_size,
micro_batch_size=config.batch_size, sequence_parallel=False,
sequence_parallel=False, apply_residual_connection_post_layernorm=False,
apply_residual_connection_post_layernorm=False, output_layernorm=False,
output_layernorm=False, layer_type="encoder",
layer_type="encoder", drop_path_rate=drop_path_rates[layer_number - 1],
drop_path_rate=drop_path_rates[layer_number - 1], set_parallel_mode=True,
set_parallel_mode=True, fuse_qkv_params=fused_qkv_params,
fuse_qkv_params=fused_qkv_params, zero_centered_gamma=False,
zero_centered_gamma=False, qkv_weight_interleaved=False,
qkv_weight_interleaved=False, ub_tp_comm_overlap=False,
ub_tp_comm_overlap=False, bias=True,
bias=True, attn_input_format=qkv_format,
attn_input_format=qkv_format, ).to(dtype=dtype, device="cuda")
)
.to(dtype=dtype, device="cuda")
)
# Create ALiBi slopes # Create ALiBi slopes
alibi_slopes = None alibi_slopes = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom": if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
alibi_slopes = torch.randn( alibi_slopes = torch.randn(config.num_heads).abs().to(dtype=torch.float32, device="cuda")
config.num_heads).abs().to(dtype=torch.float32, device="cuda")
# Run a forward and backward pass # Run a forward and backward pass
out = block(inp, out = block(
inp,
attention_mask=attention_mask, attention_mask=attention_mask,
self_attn_mask_type=config.attn_mask_type, self_attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=False, checkpoint_core_attention=False,
rotary_pos_emb=rotary_pos_emb, rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=config.attn_bias_type, core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias, core_attention_bias=bias,
alibi_slopes=alibi_slopes) alibi_slopes=alibi_slopes,
)
loss = out.sum() loss = out.sum()
loss.backward() loss.backward()
...@@ -1085,23 +1273,24 @@ def _run_transformer_layer( ...@@ -1085,23 +1273,24 @@ def _run_transformer_layer(
model_configs_fp8_vs_f16 = { model_configs_fp8_vs_f16 = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"fp8_9" : ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), "fp8_9": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_10": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), "fp8_10": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_11": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), "fp8_11": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_12": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"), "fp8_12": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_13": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"), "fp8_13": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"),
"fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), "fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"),
} }
param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16] param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16]
qkv_layout_fp8_vs_f16 = ['sbh3d', 'bshd_bshd_bshd', 'sbhd_sbhd_sbhd'] qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"]
qkv_format_fp8_vs_f16 = ['bshd', 'sbhd'] qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]
def _rmse(a, b): def _rmse(a, b):
return math.sqrt((torch.pow((a-b), 2)/a.numel()).sum()) return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum())
@pytest.mark.skipif(get_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8, 9, 3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") @pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16) @pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
...@@ -1118,58 +1307,78 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd): ...@@ -1118,58 +1307,78 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd):
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16( fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm) dtype, config, True, qkv_format, input_layernorm
)
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False") logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16( fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
dtype, config, False, qkv_format, input_layernorm) dtype, config, False, qkv_format, input_layernorm
)
tols = dict(atol=5e-1, rtol=5e-1) tols = dict(atol=5e-1, rtol=5e-1)
rmse_tol = 0.1 rmse_tol = 0.1
fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16) fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16)
fwd_range = max(fused_attn_fwd_fp8.max().item(), fwd_range = max(fused_attn_fwd_fp8.max().item(), fused_attn_fwd_f16.max().item()) - min(
fused_attn_fwd_f16.max().item()) - min(fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.min().item(), fused_attn_fwd_f16.min().item()
fused_attn_fwd_f16.min().item()) )
logging.debug('========== {:^25s} =========='.format('forward output')) logging.debug("========== {:^25s} ==========".format("forward output"))
logging.debug('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format( logging.debug(
fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item())) "fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format(
logging.debug('fused_attn_fwd_f16 min {:.6f} max {:.6f}'.format( fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item()
fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item())) )
logging.debug('fused_attn_fwd RMSE: {:.6f}'.format(fwd_rmse)) )
logging.debug(
"fused_attn_fwd_f16 min {:.6f} max {:.6f}".format(
fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item()
)
)
logging.debug("fused_attn_fwd RMSE: {:.6f}".format(fwd_rmse))
try: try:
torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols) torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols)
except Exception as e: except Exception as e:
logging.debug(e) logging.debug(e)
assert(fwd_rmse < rmse_tol * fwd_range assert (
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( fwd_rmse < rmse_tol * fwd_range
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range) ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range
)
for i in range(len(param_names[:1])): for i in range(len(param_names[:1])):
bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i]) bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i])
bwd_range = max(fused_attn_bwd_fp8[i].max().item(), bwd_range = max(
fused_attn_bwd_f16[i].max().item()) - min(fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item(), fused_attn_bwd_f16[i].max().item()
fused_attn_bwd_f16[i].min().item()) ) - min(fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_f16[i].min().item())
logging.debug('========== {:^25s} =========='.format(param_names[i])) logging.debug("========== {:^25s} ==========".format(param_names[i]))
logging.debug('fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}'.format(i, logging.debug(
fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item())) "fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}".format(
logging.debug('fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}'.format(i, i, fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()
fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item())) )
logging.debug('fused_attn_bwd RMSE[{}]: {:.6f}'.format(i, bwd_rmse)) )
logging.debug(
"fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}".format(
i, fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()
)
)
logging.debug("fused_attn_bwd RMSE[{}]: {:.6f}".format(i, bwd_rmse))
try: try:
torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols) torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols)
except Exception as e: except Exception as e:
logging.debug(e) logging.debug(e)
assert(bwd_rmse < rmse_tol * bwd_range assert (
), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( bwd_rmse < rmse_tol * bwd_range
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range) ), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range
)
def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
reset_rng_states() reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"""Get cuda rng tracker.""" """Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER return _DUMMY_CUDA_RNG_STATE_TRACKER
...@@ -1184,7 +1393,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): ...@@ -1184,7 +1393,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
) )
with fp8_model_init(enabled=fp8_mha): with fp8_model_init(enabled=fp8_mha):
mha = (MultiheadAttention( mha = MultiheadAttention(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
num_attention_heads=config.num_heads, num_attention_heads=config.num_heads,
kv_channels=config.head_dim, kv_channels=config.head_dim,
...@@ -1199,34 +1408,35 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): ...@@ -1199,34 +1408,35 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
attention_type="self", attention_type="self",
qkv_weight_interleaved=True, qkv_weight_interleaved=True,
qkv_format=qkv_format, qkv_format=qkv_format,
).to(dtype=dtype, device="cuda") ).to(dtype=dtype, device="cuda")
)
seqlens_q = torch.full([config.batch_size], config.max_seqlen_q, seqlens_q = torch.full(
dtype=torch.int32, device="cuda") [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv, )
dtype=torch.int32, device="cuda") seqlens_kv = torch.full(
[config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
)
cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0) cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
dim_to_num = { dim_to_num = {
'b' : config.batch_size, "b": config.batch_size,
'sq' : config.max_seqlen_q, "sq": config.max_seqlen_q,
'skv': config.max_seqlen_kv, "skv": config.max_seqlen_kv,
'h' : config.num_heads, "h": config.num_heads,
'hg' : config.num_gqa_groups, "hg": config.num_gqa_groups,
'd' : config.head_dim, "d": config.head_dim,
't' : cu_seqlens_q[-1], "t": cu_seqlens_q[-1],
'tg' : cu_seqlens_kv[-1], "tg": cu_seqlens_kv[-1],
'3' : 3, "3": 3,
'2' : 2, "2": 2,
'1' : 1, "1": 1,
} }
layout = '_'.join(qkv_format) layout = "_".join(qkv_format)
layout = layout.replace('s', 'sq') layout = layout.replace("s", "sq")
tensor_shape = [dim_to_num[j] for j in layout.split('_')] tensor_shape = [dim_to_num[j] for j in layout.split("_")]
tensor = 0.01 * torch.randint(-100, 100, tensor_shape, dtype=dtype, device="cuda") tensor = 0.01 * torch.randint(-100, 100, tensor_shape, dtype=dtype, device="cuda")
hidden_states = tensor.view(*tensor.shape[:-2], -1) hidden_states = tensor.view(*tensor.shape[:-2], -1)
hidden_states.requires_grad = True hidden_states.requires_grad = True
...@@ -1234,27 +1444,28 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): ...@@ -1234,27 +1444,28 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
out_grad = tensor.view(*tensor.shape[:-2], -1) out_grad = tensor.view(*tensor.shape[:-2], -1)
with fp8_autocast(enabled=fp8_mha, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=fp8_mha, fp8_recipe=fp8_recipe):
out = mha(hidden_states, out = mha(
hidden_states,
attn_mask_type=config.attn_mask_type, attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=False, checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type, core_attention_bias_type=config.attn_bias_type,
is_first_microbatch=None, is_first_microbatch=None,
) )
out.backward(out_grad) out.backward(out_grad)
param_names = [] param_names = []
param_names.append('hidden_states.grad') param_names.append("hidden_states.grad")
params = [] params = []
params.append(hidden_states) params.append(hidden_states)
for name, param in mha.named_parameters(): for name, param in mha.named_parameters():
if param.requires_grad: if param.requires_grad:
param_names.append(name+'.grad') param_names.append(name + ".grad")
params.append(param) params.append(param)
return out, param_names, tuple(x.grad for x in params) return out, param_names, tuple(x.grad for x in params)
@pytest.mark.skipif(get_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8, 9, 3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") @pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16) @pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
...@@ -1264,62 +1475,75 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): ...@@ -1264,62 +1475,75 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd): def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd):
config = model_configs_fp8_vs_f16[model] config = model_configs_fp8_vs_f16[model]
if (config.num_heads != config.num_gqa_groups and '3' in qkv_layout): if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
pytest.skip("qkv_layout not applicable for MQA/GQA"); pytest.skip("qkv_layout not applicable for MQA/GQA")
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True") logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(dtype, config, True, qkv_layout)
dtype, config, True, qkv_layout)
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False") logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False")
fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(dtype, config, False, qkv_layout)
dtype, config, False, qkv_layout)
tols = dict(atol=5e-1, rtol=5e-2) tols = dict(atol=5e-1, rtol=5e-2)
rmse_tol = 0.1 rmse_tol = 0.1
bwd_names = ['dq', 'dk', 'dv'] bwd_names = ["dq", "dk", "dv"]
fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16) fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16)
fwd_range = max(fused_attn_fwd_fp8.max().item(), fwd_range = max(fused_attn_fwd_fp8.max().item(), fused_attn_fwd_f16.max().item()) - min(
fused_attn_fwd_f16.max().item()) - min(fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.min().item(), fused_attn_fwd_f16.min().item()
fused_attn_fwd_f16.min().item()) )
logging.debug('========== {:^25s} =========='.format('forward output')) logging.debug("========== {:^25s} ==========".format("forward output"))
logging.debug('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format( logging.debug(
fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item())) "fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format(
logging.debug('fused_attn_fwd_f16 min {:.6f} max {:.6f}'.format( fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item()
fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item())) )
logging.debug('fused_attn_fwd RMSE: {:.6f}'.format(fwd_rmse)) )
logging.debug(
"fused_attn_fwd_f16 min {:.6f} max {:.6f}".format(
fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item()
)
)
logging.debug("fused_attn_fwd RMSE: {:.6f}".format(fwd_rmse))
try: try:
torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols) torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols)
except Exception as e: except Exception as e:
logging.debug(e) logging.debug(e)
assert(fwd_rmse < rmse_tol * fwd_range assert (
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( fwd_rmse < rmse_tol * fwd_range
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range) ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
for i,_ in enumerate(fused_attn_bwd_f16): fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range
)
for i, _ in enumerate(fused_attn_bwd_f16):
bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i]) bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i])
bwd_range = max(fused_attn_bwd_fp8[i].max().item(), bwd_range = max(
fused_attn_bwd_f16[i].max().item()) - min(fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item(), fused_attn_bwd_f16[i].max().item()
fused_attn_bwd_f16[i].min().item()) ) - min(fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_f16[i].min().item())
logging.debug('========== {:^25s} =========='.format(bwd_names[i])) logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
logging.debug('fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}'.format(i, logging.debug(
fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item())) "fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}".format(
logging.debug('fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}'.format(i, i, fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()
fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item())) )
logging.debug('fused_attn_bwd RMSE[{}]: {:.6f}'.format(i, bwd_rmse)) )
logging.debug(
"fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}".format(
i, fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()
)
)
logging.debug("fused_attn_bwd RMSE[{}]: {:.6f}".format(i, bwd_rmse))
try: try:
torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols) torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols)
except Exception as e: except Exception as e:
logging.debug(e) logging.debug(e)
assert(bwd_rmse < rmse_tol * bwd_range assert (
), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( bwd_rmse < rmse_tol * bwd_range
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range) ), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range
)
def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout): def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
...@@ -1327,6 +1551,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout): ...@@ -1327,6 +1551,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
reset_rng_states() reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"""Get cuda rng tracker.""" """Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER return _DUMMY_CUDA_RNG_STATE_TRACKER
...@@ -1339,60 +1564,60 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout): ...@@ -1339,60 +1564,60 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
fp8_dpa=fp8_dpa, fp8_dpa=fp8_dpa,
) )
qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()]) qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
with fp8_model_init(enabled=fp8_dpa): with fp8_model_init(enabled=fp8_dpa):
dpa = ( dpa = DotProductAttention(
DotProductAttention( config.num_heads,
config.num_heads, config.head_dim,
config.head_dim, num_gqa_groups=config.num_gqa_groups,
num_gqa_groups=config.num_gqa_groups, attention_dropout=config.dropout_p,
attention_dropout=config.dropout_p, sequence_parallel=False,
sequence_parallel=False, tp_size=1,
tp_size=1, get_rng_state_tracker=get_dummy_cuda_rng_tracker,
get_rng_state_tracker=get_dummy_cuda_rng_tracker, tp_group=None,
tp_group=None, layer_number=1,
layer_number=1, attention_type="self",
attention_type="self", qkv_format=qkv_format,
qkv_format=qkv_format, ).to(dtype=dtype, device="cuda")
).to(dtype=dtype, device="cuda")
)
seqlens_q = torch.full([config.batch_size], config.max_seqlen_q, seqlens_q = torch.full(
dtype=torch.int32, device="cuda") [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv, )
dtype=torch.int32, device="cuda") seqlens_kv = torch.full(
[config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
)
cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0) cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
dim_to_num = { dim_to_num = {
'b' : config.batch_size, "b": config.batch_size,
'sq' : config.max_seqlen_q, "sq": config.max_seqlen_q,
'skv': config.max_seqlen_kv, "skv": config.max_seqlen_kv,
'h' : config.num_heads, "h": config.num_heads,
'hg' : config.num_gqa_groups, "hg": config.num_gqa_groups,
'd' : config.head_dim, "d": config.head_dim,
't' : cu_seqlens_q[-1], "t": cu_seqlens_q[-1],
'tg' : cu_seqlens_kv[-1], "tg": cu_seqlens_kv[-1],
'3' : 3, "3": 3,
'2' : 2, "2": 2,
'1' : 1, "1": 1,
} }
inp = [] inp = []
for i,layout in enumerate(qkv_layout.split('_')): for i, layout in enumerate(qkv_layout.split("_")):
layout = '_'.join(layout) layout = "_".join(layout)
if i == 0: if i == 0:
layout = layout.replace('s', 'sq') layout = layout.replace("s", "sq")
else: else:
layout = layout.replace('s', 'skv') layout = layout.replace("s", "skv")
layout = layout.replace('h', 'hg') layout = layout.replace("h", "hg")
layout = layout.replace('t', 'tg') layout = layout.replace("t", "tg")
tensor_shape = [dim_to_num[j] for j in layout.split('_')] tensor_shape = [dim_to_num[j] for j in layout.split("_")]
tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda") tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda")
tensor_count = 1 tensor_count = 1
split_dim = 0 split_dim = 0
for dim, l in enumerate(layout.split('_')): for dim, l in enumerate(layout.split("_")):
if l.isdigit(): if l.isdigit():
tensor_count = int(l) tensor_count = int(l)
split_dim = dim split_dim = dim
...@@ -1406,14 +1631,17 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout): ...@@ -1406,14 +1631,17 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
for i in range(3): for i in range(3):
inp[i].requires_grad = True inp[i].requires_grad = True
qkv_format_kv = '_'.join(qkv_format) qkv_format_kv = "_".join(qkv_format)
qkv_format_kv = qkv_format_kv.replace('s', 'sq') qkv_format_kv = qkv_format_kv.replace("s", "sq")
out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split('_')] out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]] out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda") out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda")
with fp8_autocast(enabled=fp8_dpa, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=fp8_dpa, fp8_recipe=fp8_recipe):
out = dpa(inp[0], inp[1], inp[2], out = dpa(
inp[0],
inp[1],
inp[2],
qkv_format=qkv_format, qkv_format=qkv_format,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
...@@ -1423,7 +1651,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout): ...@@ -1423,7 +1651,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
checkpoint_core_attention=False, checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type, core_attention_bias_type=config.attn_bias_type,
is_first_microbatch=True, is_first_microbatch=True,
) )
out.backward(out_grad) out.backward(out_grad)
return out, (inp[0].grad, inp[1].grad, inp[2].grad) return out, (inp[0].grad, inp[1].grad, inp[2].grad)
...@@ -1431,22 +1659,22 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout): ...@@ -1431,22 +1659,22 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
model_configs_fp8 = { model_configs_fp8 = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"fp8_1": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "no_mask", "no_bias"), "fp8_1": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "no_mask", "no_bias"),
"fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), "fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),
"fp8_3": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), "fp8_3": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_4": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), "fp8_4": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_5": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "causal", "no_bias"), "fp8_5": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "causal", "no_bias"),
"fp8_6": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "causal", "no_bias"), "fp8_6": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "causal", "no_bias"),
"fp8_7": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"), "fp8_7": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_8": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), "fp8_8": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
} }
param_types_fp8 = [torch.float16, torch.bfloat16] param_types_fp8 = [torch.float16, torch.bfloat16]
cudnn_frontend_version = int(os.getenv('NVTE_FUSED_ATTN_FE_VER','1')) cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1"))
models_v0 = ['fp8_1', 'fp8_2', 'fp8_5', 'fp8_6'] models_v0 = ["fp8_1", "fp8_2", "fp8_5", "fp8_6"]
models_v1 = ['fp8_3', 'fp8_4', 'fp8_7', 'fp8_8'] models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
@pytest.mark.skipif(get_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8, 9, 3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") @pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8) @pytest.mark.parametrize("dtype", param_types_fp8)
...@@ -1460,50 +1688,62 @@ def test_custom_mha_fp8_vs_f16(dtype, model): ...@@ -1460,50 +1688,62 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
config = model_configs_fp8[model] config = model_configs_fp8[model]
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8( fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
dtype, config, "FusedAttention") unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention")
unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(
dtype, config, "UnfusedAttention")
tols = dict(atol=5e-1, rtol=5e-1) tols = dict(atol=5e-1, rtol=5e-1)
rmse_tol = 0.1 rmse_tol = 0.1
fwd_rmse = _rmse(fused_attn_fwd_fp8, unfused_attn_fwd_f16) fwd_rmse = _rmse(fused_attn_fwd_fp8, unfused_attn_fwd_f16)
fwd_range = max(fused_attn_fwd_fp8.max().item(), fwd_range = max(fused_attn_fwd_fp8.max().item(), unfused_attn_fwd_f16.max().item()) - min(
unfused_attn_fwd_f16.max().item()) - min(fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.min().item(), unfused_attn_fwd_f16.min().item()
unfused_attn_fwd_f16.min().item()) )
bwd_rmse = _rmse(fused_attn_bwd_fp8, unfused_attn_bwd_f16) bwd_rmse = _rmse(fused_attn_bwd_fp8, unfused_attn_bwd_f16)
bwd_range = max(fused_attn_bwd_fp8.max().item(), bwd_range = max(fused_attn_bwd_fp8.max().item(), unfused_attn_bwd_f16.max().item()) - min(
unfused_attn_bwd_f16.max().item()) - min(fused_attn_bwd_fp8.min().item(), fused_attn_bwd_fp8.min().item(), unfused_attn_bwd_f16.min().item()
unfused_attn_bwd_f16.min().item()) )
logging.debug('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format( logging.debug(
fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item())) "fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format(
logging.debug('unfused_attn_fwd_f16 min {:.6f} max {:.6f}'.format( fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item()
unfused_attn_fwd_f16.min().item(), unfused_attn_fwd_f16.max().item())) )
logging.debug('fused_attn_fwd_fp8 vs unfused_attn_fwd_f16 RMSE: {:.6f}'.format( )
fwd_rmse)) logging.debug(
"unfused_attn_fwd_f16 min {:.6f} max {:.6f}".format(
unfused_attn_fwd_f16.min().item(), unfused_attn_fwd_f16.max().item()
)
)
logging.debug("fused_attn_fwd_fp8 vs unfused_attn_fwd_f16 RMSE: {:.6f}".format(fwd_rmse))
try: try:
torch.testing.assert_close(fused_attn_fwd_fp8, unfused_attn_fwd_f16, **tols) torch.testing.assert_close(fused_attn_fwd_fp8, unfused_attn_fwd_f16, **tols)
except Exception as e: except Exception as e:
logging.debug(e) logging.debug(e)
logging.debug('fused_attn_bwd_fp8 min {:.6f} max {:.6f}'.format( logging.debug(
fused_attn_bwd_fp8.min().item(), fused_attn_bwd_fp8.max().item())) "fused_attn_bwd_fp8 min {:.6f} max {:.6f}".format(
logging.debug('unfused_attn_bwd_f16 min {:.6f} max {:.6f}'.format( fused_attn_bwd_fp8.min().item(), fused_attn_bwd_fp8.max().item()
unfused_attn_bwd_f16.min().item(), unfused_attn_bwd_f16.max().item())) )
logging.debug('fused_attn_bwd_fp8 vs unfused_attn_bwd_f16 RMSE: {:.6f}'.format( )
bwd_rmse)) logging.debug(
"unfused_attn_bwd_f16 min {:.6f} max {:.6f}".format(
unfused_attn_bwd_f16.min().item(), unfused_attn_bwd_f16.max().item()
)
)
logging.debug("fused_attn_bwd_fp8 vs unfused_attn_bwd_f16 RMSE: {:.6f}".format(bwd_rmse))
try: try:
torch.testing.assert_close(fused_attn_bwd_fp8, unfused_attn_bwd_f16, **tols) torch.testing.assert_close(fused_attn_bwd_fp8, unfused_attn_bwd_f16, **tols)
except Exception as e: except Exception as e:
logging.debug(e) logging.debug(e)
assert(fwd_rmse < rmse_tol * fwd_range assert (
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( fwd_rmse < rmse_tol * fwd_range
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range) ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
assert(bwd_rmse < rmse_tol * bwd_range fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( )
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range) assert (
bwd_rmse < rmse_tol * bwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range
)
def _run_custom_mha_fp8(dtype, config, backend): def _run_custom_mha_fp8(dtype, config, backend):
...@@ -1517,18 +1757,25 @@ def _run_custom_mha_fp8(dtype, config, backend): ...@@ -1517,18 +1757,25 @@ def _run_custom_mha_fp8(dtype, config, backend):
if backend == "FusedAttention": if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
inp = 0.0001 * torch.randint(-100, 100, inp = 0.0001 * torch.randint(
(config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim), -100,
dtype=dtype, device="cuda", requires_grad=True) 100,
seqlens = torch.full([config.batch_size], config.max_seqlen_q, (config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim),
dtype=torch.int32, device="cuda") dtype=dtype,
device="cuda",
requires_grad=True,
)
seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32) cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
out_grad = 0.01 * torch.randn( out_grad = 0.01 * torch.randn(
config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim, config.batch_size * config.max_seqlen_q,
dtype=dtype, device="cuda") config.num_heads * config.head_dim,
torch.save(out_grad, 'out_grad.pt') dtype=dtype,
device="cuda",
)
torch.save(out_grad, "out_grad.pt")
fp8_recipe = recipe.DelayedScaling( fp8_recipe = recipe.DelayedScaling(
margin=0, margin=0,
...@@ -1543,10 +1790,13 @@ def _run_custom_mha_fp8(dtype, config, backend): ...@@ -1543,10 +1790,13 @@ def _run_custom_mha_fp8(dtype, config, backend):
out.backward(out_grad) out.backward(out_grad)
out = torch.load("out.pt") out = torch.load("out.pt")
dqkv = torch.load('dqkv.pt') dqkv = torch.load("dqkv.pt")
return (out.view(config.batch_size, config.max_seqlen_q, -1), return (
dqkv.view(config.batch_size, config.max_seqlen_q, 3, out.view(config.batch_size, config.max_seqlen_q, -1),
config.num_heads, config.head_dim).contiguous()) dqkv.view(
config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim
).contiguous(),
)
def _run_ref_mha_f16(dtype, config, backend): def _run_ref_mha_f16(dtype, config, backend):
...@@ -1560,13 +1810,14 @@ def _run_ref_mha_f16(dtype, config, backend): ...@@ -1560,13 +1810,14 @@ def _run_ref_mha_f16(dtype, config, backend):
if backend == "FusedAttention": if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
inp = torch.load('qkv.pt').to(device="cuda") inp = torch.load("qkv.pt").to(device="cuda")
inp.requires_grad = True inp.requires_grad = True
seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda") seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32) cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
out_grad = torch.load('out_grad.pt').to(device="cuda").view( out_grad = (
config.batch_size, config.max_seqlen_q, -1) torch.load("out_grad.pt").to(device="cuda").view(config.batch_size, config.max_seqlen_q, -1)
)
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
...@@ -1582,24 +1833,22 @@ def _run_ref_mha_f16(dtype, config, backend): ...@@ -1582,24 +1833,22 @@ def _run_ref_mha_f16(dtype, config, backend):
"""Get cuda rng tracker.""" """Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER return _DUMMY_CUDA_RNG_STATE_TRACKER
block = ( block = DotProductAttention(
DotProductAttention( config.num_heads,
config.num_heads, config.head_dim,
config.head_dim, attention_dropout=config.dropout_p,
attention_dropout=config.dropout_p, sequence_parallel=False,
sequence_parallel=False, tp_size=1,
tp_size=1, get_rng_state_tracker=get_dummy_cuda_rng_tracker,
get_rng_state_tracker=get_dummy_cuda_rng_tracker, tp_group=None,
tp_group=None, layer_number=1,
layer_number=1, attention_type="self",
attention_type="self", qkv_format="bshd",
qkv_format="bshd", ).to(dtype=dtype, device="cuda")
).to(dtype=dtype, device="cuda")
) q = inp[:, :, 0, :, :]
k = inp[:, :, 1, :, :]
q = inp[:,:,0,:,:] v = inp[:, :, 2, :, :]
k = inp[:,:,1,:,:]
v = inp[:,:,2,:,:]
out = block(q, k, v, attn_mask_type=config.attn_mask_type) out = block(q, k, v, attn_mask_type=config.attn_mask_type)
out.backward(out_grad) out.backward(out_grad)
...@@ -1611,12 +1860,12 @@ _2X_ACC_FPROP = False ...@@ -1611,12 +1860,12 @@ _2X_ACC_FPROP = False
_2X_ACC_DGRAD = False _2X_ACC_DGRAD = False
_2X_ACC_WGRAD = False _2X_ACC_WGRAD = False
META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1 META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
META_O = tex.FP8FwdTensors.GEMM2_INPUT META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2 META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3 META_DP = tex.FP8BwdTensors.GRAD_INPUT3
class _custom_mha_fp8(torch.autograd.Function): class _custom_mha_fp8(torch.autograd.Function):
...@@ -1682,46 +1931,57 @@ class _custom_mha_fp8(torch.autograd.Function): ...@@ -1682,46 +1931,57 @@ class _custom_mha_fp8(torch.autograd.Function):
D_dtype=fp8_dtype_forward, D_dtype=fp8_dtype_forward,
) )
qkv = qkv.view(-1, 3, h, d) qkv = qkv.view(-1, 3, h, d)
qkv_fp16 = ext.cast_from_fp8(qkv, fp8_meta["scaling_fwd"], qkv_fp16 = (
META_QKV, fp8_dtype_forward, ext.cast_from_fp8(
tex.DType.kFloat16).view(b, max_s, 3, h, d).contiguous() qkv, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward, tex.DType.kFloat16
torch.save(qkv_fp16, 'qkv.pt') )
.view(b, max_s, 3, h, d)
.contiguous()
)
torch.save(qkv_fp16, "qkv.pt")
if cudnn_frontend_version == 1: if cudnn_frontend_version == 1:
qkv = qkv.view(b, max_s, 3, h, d) # bs3hd qkv = qkv.view(b, max_s, 3, h, d) # bs3hd
# FMHA # FMHA
out, aux_ctx_tensors, *rest = fused_attn_fwd( out, aux_ctx_tensors, *rest = fused_attn_fwd(
is_training, is_training,
max_s, max_s,
max_s, max_s,
cu_seqlens, cu_seqlens,
cu_seqlens, cu_seqlens,
qkv[:,:,0,:,:] if cudnn_frontend_version == 1 else qkv[:,0,:,:], qkv[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv[:, 0, :, :],
qkv[:,:,1,:,:] if cudnn_frontend_version == 1 else qkv[:,1,:,:], qkv[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv[:, 1, :, :],
qkv[:,:,2,:,:] if cudnn_frontend_version == 1 else qkv[:,2,:,:], qkv[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv[:, 2, :, :],
fp8_dtype_forward, fp8_dtype_forward,
FusedAttnBackend["FP8"], FusedAttnBackend["FP8"],
None, None, None, None, None, None,
fp8_meta["scaling_fwd"].scale_inv[META_QKV], None,
fp8_meta["scaling_fwd"].scale_inv[META_S], None,
fp8_meta["scaling_fwd"].scale[META_S], None,
fp8_meta["scaling_fwd"].scale[META_O], None,
fp8_meta["scaling_fwd"].amax_history[0][META_S], fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].amax_history[0][META_O], fp8_meta["scaling_fwd"].scale_inv[META_S],
attn_scale=None, fp8_meta["scaling_fwd"].scale[META_S],
dropout=p_dropout, fp8_meta["scaling_fwd"].scale[META_O],
fast_zero_fill=fast_zero_fill, fp8_meta["scaling_fwd"].amax_history[0][META_S],
qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd", fp8_meta["scaling_fwd"].amax_history[0][META_O],
attn_bias_type="no_bias", attn_scale=None,
attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding", dropout=p_dropout,
rng_gen=None, fast_zero_fill=fast_zero_fill,
) qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
attn_bias_type="no_bias",
attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding",
rng_gen=None,
)
M, ZInv, philox_unpacked = aux_ctx_tensors M, ZInv, philox_unpacked = aux_ctx_tensors
ctx.save_for_backward( ctx.save_for_backward(
inp_t_fp8, qkv_weight_t_fp8, workspace, inp_t_fp8,
qkv, out, qkv_weight_t_fp8,
workspace,
qkv,
out,
fp8_meta["scaling_fwd"].scale, fp8_meta["scaling_fwd"].scale,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
) )
...@@ -1736,82 +1996,84 @@ class _custom_mha_fp8(torch.autograd.Function): ...@@ -1736,82 +1996,84 @@ class _custom_mha_fp8(torch.autograd.Function):
ctx.mask_type = mask_type ctx.mask_type = mask_type
ctx.dtype = inp.dtype ctx.dtype = inp.dtype
out = out.view(-1, in_features) # (bs)(hd) out = out.view(-1, in_features) # (bs)(hd)
out_fp16 = ext.cast_from_fp8(out, fp8_meta["scaling_fwd"], out_fp16 = ext.cast_from_fp8(
META_O, fp8_dtype_forward, tex.DType.kFloat16) out, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward, tex.DType.kFloat16
torch.save(out_fp16, 'out.pt') # (bs)(hd) )
torch.save(out_fp16, "out.pt") # (bs)(hd)
return out_fp16 return out_fp16
@staticmethod @staticmethod
def backward( def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
with torch.cuda.nvtx.range("_DPA"): with torch.cuda.nvtx.range("_DPA"):
( (
inp_t_fp8, inp_t_fp8,
qkv_weight_t_fp8, qkv_weight_t_fp8,
workspace, workspace,
qkv, out, qkv,
out,
fwd_scales, fwd_scales,
fwd_scale_inverses, fwd_scale_inverses,
) = ctx.saved_tensors ) = ctx.saved_tensors
fp8_dtype_forward = fp8.get_fp8_te_dtype( fp8_dtype_forward = fp8.get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
ctx.fp8_meta["recipe"], fprop_tensor=True fp8_dtype_backward = fp8.get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
)
fp8_dtype_backward = fp8.get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
)
proj_dgrad = ext.cast_to_fp8( proj_dgrad = ext.cast_to_fp8(
grad_output, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward grad_output, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
) # (bs)(hd) ) # (bs)(hd)
dq, dk, dv, *rest = fused_attn_bwd( dq, dk, dv, *rest = fused_attn_bwd(
ctx.max_s, ctx.max_s,
ctx.max_s, ctx.max_s,
ctx.cu_seqlens, ctx.cu_seqlens,
ctx.cu_seqlens, ctx.cu_seqlens,
qkv[:,:,0,:,:] if cudnn_frontend_version == 1 else qkv[:,0,:,:], qkv[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv[:, 0, :, :],
qkv[:,:,1,:,:] if cudnn_frontend_version == 1 else qkv[:,1,:,:], qkv[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv[:, 1, :, :],
qkv[:,:,2,:,:] if cudnn_frontend_version == 1 else qkv[:,2,:,:], qkv[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv[:, 2, :, :],
out, out,
proj_dgrad.view_as(out), proj_dgrad.view_as(out),
fp8_dtype_forward, fp8_dtype_forward,
fp8_dtype_backward, fp8_dtype_backward,
ctx.aux_ctx_tensors, ctx.aux_ctx_tensors,
FusedAttnBackend["FP8"], FusedAttnBackend["FP8"],
None, None, None, None, None,
fwd_scale_inverses[META_QKV], # d_scale_qkv, None,
fwd_scale_inverses[META_S], # d_scale_s, None,
fwd_scale_inverses[META_O], # d_scale_o, None,
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do fwd_scale_inverses[META_QKV], # d_scale_qkv,
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp fwd_scale_inverses[META_S], # d_scale_s,
fwd_scales[META_S], # q_scale_s fwd_scale_inverses[META_O], # d_scale_o,
ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO], # d_scale_do
ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP], # d_scale_dp
ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp fwd_scales[META_S], # q_scale_s
ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv ctx.fp8_meta["scaling_bwd"].scale[META_DP], # q_scale_dp
attn_scale=None, ctx.fp8_meta["scaling_bwd"].scale[META_DQKV], # q_scale_dqkv
dropout=ctx.p_dropout, ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP], # amax_dp
fast_zero_fill=ctx.fast_zero_fill, ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV], # amax_dqkv
qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd", attn_scale=None,
attn_bias_type="no_bias", dropout=ctx.p_dropout,
attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding", fast_zero_fill=ctx.fast_zero_fill,
) qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
attn_bias_type="no_bias",
attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding",
)
dim = 2 if cudnn_frontend_version == 1 else 1 dim = 2 if cudnn_frontend_version == 1 else 1
dqkv = torch.Tensor().to(device=dq.device, dtype=dq.dtype) dqkv = torch.Tensor().to(device=dq.device, dtype=dq.dtype)
dqkv_shape = list(dq.shape) dqkv_shape = list(dq.shape)
dqkv_shape.insert(dim, 3) dqkv_shape.insert(dim, 3)
dqkv_stride = list(dq.stride()) dqkv_stride = list(dq.stride())
dqkv_stride.insert(dim, int(dqkv_stride[-3]/3)) dqkv_stride.insert(dim, int(dqkv_stride[-3] / 3))
dqkv.set_(dq.untyped_storage(), dq.storage_offset(), dqkv_shape, dqkv_stride) # bs3hd dqkv.set_(dq.untyped_storage(), dq.storage_offset(), dqkv_shape, dqkv_stride) # bs3hd
dqkv_c = dqkv.view(-1, 3*ctx.hidden_size) dqkv_c = dqkv.view(-1, 3 * ctx.hidden_size)
dqkv_c_fp16 = ext.cast_from_fp8(dqkv_c, dqkv_c_fp16 = ext.cast_from_fp8(
ctx.fp8_meta["scaling_bwd"], META_DQKV, dqkv_c,
fp8_dtype_backward, tex.DType.kFloat16) ctx.fp8_meta["scaling_bwd"],
torch.save(dqkv_c_fp16, 'dqkv.pt') META_DQKV,
fp8_dtype_backward,
tex.DType.kFloat16,
)
torch.save(dqkv_c_fp16, "dqkv.pt")
qkv_bgrad, dqkv_t = ext.fp8_transpose_bgrad_fused( qkv_bgrad, dqkv_t = ext.fp8_transpose_bgrad_fused(
dqkv_c, dqkv_c,
...@@ -1850,7 +2112,8 @@ class _custom_mha_fp8(torch.autograd.Function): ...@@ -1850,7 +2112,8 @@ class _custom_mha_fp8(torch.autograd.Function):
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=_2X_ACC_WGRAD,
) )
return (qkv_dgrad, return (
qkv_dgrad,
qkv_wgrad, qkv_wgrad,
qkv_bgrad, qkv_bgrad,
None, None,
...@@ -1862,14 +2125,12 @@ class _custom_mha_fp8(torch.autograd.Function): ...@@ -1862,14 +2125,12 @@ class _custom_mha_fp8(torch.autograd.Function):
None, None,
None, None,
None, None,
None) None,
)
class Custom_MHA_FP8(TransformerEngineBaseModule): class Custom_MHA_FP8(TransformerEngineBaseModule):
def __init__( def __init__(self, config, params_dtype: torch.dtype = torch.float32):
self,
config,
params_dtype: torch.dtype = torch.float32):
super().__init__() super().__init__()
self.p_dropout = config.dropout_p self.p_dropout = config.dropout_p
self.h = config.num_heads self.h = config.num_heads
...@@ -1901,8 +2162,10 @@ class Custom_MHA_FP8(TransformerEngineBaseModule): ...@@ -1901,8 +2162,10 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
) )
def forward( def forward(
self, inp: torch.Tensor, self,
cu_seqlens, max_s, inp: torch.Tensor,
cu_seqlens,
max_s,
) -> torch.Tensor: ) -> torch.Tensor:
with self.prepare_forward(inp, None, num_gemms=3) as inp: with self.prepare_forward(inp, None, num_gemms=3) as inp:
out = _custom_mha_fp8.apply( out = _custom_mha_fp8.apply(
...@@ -1917,5 +2180,6 @@ class Custom_MHA_FP8(TransformerEngineBaseModule): ...@@ -1917,5 +2180,6 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
self.fp8_meta, self.fp8_meta,
self.workspace, self.workspace,
self.training, self.training,
self.mask_type) self.mask_type,
)
return out return out
...@@ -14,12 +14,13 @@ from transformer_engine.pytorch.utils import get_device_compute_capability ...@@ -14,12 +14,13 @@ from transformer_engine.pytorch.utils import get_device_compute_capability
model_configs_flash_attn = { model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA "cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA "cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA "cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA "cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
} }
def get_bash_arguments(**kwargs): def get_bash_arguments(**kwargs):
args = ["python", "-m", "torch.distributed.launch", "--nproc-per-node=2"] args = ["python", "-m", "torch.distributed.launch", "--nproc-per-node=2"]
te_path = os.getenv("TE_PATH", "/opt/transformerengine") te_path = os.getenv("TE_PATH", "/opt/transformerengine")
...@@ -29,46 +30,43 @@ def get_bash_arguments(**kwargs): ...@@ -29,46 +30,43 @@ def get_bash_arguments(**kwargs):
args.append(f"{k}={v}") args.append(f"{k}={v}")
return args return args
@pytest.mark.skipif(not _is_flash_attention_2_available(), reason="Flash-attn 2.0+ is required.") @pytest.mark.skipif(not _is_flash_attention_2_available(), reason="Flash-attn 2.0+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") @pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ['bf16', 'fp16']) @pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd', 'thd']) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
def test_cp_with_flash_attention(dtype, model, qkv_format): def test_cp_with_flash_attention(dtype, model, qkv_format):
subprocess.run( subprocess.run(
get_bash_arguments( get_bash_arguments(
dtype=dtype, dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FlashAttention"
model=model,
qkv_format=qkv_format,
kernel_backend='FlashAttention'
), ),
check=True check=True,
) )
model_configs_fused_attn = { model_configs_fused_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA "cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA "cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA "cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
"cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA "cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA
"cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA "cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA "cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA "cp_2_2": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
"cp_2_3": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA "cp_2_3": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
} }
@pytest.mark.skipif(_cudnn_version() < (8,9,7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") @pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ['bf16', 'fp16']) @pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd', 'thd']) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
def test_cp_with_fused_attention(dtype, model, qkv_format): def test_cp_with_fused_attention(dtype, model, qkv_format):
subprocess.run( subprocess.run(
get_bash_arguments( get_bash_arguments(
dtype=dtype, dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FusedAttention"
model=model,
qkv_format=qkv_format,
kernel_backend='FusedAttention'
), ),
check=True check=True,
) )
...@@ -8,8 +8,15 @@ import pytest ...@@ -8,8 +8,15 @@ import pytest
import torch import torch
from transformer_engine.pytorch import ( from transformer_engine.pytorch import (
DotProductAttention, LayerNormLinear, LayerNormMLP, Linear, make_graphed_callables, DotProductAttention,
MultiheadAttention, TransformerLayer, fp8_autocast, fp8_model_init, LayerNormLinear,
LayerNormMLP,
Linear,
make_graphed_callables,
MultiheadAttention,
TransformerLayer,
fp8_autocast,
fp8_model_init,
) )
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import is_bf16_compatible
...@@ -26,15 +33,18 @@ torch.cuda.manual_seed(seed) ...@@ -26,15 +33,18 @@ torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state() _cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state() _cuda_rng_state = torch.cuda.get_rng_state()
@dataclass @dataclass
class ModelConfig: class ModelConfig:
"""Data tensor dimensions within Transformer model""" """Data tensor dimensions within Transformer model"""
sequence_length: int sequence_length: int
batch_size: int batch_size: int
hidden_size: int hidden_size: int
num_heads: int num_heads: int
kv_channels: int kv_channels: int
model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)} model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}
modules = ["transformer", "layernorm_mlp", "layernorm_linear", "linear", "mha", "dpa"] modules = ["transformer", "layernorm_mlp", "layernorm_linear", "linear", "mha", "dpa"]
...@@ -66,7 +76,9 @@ def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) ...@@ -66,7 +76,9 @@ def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None)
for i, (t1, t2) in enumerate(zip(l1, l2)): for i, (t1, t2) in enumerate(zip(l1, l2)):
if not torch.equal(t1, t2): if not torch.equal(t1, t2):
failed = True failed = True
failed_tensors += f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n" failed_tensors += (
f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n"
)
assert not failed, "Output mismatches in:\n" + failed_tensors assert not failed, "Output mismatches in:\n" + failed_tensors
...@@ -157,41 +169,51 @@ def _test_cuda_graphs( ...@@ -157,41 +169,51 @@ def _test_cuda_graphs(
with fp8_model_init(enabled=fp8_params): with fp8_model_init(enabled=fp8_params):
# Create modules. # Create modules.
if module == "transformer": if module == "transformer":
modules = [TransformerLayer( modules = [
config.hidden_size, TransformerLayer(
config.hidden_size, config.hidden_size,
config.num_heads, config.hidden_size,
hidden_dropout=0.0, config.num_heads,
attention_dropout=0.0, hidden_dropout=0.0,
fuse_qkv_params=True, attention_dropout=0.0,
params_dtype=dtype, fuse_qkv_params=True,
) for _ in range(num_layers)] params_dtype=dtype,
)
for _ in range(num_layers)
]
elif module == "layernorm_mlp": elif module == "layernorm_mlp":
modules = [LayerNormMLP( modules = [
config.hidden_size, config.hidden_size, params_dtype=dtype LayerNormMLP(config.hidden_size, config.hidden_size, params_dtype=dtype)
) for _ in range(num_layers)] for _ in range(num_layers)
]
elif module == "layernorm_linear": elif module == "layernorm_linear":
modules = [LayerNormLinear( modules = [
config.hidden_size, config.hidden_size, params_dtype=dtype LayerNormLinear(config.hidden_size, config.hidden_size, params_dtype=dtype)
) for _ in range(num_layers)] for _ in range(num_layers)
]
elif module == "mha": elif module == "mha":
modules = [MultiheadAttention( modules = [
config.hidden_size, MultiheadAttention(
config.num_heads, config.hidden_size,
attention_dropout=0.0, config.num_heads,
params_dtype=dtype, attention_dropout=0.0,
fuse_qkv_params=True, params_dtype=dtype,
) for _ in range(num_layers)] fuse_qkv_params=True,
)
for _ in range(num_layers)
]
elif dpa: elif dpa:
assert config.hidden_size % config.num_heads == 0, "Err." assert config.hidden_size % config.num_heads == 0, "Err."
assert num_layers == 1, "Err." assert num_layers == 1, "Err."
modules = [DotProductAttention( modules = [
config.num_heads, config.kv_channels, attention_dropout=0.0 DotProductAttention(config.num_heads, config.kv_channels, attention_dropout=0.0)
) for _ in range(num_layers)] for _ in range(num_layers)
]
else: else:
modules = [Linear( modules = [
config.hidden_size, config.hidden_size, device="cuda", params_dtype=dtype Linear(config.hidden_size, config.hidden_size, device="cuda", params_dtype=dtype)
) for _ in range(num_layers)] for _ in range(num_layers)
]
# Initialize gradient buffers. # Initialize gradient buffers.
for module in modules: for module in modules:
...@@ -238,7 +260,7 @@ def _test_cuda_graphs( ...@@ -238,7 +260,7 @@ def _test_cuda_graphs(
with fp8_autocast(enabled=fp8): with fp8_autocast(enabled=fp8):
kwargs = {} kwargs = {}
if fp8_weight_caching: if fp8_weight_caching:
kwargs["is_first_microbatch"] = (grad_accumulation_step == 0) kwargs["is_first_microbatch"] = grad_accumulation_step == 0
output = model(*inputs, **kwargs) output = model(*inputs, **kwargs)
output.backward(grad_output) output.backward(grad_output)
if not dpa: if not dpa:
......
...@@ -27,33 +27,31 @@ num_heads = 16 ...@@ -27,33 +27,31 @@ num_heads = 16
head_dim = 64 head_dim = 64
dtype = torch.bfloat16 dtype = torch.bfloat16
class TestDeferredInit: class TestDeferredInit:
@staticmethod @staticmethod
def get_module_args(module): def get_module_args(module):
hidden_size = num_heads * head_dim hidden_size = num_heads * head_dim
args = (hidden_size,) args = (hidden_size,)
kwargs = { kwargs = {"params_dtype": dtype, "device": "meta"}
'params_dtype': dtype,
'device': 'meta'
}
if module in [te.Linear, te.LayerNormLinear, te.LayerNormMLP]: if module in [te.Linear, te.LayerNormLinear, te.LayerNormMLP]:
ffn_hidden_size = 2 * hidden_size ffn_hidden_size = 2 * hidden_size
args += (ffn_hidden_size, ) args += (ffn_hidden_size,)
kwargs['bias'] = True kwargs["bias"] = True
if module == te.LayerNormMLP: if module == te.LayerNormMLP:
kwargs['seq_length'] = seq_length kwargs["seq_length"] = seq_length
elif module == te.MultiheadAttention: elif module == te.MultiheadAttention:
args += (num_heads, ) args += (num_heads,)
kwargs['fuse_qkv_params'] = True kwargs["fuse_qkv_params"] = True
elif module == te.TransformerLayer: elif module == te.TransformerLayer:
args += (3 * hidden_size, num_heads) args += (3 * hidden_size, num_heads)
kwargs['fuse_qkv_params'] = True kwargs["fuse_qkv_params"] = True
kwargs['seq_length'] = seq_length kwargs["seq_length"] = seq_length
return args, kwargs return args, kwargs
@pytest.mark.parametrize("module_type", _core_modules+_composed_modules) @pytest.mark.parametrize("module_type", _core_modules + _composed_modules)
def test_zero_memory_init( def test_zero_memory_init(
self, self,
module_type: torch.nn.Module, module_type: torch.nn.Module,
......
...@@ -26,6 +26,7 @@ _tols: Dict[tex.DType, Dict[str, float]] = { ...@@ -26,6 +26,7 @@ _tols: Dict[tex.DType, Dict[str, float]] = {
tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125), # epsilon = 0.125 tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125), # epsilon = 0.125
} }
def _to_list(x: Union[Iterable, Any]) -> List: def _to_list(x: Union[Iterable, Any]) -> List:
"""Convert to list if iterable, otherwise put in singleton list""" """Convert to list if iterable, otherwise put in singleton list"""
if isinstance(x, Iterable): if isinstance(x, Iterable):
...@@ -33,12 +34,14 @@ def _to_list(x: Union[Iterable, Any]) -> List: ...@@ -33,12 +34,14 @@ def _to_list(x: Union[Iterable, Any]) -> List:
else: else:
return [x] return [x]
# Types that can be interpreted as tensor dims # Types that can be interpreted as tensor dims
DimsType = Union[Iterable[int], int] DimsType = Union[Iterable[int], int]
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFloat8Tensor: class TestFloat8Tensor:
...@@ -108,7 +111,7 @@ class TestFloat8Tensor: ...@@ -108,7 +111,7 @@ class TestFloat8Tensor:
def test_quantize_dequantize_scales(self, scale: float) -> None: def test_quantize_dequantize_scales(self, scale: float) -> None:
self._test_quantize_dequantize(scale=scale) self._test_quantize_dequantize(scale=scale)
@pytest.mark.parametrize("dims", [[], 1, 311, [7,11], [7,5,3], [2,3,5,3]]) @pytest.mark.parametrize("dims", [[], 1, 311, [7, 11], [7, 5, 3], [2, 3, 5, 3]])
def test_quantize_dequantize_dims(self, dims: DimsType) -> None: def test_quantize_dequantize_dims(self, dims: DimsType) -> None:
self._test_quantize_dequantize(dims=dims) self._test_quantize_dequantize(dims=dims)
...@@ -310,7 +313,7 @@ class TestFloat8Tensor: ...@@ -310,7 +313,7 @@ class TestFloat8Tensor:
def test_serialization( def test_serialization(
self, self,
dims: DimsType = [2,3,5], dims: DimsType = [2, 3, 5],
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 0.5, scale: float = 0.5,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
......
...@@ -117,9 +117,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -117,9 +117,7 @@ class TestFusedAdam(TestFusedOptimizer):
tensors = [] tensors = []
for size in sizes: for size in sizes:
tensors.append(torch.rand(size, dtype=torch.float, device="cuda")) tensors.append(torch.rand(size, dtype=torch.float, device="cuda"))
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim( ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(tensors, self.options)
tensors, self.options
)
for i in range(self.iters): for i in range(self.iters):
self.gen_grad(ref_param, tst_param) self.gen_grad(ref_param, tst_param)
...@@ -139,9 +137,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -139,9 +137,7 @@ class TestFusedAdam(TestFusedOptimizer):
} }
tensor = torch.rand(nelem, dtype=torch.float, device="cuda") tensor = torch.rand(nelem, dtype=torch.float, device="cuda")
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim( ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim([tensor], adam_option)
[tensor], adam_option
)
for i in range(self.iters): for i in range(self.iters):
self.gen_grad(ref_param, tst_param) self.gen_grad(ref_param, tst_param)
...@@ -161,9 +157,7 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -161,9 +157,7 @@ class TestFusedAdam(TestFusedOptimizer):
} }
tensor = torch.rand(nelem, dtype=torch.float, device="cuda") tensor = torch.rand(nelem, dtype=torch.float, device="cuda")
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim( ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim([tensor], adam_option)
[tensor], adam_option
)
# Add an empty param group which may occur for pipeline parallel p-tuning # Add an empty param group which may occur for pipeline parallel p-tuning
tst_optim.add_param_group({"params": []}) tst_optim.add_param_group({"params": []})
...@@ -175,10 +169,11 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -175,10 +169,11 @@ class TestFusedAdam(TestFusedOptimizer):
torch.testing.assert_close(ref_param, tst_param) torch.testing.assert_close(ref_param, tst_param)
class TestFusedSGD(TestFusedOptimizer): class TestFusedSGD(TestFusedOptimizer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(TestFusedSGD, self).__init__(*args, **kwargs) super(TestFusedSGD, self).__init__(*args, **kwargs)
self.options = {"lr": .25, "momentum": .125} self.options = {"lr": 0.25, "momentum": 0.125}
self.ref_optim = torch.optim.SGD self.ref_optim = torch.optim.SGD
self.fused_optim = te.optimizers.FusedSGD self.fused_optim = te.optimizers.FusedSGD
...@@ -188,7 +183,7 @@ class TestFusedSGD(TestFusedOptimizer): ...@@ -188,7 +183,7 @@ class TestFusedSGD(TestFusedOptimizer):
def test_half(self): def test_half(self):
self.gen_single_type_test(param_type=torch.float16) self.gen_single_type_test(param_type=torch.float16)
@unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required") @unittest.skipIf(torch.cuda.device_count() < 2, "more than 1 GPU required")
def test_multi_device(self): def test_multi_device(self):
devices = ("cuda:0", "cuda:1") devices = ("cuda:0", "cuda:1")
for current_dev, tensor_dev in product(devices, devices): for current_dev, tensor_dev in product(devices, devices):
...@@ -452,8 +447,8 @@ class AdamTest(unittest.TestCase): ...@@ -452,8 +447,8 @@ class AdamTest(unittest.TestCase):
@largeTensorTest("60GB", "cuda") @largeTensorTest("60GB", "cuda")
def testLargeTensor(self): def testLargeTensor(self):
t = torch.zeros(2359332864, dtype=torch.half, device='cuda') t = torch.zeros(2359332864, dtype=torch.half, device="cuda")
t2 = torch.zeros(2359332864, dtype=torch.half, device='cuda') t2 = torch.zeros(2359332864, dtype=torch.half, device="cuda")
grad = torch.randn_like(t) grad = torch.randn_like(t)
t.grad = grad t.grad = grad
t2.grad = grad t2.grad = grad
......
...@@ -26,10 +26,7 @@ def apply_rotary_pos_emb_thd( ...@@ -26,10 +26,7 @@ def apply_rotary_pos_emb_thd(
""" """
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return torch.cat( return torch.cat(
[ [apply_rotary_pos_emb(x.unsqueeze(1), freqs[: x.size(0)]) for x in torch.split(t, seqlens)]
apply_rotary_pos_emb(x.unsqueeze(1), freqs[: x.size(0)])
for x in torch.split(t, seqlens)
]
).squeeze(1) ).squeeze(1)
...@@ -45,6 +42,7 @@ def get_tol(dtype: torch.dtype) -> Dict: ...@@ -45,6 +42,7 @@ def get_tol(dtype: torch.dtype) -> Dict:
def _overlapping_grad(output: torch.Tensor) -> torch.Tensor: def _overlapping_grad(output: torch.Tensor) -> torch.Tensor:
return output.sum() * 2 return output.sum() * 2
# Gradient is a full tensor # Gradient is a full tensor
def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor: def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
t = torch.ones_like(output) t = torch.ones_like(output)
...@@ -86,9 +84,7 @@ def test_fused_rope( ...@@ -86,9 +84,7 @@ def test_fused_rope(
emb = rotary_pos_emb(seq_length) emb = rotary_pos_emb(seq_length)
# unfused # unfused
output_unfused = apply_rotary_pos_emb( output_unfused = apply_rotary_pos_emb(t, emb, tensor_format=tensor_format, fused=False)
t, emb, tensor_format=tensor_format, fused=False
)
loss_unfused = loss_func(output_unfused) loss_unfused = loss_func(output_unfused)
loss_unfused.backward() loss_unfused.backward()
grad_unfused = t.grad.detach().clone() grad_unfused = t.grad.detach().clone()
......
...@@ -13,23 +13,16 @@ num_heads = 16 ...@@ -13,23 +13,16 @@ num_heads = 16
head_dim = 64 head_dim = 64
dtype = torch.bfloat16 dtype = torch.bfloat16
num_attn_head = 16 num_attn_head = 16
ffn_hidden_size=1024 ffn_hidden_size = 1024
@pytest.mark.parametrize("kv_channels", [128, 256]) @pytest.mark.parametrize("kv_channels", [128, 256])
@pytest.mark.parametrize("hidden_size", [128, 256]) @pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("num_gqa_groups", [1, 2, 4, 8, 16]) @pytest.mark.parametrize("num_gqa_groups", [1, 2, 4, 8, 16])
def test_gqa( def test_gqa(kv_channels, hidden_size, num_gqa_groups) -> None:
kv_channels,
hidden_size,
num_gqa_groups
) -> None:
model = te.TransformerLayer( model = te.TransformerLayer(
hidden_size, hidden_size, ffn_hidden_size, num_attn_head, num_gqa_groups, kv_channels=kv_channels
ffn_hidden_size,
num_attn_head,
num_gqa_groups,
kv_channels=kv_channels
) )
# Run forward pass # Run forward pass
...@@ -42,10 +35,9 @@ def test_gqa( ...@@ -42,10 +35,9 @@ def test_gqa(
assert model.self_attention.layernorm_qkv.query_weight.shape[0] == kv_channels * num_attn_head assert model.self_attention.layernorm_qkv.query_weight.shape[0] == kv_channels * num_attn_head
assert model.self_attention.layernorm_qkv.query_weight.shape[1] == hidden_size assert model.self_attention.layernorm_qkv.query_weight.shape[1] == hidden_size
assert model.self_attention.layernorm_qkv.value_weight.shape[0] == kv_channels * num_gqa_groups assert model.self_attention.layernorm_qkv.value_weight.shape[0] == kv_channels * num_gqa_groups
assert model.self_attention.layernorm_qkv.value_weight.shape[1] == hidden_size assert model.self_attention.layernorm_qkv.value_weight.shape[1] == hidden_size
assert model.self_attention.proj.weight.shape[0] == hidden_size assert model.self_attention.proj.weight.shape[0] == hidden_size
assert model.self_attention.proj.weight.shape[1] == kv_channels * num_attn_head assert model.self_attention.proj.weight.shape[1] == kv_channels * num_attn_head
...@@ -11,11 +11,11 @@ import transformer_engine.pytorch as te ...@@ -11,11 +11,11 @@ import transformer_engine.pytorch as te
# Model names for test_torch_dynamo # Model names for test_torch_dynamo
_model_factory = { _model_factory = {
"Linear": [(lambda: te.Linear(16, 16)), [16, 16]], "Linear": [(lambda: te.Linear(16, 16)), [16, 16]],
"LayerNorm": [(lambda: te.LayerNorm(16)), [16, 16]], "LayerNorm": [(lambda: te.LayerNorm(16)), [16, 16]],
"LayerNormLinear": [(lambda: te.LayerNormLinear(16, 16)), [16, 16]], "LayerNormLinear": [(lambda: te.LayerNormLinear(16, 16)), [16, 16]],
"LayerNormMLP": [(lambda: te.LayerNormMLP(16, 16)), [16, 16]], "LayerNormMLP": [(lambda: te.LayerNormMLP(16, 16)), [16, 16]],
"TransformerLayer": [(lambda: te.TransformerLayer(128, 128, 2)), [4, 1, 128]], "TransformerLayer": [(lambda: te.TransformerLayer(128, 128, 2)), [4, 1, 128]],
} }
...@@ -31,11 +31,11 @@ def test_torch_dynamo(model_name: str): ...@@ -31,11 +31,11 @@ def test_torch_dynamo(model_name: str):
# Helper function to construct tensor with default options # Helper function to construct tensor with default options
def make_tensor( def make_tensor(
dims: Tuple[int], dims: Tuple[int],
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: torch.device = "cuda", device: torch.device = "cuda",
requires_grad: bool = True, requires_grad: bool = True,
**kwargs, **kwargs,
): ):
return torch.zeros( return torch.zeros(
dims, dims,
......
...@@ -28,9 +28,7 @@ appliers = [MultiTensorApply(2048 * 32), MultiTensorApply(333), MultiTensorApply ...@@ -28,9 +28,7 @@ appliers = [MultiTensorApply(2048 * 32), MultiTensorApply(333), MultiTensorApply
@pytest.mark.parametrize("in_type", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("in_type", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("out_type", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("out_type", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("inplace", [False, True]) @pytest.mark.parametrize("inplace", [False, True])
def test_multi_tensor_scale( def test_multi_tensor_scale(input_size_pair, applier, repeat, in_type, out_type, inplace):
input_size_pair, applier, repeat, in_type, out_type, inplace
):
if inplace is True and (out_type is not in_type): if inplace is True and (out_type is not in_type):
pytest.skip("inplace=True and out_type != in_type is not supported.") pytest.skip("inplace=True and out_type != in_type is not supported.")
elif (in_type == torch.float16 and out_type == torch.bfloat16) or ( elif (in_type == torch.float16 and out_type == torch.bfloat16) or (
...@@ -154,9 +152,7 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens ...@@ -154,9 +152,7 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens
in_list += [a.clone().to(in_type), b.clone().to(in_type)] in_list += [a.clone().to(in_type), b.clone().to(in_type)]
if per_tensor: if per_tensor:
norm, norm_per_tensor = applier( norm, norm_per_tensor = applier(tex.multi_tensor_l2norm, overflow_buf, [in_list], True)
tex.multi_tensor_l2norm, overflow_buf, [in_list], True
)
normab = torch.cat((a.norm().view(1), b.norm().view(1))) normab = torch.cat((a.norm().view(1), b.norm().view(1)))
norm_per_tensor = norm_per_tensor.view(-1, 2) norm_per_tensor = norm_per_tensor.view(-1, 2)
else: else:
...@@ -168,9 +164,7 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens ...@@ -168,9 +164,7 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens
torch.testing.assert_close(norm, reference.broadcast_to(norm.shape)) torch.testing.assert_close(norm, reference.broadcast_to(norm.shape))
if per_tensor: if per_tensor:
torch.testing.assert_close( torch.testing.assert_close(norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape))
norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape)
)
assert overflow_buf.item() == 0 assert overflow_buf.item() == 0
...@@ -179,9 +173,7 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens ...@@ -179,9 +173,7 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens
@pytest.mark.parametrize("repeat", [1, 55]) @pytest.mark.parametrize("repeat", [1, 55])
@pytest.mark.parametrize("in_type", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("in_type", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("per_tensor", [False, True]) @pytest.mark.parametrize("per_tensor", [False, True])
def test_multi_tensor_unscale_l2norm( def test_multi_tensor_unscale_l2norm(input_size_pair, applier, repeat, in_type, per_tensor):
input_size_pair, applier, repeat, in_type, per_tensor
):
sizea, sizeb = input_size_pair sizea, sizeb = input_size_pair
device = torch.device("cuda") device = torch.device("cuda")
val = 4.0 val = 4.0
...@@ -205,9 +197,7 @@ def test_multi_tensor_unscale_l2norm( ...@@ -205,9 +197,7 @@ def test_multi_tensor_unscale_l2norm(
inv_scale_cuda, inv_scale_cuda,
True, True,
) )
normab = torch.cat( normab = torch.cat(((a * inv_scale).norm().view(1), (b * inv_scale).norm().view(1)))
((a * inv_scale).norm().view(1), (b * inv_scale).norm().view(1))
)
norm_per_tensor = norm_per_tensor.view(-1, 2) norm_per_tensor = norm_per_tensor.view(-1, 2)
else: else:
norm, _ = applier( norm, _ = applier(
...@@ -224,7 +214,5 @@ def test_multi_tensor_unscale_l2norm( ...@@ -224,7 +214,5 @@ def test_multi_tensor_unscale_l2norm(
torch.testing.assert_close(norm, reference.broadcast_to(norm.shape)) torch.testing.assert_close(norm, reference.broadcast_to(norm.shape))
if per_tensor: if per_tensor:
torch.testing.assert_close( torch.testing.assert_close(norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape))
norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape)
)
assert overflow_buf.item() == 0 assert overflow_buf.item() == 0
...@@ -20,8 +20,15 @@ from transformer_engine.pytorch.utils import ( ...@@ -20,8 +20,15 @@ from transformer_engine.pytorch.utils import (
is_bf16_compatible, is_bf16_compatible,
) )
from transformer_engine.pytorch import ( from transformer_engine.pytorch import (
DotProductAttention, LayerNormLinear, LayerNormMLP, Linear, DotProductAttention,
MultiheadAttention, RMSNorm, TransformerLayer, LayerNorm, InferenceParams LayerNormLinear,
LayerNormMLP,
Linear,
MultiheadAttention,
RMSNorm,
TransformerLayer,
LayerNorm,
InferenceParams,
) )
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
...@@ -106,10 +113,11 @@ def assert_allclose( ...@@ -106,10 +113,11 @@ def assert_allclose(
if not result: if not result:
diff = torch.abs(t1 - t2).flatten() diff = torch.abs(t1 - t2).flatten()
m = torch.argmax(diff) m = torch.argmax(diff)
msg = (f"Outputs not close enough in tensor at idx={i}. " msg = (
f"Location of the maximum difference: {m.item()} " f"Outputs not close enough in tensor at idx={i}. "
f"with {t1.flatten()[m].item()} vs {t2.flatten()[m].item()} " f"Location of the maximum difference: {m.item()} "
f"(diff {diff[m].item()})." f"with {t1.flatten()[m].item()} vs {t2.flatten()[m].item()} "
f"(diff {diff[m].item()})."
) )
raise AssertionError(msg) raise AssertionError(msg)
...@@ -175,9 +183,7 @@ class TorchDotProductAttention(torch.nn.Module): ...@@ -175,9 +183,7 @@ class TorchDotProductAttention(torch.nn.Module):
) )
# [sq, b, np, hn] -> [sq, b * np, hn] # [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.reshape( query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1)
output_size[2], output_size[0] * output_size[1], -1
)
# [sk, b, np, hn] -> [sk, b * np, hn] # [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1) key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)
...@@ -216,14 +222,10 @@ class TorchDotProductAttention(torch.nn.Module): ...@@ -216,14 +222,10 @@ class TorchDotProductAttention(torch.nn.Module):
) )
# change view [sk, b * np, hn] # change view [sk, b * np, hn]
value_layer = value_layer.reshape( value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1)
value_layer.size(0), output_size[0] * output_size[1], -1
)
# change view [b * np, sq, sk] # change view [b * np, sq, sk]
attention_probs = attention_probs.view( attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
output_size[0] * output_size[1], output_size[2], -1
)
# matmul: [b * np, sq, hn] # matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
...@@ -241,9 +243,7 @@ class TorchDotProductAttention(torch.nn.Module): ...@@ -241,9 +243,7 @@ class TorchDotProductAttention(torch.nn.Module):
class TorchLayerNorm(nn.Module): class TorchLayerNorm(nn.Module):
def __init__(self, in_features: int, def __init__(self, in_features: int, eps: float, zero_centered_gamma: bool):
eps: float,
zero_centered_gamma: bool):
super().__init__() super().__init__()
self.eps = eps self.eps = eps
self.in_features = in_features self.in_features = in_features
...@@ -260,10 +260,12 @@ class TorchLayerNorm(nn.Module): ...@@ -260,10 +260,12 @@ class TorchLayerNorm(nn.Module):
w = w.to(torch.float32) w = w.to(torch.float32)
b = self.bias.to(torch.float32) b = self.bias.to(torch.float32)
inp = x.to(torch.float32) inp = x.to(torch.float32)
out = torch.nn.functional.layer_norm(inp, (self.in_features,), weight=w, out = torch.nn.functional.layer_norm(
bias=b, eps=self.eps) inp, (self.in_features,), weight=w, bias=b, eps=self.eps
)
return out.to(x.dtype) return out.to(x.dtype)
# Adapted from https://github.com/bzhangGo/rmsnorm/blob/c6691f20ec0af4128c8159c903071f7575404295/rmsnorm_torch.py # Adapted from https://github.com/bzhangGo/rmsnorm/blob/c6691f20ec0af4128c8159c903071f7575404295/rmsnorm_torch.py
class TorchRMSNorm(nn.Module): class TorchRMSNorm(nn.Module):
def __init__(self, in_features, zero_centered_gamma, eps=1e-5): def __init__(self, in_features, zero_centered_gamma, eps=1e-5):
...@@ -278,11 +280,11 @@ class TorchRMSNorm(nn.Module): ...@@ -278,11 +280,11 @@ class TorchRMSNorm(nn.Module):
self.register_parameter("weight", self.weight) self.register_parameter("weight", self.weight)
def forward(self, x): def forward(self, x):
norm_x2 = torch.sum(x.float()**2, dim=-1, keepdim=True) norm_x2 = torch.sum(x.float() ** 2, dim=-1, keepdim=True)
d_x = self.in_features d_x = self.in_features
rms_x2 = norm_x2 / d_x + self.eps rms_x2 = norm_x2 / d_x + self.eps
r_rms_x = rms_x2 ** (-1. / 2) r_rms_x = rms_x2 ** (-1.0 / 2)
x_normed = x * r_rms_x x_normed = x * r_rms_x
w = self.weight.float() w = self.weight.float()
...@@ -292,17 +294,24 @@ class TorchRMSNorm(nn.Module): ...@@ -292,17 +294,24 @@ class TorchRMSNorm(nn.Module):
class TorchLayerNormLinear(nn.Module): class TorchLayerNormLinear(nn.Module):
def __init__(self, in_features: int, out_features: int, def __init__(
eps: float, bias: bool = True, self,
normalization: str = "LayerNorm", in_features: int,
zero_centered_gamma: bool = False): out_features: int,
eps: float,
bias: bool = True,
normalization: str = "LayerNorm",
zero_centered_gamma: bool = False,
):
super().__init__() super().__init__()
if normalization == "LayerNorm": if normalization == "LayerNorm":
self.layernorm = TorchLayerNorm(in_features, eps=eps, self.layernorm = TorchLayerNorm(
zero_centered_gamma=zero_centered_gamma) in_features, eps=eps, zero_centered_gamma=zero_centered_gamma
)
elif normalization == "RMSNorm": elif normalization == "RMSNorm":
self.layernorm = TorchRMSNorm(in_features, eps=eps, self.layernorm = TorchRMSNorm(
zero_centered_gamma=zero_centered_gamma) in_features, eps=eps, zero_centered_gamma=zero_centered_gamma
)
else: else:
raise RuntimeError("Unsupported normalization") raise RuntimeError("Unsupported normalization")
...@@ -329,21 +338,26 @@ class TorchMHA(nn.Module): ...@@ -329,21 +338,26 @@ class TorchMHA(nn.Module):
output = output[0] output = output[0]
return output return output
class TorchQuickGELU(nn.Module): class TorchQuickGELU(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
return input * torch.sigmoid(1.702 * input) return input * torch.sigmoid(1.702 * input)
class TorchSquaredRELU(nn.Module): class TorchSquaredRELU(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
return (input > 0) * input * input return (input > 0) * input * input
_supported_act = {'geglu' : nn.GELU(approximate="tanh"),
'gelu' : nn.GELU(approximate="tanh"), _supported_act = {
'reglu' : nn.ReLU(), "geglu": nn.GELU(approximate="tanh"),
'relu' : nn.ReLU(), "gelu": nn.GELU(approximate="tanh"),
'swiglu' : nn.SiLU(), "reglu": nn.ReLU(),
'qgelu' : TorchQuickGELU(), "relu": nn.ReLU(),
'srelu' : TorchSquaredRELU()} "swiglu": nn.SiLU(),
"qgelu": TorchQuickGELU(),
"srelu": TorchSquaredRELU(),
}
class TorchGLU(nn.Module): class TorchGLU(nn.Module):
...@@ -353,26 +367,29 @@ class TorchGLU(nn.Module): ...@@ -353,26 +367,29 @@ class TorchGLU(nn.Module):
def forward(self, x): def forward(self, x):
shape = x.size(-1) shape = x.size(-1)
a = x[..., :shape // 2] a = x[..., : shape // 2]
b = x[..., (shape // 2):] b = x[..., (shape // 2) :]
a = self.act(a) a = self.act(a)
return a * b return a * b
class TorchLayerNormMLP(nn.Module): class TorchLayerNormMLP(nn.Module):
def __init__(self, hidden_size: int, ffn_hidden_size: int, def __init__(
eps: float = 1e-5, activation = 'gelu', self,
normalization: str = "LayerNorm"): hidden_size: int,
ffn_hidden_size: int,
eps: float = 1e-5,
activation="gelu",
normalization: str = "LayerNorm",
):
super().__init__() super().__init__()
if normalization == "LayerNorm": if normalization == "LayerNorm":
self.ln = TorchLayerNorm(hidden_size, eps=eps, self.ln = TorchLayerNorm(hidden_size, eps=eps, zero_centered_gamma=False)
zero_centered_gamma=False)
elif normalization == "RMSNorm": elif normalization == "RMSNorm":
self.ln = TorchRMSNorm(hidden_size, eps=eps, self.ln = TorchRMSNorm(hidden_size, eps=eps, zero_centered_gamma=False)
zero_centered_gamma=False)
else: else:
raise RuntimeError("Unsupported normalization") raise RuntimeError("Unsupported normalization")
if 'glu' in activation: if "glu" in activation:
fc1_output_features = 2 * ffn_hidden_size fc1_output_features = 2 * ffn_hidden_size
self.gelu = TorchGLU(activation) self.gelu = TorchGLU(activation)
else: else:
...@@ -387,7 +404,9 @@ class TorchLayerNormMLP(nn.Module): ...@@ -387,7 +404,9 @@ class TorchLayerNormMLP(nn.Module):
class TorchGPT(nn.Module): class TorchGPT(nn.Module):
def __init__(self, hidden_size: int, eps: float, num_attention_heads: int, parallel_attention_mlp: bool): def __init__(
self, hidden_size: int, eps: float, num_attention_heads: int, parallel_attention_mlp: bool
):
super().__init__() super().__init__()
self.ln = nn.LayerNorm(hidden_size, eps=eps) self.ln = nn.LayerNorm(hidden_size, eps=eps)
self.causal_attn = TorchMHA(hidden_size, num_attention_heads) self.causal_attn = TorchMHA(hidden_size, num_attention_heads)
...@@ -411,7 +430,6 @@ class TorchGPT(nn.Module): ...@@ -411,7 +430,6 @@ class TorchGPT(nn.Module):
return x return x
def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False): def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False):
reset_rng_states() reset_rng_states()
FP8GlobalStateManager.reset() FP8GlobalStateManager.reset()
...@@ -421,23 +439,21 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False ...@@ -421,23 +439,21 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8 and fp8_model_params): with fp8_model_init(enabled=fp8 and fp8_model_params):
block = ( block = TransformerLayer(
TransformerLayer( config.hidden_size,
config.hidden_size, 4 * config.hidden_size,
4 * config.hidden_size, config.num_attention_heads,
config.num_attention_heads, layernorm_epsilon=config.eps,
layernorm_epsilon=config.eps, init_method=init_method,
init_method=init_method, output_layer_init_method=output_layer_init_method,
output_layer_init_method=output_layer_init_method, hidden_dropout=0.1,
hidden_dropout=0.1, attention_dropout=0.1,
attention_dropout=0.1, kv_channels=config.embed,
kv_channels=config.embed, apply_residual_connection_post_layernorm=False,
apply_residual_connection_post_layernorm=False, output_layernorm=False,
output_layernorm=False, params_dtype=dtype,
params_dtype=dtype, fuse_qkv_params=True,
fuse_qkv_params=True, device="cuda",
device="cuda",
)
) )
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
...@@ -477,8 +493,12 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_par ...@@ -477,8 +493,12 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_par
config = model_configs[model] config = model_configs[model]
outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False) outputs = _test_e2e_selective_recompute(
outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True) bs, dtype, config, fp8, fp8_model_params, recompute=False
)
outputs_recompute = _test_e2e_selective_recompute(
bs, dtype, config, fp8, fp8_model_params, recompute=True
)
# Check that results match # Check that results match
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
...@@ -496,10 +516,7 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_par ...@@ -496,10 +516,7 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_par
def _test_e2e_full_recompute( def _test_e2e_full_recompute(
bs, dtype, config, fp8, bs, dtype, config, fp8, fp8_model_params=False, recompute=False, use_reentrant=True
fp8_model_params=False,
recompute=False,
use_reentrant=True
): ):
reset_rng_states() reset_rng_states()
FP8GlobalStateManager.reset() FP8GlobalStateManager.reset()
...@@ -586,10 +603,12 @@ def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params, ...@@ -586,10 +603,12 @@ def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params,
# Non-reentrant checkpoint becomes non-deterministic with bias+GELU fusion # Non-reentrant checkpoint becomes non-deterministic with bias+GELU fusion
os.environ["NVTE_BIAS_GELU_NVFUSION"] = "0" os.environ["NVTE_BIAS_GELU_NVFUSION"] = "0"
outputs, names = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, outputs, names = _test_e2e_full_recompute(
recompute=False, use_reentrant=use_reentrant) bs, dtype, config, fp8, fp8_model_params, recompute=False, use_reentrant=use_reentrant
outputs_recompute, _ = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, )
recompute=True, use_reentrant=use_reentrant) outputs_recompute, _ = _test_e2e_full_recompute(
bs, dtype, config, fp8, fp8_model_params, recompute=True, use_reentrant=use_reentrant
)
if not use_reentrant: if not use_reentrant:
# Reset bias+GELU fusion flag to avoid contaminating other tests # Reset bias+GELU fusion flag to avoid contaminating other tests
...@@ -753,22 +772,19 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): ...@@ -753,22 +772,19 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
config = model_configs[model] config = model_configs[model]
te_gpt = ( te_gpt = TransformerLayer(
TransformerLayer( hidden_size=config.hidden_size,
hidden_size=config.hidden_size, ffn_hidden_size=4 * config.hidden_size,
ffn_hidden_size=4 * config.hidden_size, num_attention_heads=config.num_attention_heads,
num_attention_heads=config.num_attention_heads, layernorm_epsilon=config.eps,
layernorm_epsilon=config.eps, attention_dropout=0.1,
attention_dropout=0.1, hidden_dropout=0.1,
hidden_dropout=0.1, params_dtype=dtype,
params_dtype=dtype, fuse_qkv_params=True,
fuse_qkv_params=True, qkv_weight_interleaved=False,
qkv_weight_interleaved=False, parallel_attention_mlp=parallel_attention_mlp,
parallel_attention_mlp=parallel_attention_mlp, device="cuda",
device="cuda", ).eval()
)
.eval()
)
torch_gpt = ( torch_gpt = (
TorchGPT( TorchGPT(
...@@ -853,18 +869,15 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): ...@@ -853,18 +869,15 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
def test_mha_accuracy(dtype, bs, model, mask_type): def test_mha_accuracy(dtype, bs, model, mask_type):
config = model_configs[model] config = model_configs[model]
te_mha = ( te_mha = MultiheadAttention(
MultiheadAttention( config.hidden_size,
config.hidden_size, config.num_attention_heads,
config.num_attention_heads, fuse_qkv_params=True,
fuse_qkv_params=True, params_dtype=dtype,
params_dtype=dtype, qkv_weight_interleaved=False,
qkv_weight_interleaved=False, input_layernorm=False,
input_layernorm=False, device="cuda",
device="cuda", ).eval()
)
.eval()
)
torch_mha = ( torch_mha = (
TorchMHA( TorchMHA(
...@@ -919,7 +932,9 @@ def _test_granular_accuracy(block, bs, dtype, config): ...@@ -919,7 +932,9 @@ def _test_granular_accuracy(block, bs, dtype, config):
def _test_dpa_accuracy(block, bs, dtype, config): def _test_dpa_accuracy(block, bs, dtype, config):
reset_rng_states() reset_rng_states()
mask = torch.triu(torch.ones(config.seq_len, config.seq_len, dtype=torch.bool, device="cuda"), diagonal=1) mask = torch.triu(
torch.ones(config.seq_len, config.seq_len, dtype=torch.bool, device="cuda"), diagonal=1
)
query, key, value = [ query, key, value = [
torch.randn( torch.randn(
(config.seq_len, bs, config.num_attention_heads, config.embed), (config.seq_len, bs, config.num_attention_heads, config.embed),
...@@ -953,7 +968,7 @@ def test_dpa_accuracy(dtype, bs, model): ...@@ -953,7 +968,7 @@ def test_dpa_accuracy(dtype, bs, model):
DotProductAttention( DotProductAttention(
config.num_attention_heads, config.num_attention_heads,
config.embed, config.embed,
attention_dropout=0.0, # disable dropout, FU uses rng differently attention_dropout=0.0, # disable dropout, FU uses rng differently
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -962,7 +977,7 @@ def test_dpa_accuracy(dtype, bs, model): ...@@ -962,7 +977,7 @@ def test_dpa_accuracy(dtype, bs, model):
torch_dpa = ( torch_dpa = (
TorchDotProductAttention( TorchDotProductAttention(
config.embed, config.embed,
0.0, # dropout 0.0, # dropout
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -984,27 +999,21 @@ def test_dpa_accuracy(dtype, bs, model): ...@@ -984,27 +999,21 @@ def test_dpa_accuracy(dtype, bs, model):
def test_linear_accuracy(dtype, bs, model): def test_linear_accuracy(dtype, bs, model):
config = model_configs[model] config = model_configs[model]
te_linear = ( te_linear = Linear(
Linear( config.hidden_size,
config.hidden_size, 4 * config.hidden_size,
4 * config.hidden_size, bias=True,
bias=True, params_dtype=dtype,
params_dtype=dtype, device="cuda",
device="cuda", ).eval()
)
.eval()
)
torch_linear = ( torch_linear = torch.nn.Linear(
torch.nn.Linear( config.hidden_size,
config.hidden_size, 4 * config.hidden_size,
4 * config.hidden_size, bias=True,
bias=True, device="cuda",
device="cuda", dtype=dtype,
dtype=dtype, ).eval()
)
.eval()
)
# Share params # Share params
with torch.no_grad(): with torch.no_grad():
...@@ -1029,23 +1038,16 @@ def test_linear_accuracy(dtype, bs, model): ...@@ -1029,23 +1038,16 @@ def test_linear_accuracy(dtype, bs, model):
def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
config = model_configs[model] config = model_configs[model]
te_rmsnorm = ( te_rmsnorm = RMSNorm(
RMSNorm( config.hidden_size,
config.hidden_size, eps=eps,
eps=eps, params_dtype=dtype,
params_dtype=dtype, zero_centered_gamma=zero_centered_gamma,
zero_centered_gamma=zero_centered_gamma, device="cuda",
device="cuda", ).eval()
)
.eval()
)
torch_rmsnorm = ( torch_rmsnorm = (
TorchRMSNorm( TorchRMSNorm(config.hidden_size, eps=eps, zero_centered_gamma=zero_centered_gamma)
config.hidden_size,
eps=eps,
zero_centered_gamma=zero_centered_gamma
)
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
.eval() .eval()
...@@ -1059,12 +1061,14 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): ...@@ -1059,12 +1061,14 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
torch_outputs = _test_granular_accuracy(torch_rmsnorm, bs, dtype, config) torch_outputs = _test_granular_accuracy(torch_rmsnorm, bs, dtype, config)
# Check output. # Check output.
atol = {torch.float32 : 1e-7, atol = {
torch.half : 2e-3, torch.float32: 1e-7,
torch.bfloat16: 2e-2, torch.half: 2e-3,
torch.bfloat16: 2e-2,
} }
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype]) assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
...@@ -1073,23 +1077,16 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): ...@@ -1073,23 +1077,16 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
config = model_configs[model] config = model_configs[model]
te_layernorm = ( te_layernorm = LayerNorm(
LayerNorm( config.hidden_size,
config.hidden_size, eps=eps,
eps=eps, params_dtype=dtype,
params_dtype=dtype, zero_centered_gamma=zero_centered_gamma,
zero_centered_gamma=zero_centered_gamma, device="cuda",
device="cuda", ).eval()
)
.eval()
)
torch_layernorm = ( torch_layernorm = (
TorchLayerNorm( TorchLayerNorm(config.hidden_size, eps=eps, zero_centered_gamma=zero_centered_gamma)
config.hidden_size,
eps=eps,
zero_centered_gamma=zero_centered_gamma
)
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
.eval() .eval()
...@@ -1104,9 +1101,10 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): ...@@ -1104,9 +1101,10 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
torch_outputs = _test_granular_accuracy(torch_layernorm, bs, dtype, config) torch_outputs = _test_granular_accuracy(torch_layernorm, bs, dtype, config)
# Check output. # Check output.
atol = {torch.float32 : 1e-7, atol = {
torch.half : 2e-3, torch.float32: 1e-7,
torch.bfloat16: 2e-2, torch.half: 2e-3,
torch.bfloat16: 2e-2,
} }
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype]) assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
...@@ -1119,19 +1117,16 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): ...@@ -1119,19 +1117,16 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centered_gamma): def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centered_gamma):
config = model_configs[model] config = model_configs[model]
te_ln_linear = ( te_ln_linear = LayerNormLinear(
LayerNormLinear( config.hidden_size,
config.hidden_size, 4 * config.hidden_size,
4 * config.hidden_size, config.eps,
config.eps, bias=True,
bias=True, normalization=normalization,
normalization=normalization, params_dtype=dtype,
params_dtype=dtype, zero_centered_gamma=zero_centered_gamma,
zero_centered_gamma=zero_centered_gamma, device="cuda",
device="cuda", ).eval()
)
.eval()
)
torch_ln_linear = ( torch_ln_linear = (
TorchLayerNormLinear( TorchLayerNormLinear(
...@@ -1159,9 +1154,10 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere ...@@ -1159,9 +1154,10 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere
torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config) torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config)
# Check output. # Check output.
atol = {torch.float32 : 2.5e-4, atol = {
torch.half : 2e-3, torch.float32: 2.5e-4,
torch.bfloat16: 2e-2, torch.half: 2e-3,
torch.bfloat16: 2e-2,
} }
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype]) assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
...@@ -1174,17 +1170,14 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere ...@@ -1174,17 +1170,14 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere
def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization): def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
config = model_configs[model] config = model_configs[model]
te_ln_mlp = ( te_ln_mlp = LayerNormMLP(
LayerNormMLP( config.hidden_size,
config.hidden_size, 4 * config.hidden_size,
4 * config.hidden_size, activation=activation,
activation=activation, normalization=normalization,
normalization=normalization, params_dtype=dtype,
params_dtype=dtype, device="cuda",
device="cuda", ).eval()
)
.eval()
)
torch_ln_mlp = ( torch_ln_mlp = (
TorchLayerNormMLP( TorchLayerNormMLP(
...@@ -1226,8 +1219,10 @@ def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph): ...@@ -1226,8 +1219,10 @@ def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
optimizer = torch.optim.SGD(block.parameters(), lr=0.1) optimizer = torch.optim.SGD(block.parameters(), lr=0.1)
# Placeholders used for graph capture. # Placeholders used for graph capture.
static_input = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype, requires_grad=True) static_input = torch.randn(
static_target = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype) config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True
)
static_target = torch.randn(config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype)
real_input = torch.rand_like(static_input) real_input = torch.rand_like(static_input)
real_target = torch.rand_like(static_target) real_target = torch.rand_like(static_target)
...@@ -1286,22 +1281,20 @@ def test_gpt_cuda_graph(dtype, bs, model): ...@@ -1286,22 +1281,20 @@ def test_gpt_cuda_graph(dtype, bs, model):
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = ( block = TransformerLayer(
TransformerLayer( config.hidden_size,
config.hidden_size, 4 * config.hidden_size,
4 * config.hidden_size, config.num_attention_heads,
config.num_attention_heads, layernorm_epsilon=config.eps,
layernorm_epsilon=config.eps, init_method=init_method,
init_method=init_method, output_layer_init_method=output_layer_init_method,
output_layer_init_method=output_layer_init_method, hidden_dropout=0.1,
hidden_dropout=0.1, attention_dropout=0.1,
attention_dropout=0.1, kv_channels=config.embed,
kv_channels=config.embed, params_dtype=dtype,
params_dtype=dtype, apply_residual_connection_post_layernorm=False,
apply_residual_connection_post_layernorm=False, output_layernorm=False,
output_layernorm=False, device="cuda",
device="cuda",
)
) )
graphed_block = copy.deepcopy(block) graphed_block = copy.deepcopy(block)
...@@ -1388,7 +1381,6 @@ def test_gpt_fp8_parameters(dtype, bs, model): ...@@ -1388,7 +1381,6 @@ def test_gpt_fp8_parameters(dtype, bs, model):
) )
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
...@@ -1451,7 +1443,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): ...@@ -1451,7 +1443,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
requires_grad=True, requires_grad=True,
) )
x_bshd = x_sbhd.transpose(0,1).contiguous() x_bshd = x_sbhd.transpose(0, 1).contiguous()
# To make sure forward is also identical (just in case some module decides # To make sure forward is also identical (just in case some module decides
# to act fancy) # to act fancy)
...@@ -1466,7 +1458,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): ...@@ -1466,7 +1458,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
# Check that results match # Check that results match
torch.testing.assert_close( torch.testing.assert_close(
y_bshd, y_bshd,
y_sbhd.transpose(0,1).contiguous(), y_sbhd.transpose(0, 1).contiguous(),
) )
...@@ -1500,19 +1492,16 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ...@@ -1500,19 +1492,16 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
S_max = S + 2 S_max = S + 2
if module == "TransformerLayer": if module == "TransformerLayer":
model = ( model = TransformerLayer(
TransformerLayer( hidden_size=D,
hidden_size=D, ffn_hidden_size=4 * D,
ffn_hidden_size= 4 * D, num_attention_heads=H,
num_attention_heads=H, attn_input_format=input_format,
attn_input_format=input_format, layer_number=layer_number,
layer_number=layer_number, attention_dropout=0.0,
attention_dropout = 0.0, params_dtype=dtype,
params_dtype=dtype, device="cuda",
device="cuda", ).eval()
)
.eval()
)
else: else:
model = ( model = (
MultiheadAttention( MultiheadAttention(
...@@ -1520,7 +1509,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ...@@ -1520,7 +1509,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
num_attention_heads=H, num_attention_heads=H,
qkv_format=input_format, qkv_format=input_format,
layer_number=layer_number, layer_number=layer_number,
attention_dropout = 0.0, attention_dropout=0.0,
params_dtype=dtype, params_dtype=dtype,
) )
.cuda() .cuda()
...@@ -1537,39 +1526,38 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ...@@ -1537,39 +1526,38 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
incremental_output = torch.zeros_like(input) incremental_output = torch.zeros_like(input)
# Generate output for the entire sequence # Generate output for the entire sequence
full_output = model( full_output = model(hidden_states=input, rotary_pos_emb=rotary_freqs if use_RoPE else None)
hidden_states=input,
rotary_pos_emb=rotary_freqs if use_RoPE else None)
# Incrementaly generate outputs using KV-cache # Incrementaly generate outputs using KV-cache
for i in range(S): for i in range(S):
if input_format == "sbhd": if input_format == "sbhd":
incremental_input = input[i].view(1,B,D) incremental_input = input[i].view(1, B, D)
else: else:
incremental_input = input[:, i, :].view(B,1,D) incremental_input = input[:, i, :].view(B, 1, D)
line_output = model( line_output = model(
hidden_states=incremental_input, hidden_states=incremental_input,
inference_params=inference_params, inference_params=inference_params,
rotary_pos_emb=rotary_freqs if use_RoPE else None) rotary_pos_emb=rotary_freqs if use_RoPE else None,
)
inference_params.sequence_len_offset += 1 inference_params.sequence_len_offset += 1
if input_format == "sbhd": if input_format == "sbhd":
incremental_output[i] = line_output.view(B,D) incremental_output[i] = line_output.view(B, D)
else: else:
incremental_output[:, i, :] = line_output.view(B,D) incremental_output[:, i, :] = line_output.view(B, D)
if module == "TransformerLayer": if module == "TransformerLayer":
atol = { atol = {
torch.float32 : 5e-3, torch.float32: 5e-3,
torch.half : 5e-3, torch.half: 5e-3,
torch.bfloat16: 5e-2, torch.bfloat16: 5e-2,
} }
else: else:
atol = { atol = {
torch.float32 : 1e-3, torch.float32: 1e-3,
torch.half : 1e-3, torch.half: 1e-3,
torch.bfloat16: 1e-2, torch.bfloat16: 1e-2,
} }
......
...@@ -32,7 +32,13 @@ from typing import Optional, Union, Tuple, List ...@@ -32,7 +32,13 @@ from typing import Optional, Union, Tuple, List
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, gelu, cast_to_fp8, cast_from_fp8 from transformer_engine.pytorch.cpp_extensions import (
gemm,
fp8_gemm,
gelu,
cast_to_fp8,
cast_from_fp8,
)
from transformer_engine.pytorch.module.base import get_workspace from transformer_engine.pytorch.module.base import get_workspace
import transformer_engine.pytorch.cpp_extensions as texcpp import transformer_engine.pytorch.cpp_extensions as texcpp
import transformer_engine.pytorch.softmax as softmax_defs import transformer_engine.pytorch.softmax as softmax_defs
...@@ -50,8 +56,10 @@ if SAVE_TEST_IO: ...@@ -50,8 +56,10 @@ if SAVE_TEST_IO:
from polygraphy.comparator import RunResults from polygraphy.comparator import RunResults
# The directory where generated ONNX test models are stored. # The directory where generated ONNX test models are stored.
NVTE_TEST_ARTIFACTS_DIR = os.environ.get('NVTE_TEST_ARTIFACTS_DIR') NVTE_TEST_ARTIFACTS_DIR = os.environ.get("NVTE_TEST_ARTIFACTS_DIR")
NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join(tempfile.gettempdir(), "./gen_onnx_models") NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join(
tempfile.gettempdir(), "./gen_onnx_models"
)
# The directory where this file is stored. # The directory where this file is stored.
...@@ -100,23 +108,21 @@ def do_export( ...@@ -100,23 +108,21 @@ def do_export(
model: torch.nn.Module, model: torch.nn.Module,
inp: torch.Tensor, inp: torch.Tensor,
fname: str, fname: str,
use_fp8: bool=True, use_fp8: bool = True,
opset: int=OPSET, opset: int = OPSET,
input_names: List[str]=None, input_names: List[str] = None,
output_names: List[str]=None, output_names: List[str] = None,
dynamic_axes: List[str]=None dynamic_axes: List[str] = None,
): ):
"""Export to ONNX""" """Export to ONNX"""
fp8_recipe = create_fp8_recipe() fp8_recipe = create_fp8_recipe()
input_names = input_names or ["input"] input_names = input_names or ["input"]
output_names = output_names or ["output"] output_names = output_names or ["output"]
with torch.inference_mode(), te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), warnings.catch_warnings(): with torch.inference_mode(), te.fp8_autocast(
warnings.filterwarnings( enabled=use_fp8, fp8_recipe=fp8_recipe
action='ignore', ), warnings.catch_warnings():
category=torch.jit.TracerWarning, warnings.filterwarnings(action="ignore", category=torch.jit.TracerWarning, module=r".*")
module=r'.*'
)
model.cuda().eval() model.cuda().eval()
os.makedirs(NVTE_TEST_ARTIFACTS_DIR, exist_ok=True) os.makedirs(NVTE_TEST_ARTIFACTS_DIR, exist_ok=True)
...@@ -138,7 +144,8 @@ def do_export( ...@@ -138,7 +144,8 @@ def do_export(
input_names=input_names, input_names=input_names,
output_names=output_names, output_names=output_names,
do_constant_folding=True, do_constant_folding=True,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH) operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH,
)
def to_numpy(tensor): def to_numpy(tensor):
...@@ -154,24 +161,30 @@ def set_layer_scale(module: torch.nn.Module, scale: float, num_gemms: int): ...@@ -154,24 +161,30 @@ def set_layer_scale(module: torch.nn.Module, scale: float, num_gemms: int):
NB_SCALES_PER_GEMM = 3 # One scale per: input, weights, and output GEMM tensors. NB_SCALES_PER_GEMM = 3 # One scale per: input, weights, and output GEMM tensors.
nb_total_scales = num_gemms * NB_SCALES_PER_GEMM nb_total_scales = num_gemms * NB_SCALES_PER_GEMM
module.init_fp8_metadata(num_gemms) module.init_fp8_metadata(num_gemms)
module.fp8_meta["scaling_fwd"].scale = torch.ones( module.fp8_meta["scaling_fwd"].scale = (
nb_total_scales, dtype=torch.float32, device="cuda") / scale torch.ones(nb_total_scales, dtype=torch.float32, device="cuda") / scale
module.fp8_meta["scaling_fwd"].scale_inv = torch.ones( )
nb_total_scales, dtype=torch.float32, device="cuda") * scale module.fp8_meta["scaling_fwd"].scale_inv = (
torch.ones(nb_total_scales, dtype=torch.float32, device="cuda") * scale
)
def te_infer(model: torch.nn.Module, inps: Union[Tuple[torch.tensor], torch.tensor], is_fp8: bool): def te_infer(model: torch.nn.Module, inps: Union[Tuple[torch.tensor], torch.tensor], is_fp8: bool):
"""Transformer Engine forward propagation.""" """Transformer Engine forward propagation."""
fp8_recipe = create_fp8_recipe() fp8_recipe = create_fp8_recipe()
with torch.inference_mode(), te.fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe), warnings.catch_warnings(): with torch.inference_mode(), te.fp8_autocast(
enabled=is_fp8, fp8_recipe=fp8_recipe
), warnings.catch_warnings():
te_outputs = model(*inps if isinstance(inps, tuple) else (inps,)) te_outputs = model(*inps if isinstance(inps, tuple) else (inps,))
if not isinstance(te_outputs, tuple): if not isinstance(te_outputs, tuple):
te_outputs = (te_outputs,) te_outputs = (te_outputs,)
return te_outputs return te_outputs
def compare_outputs(onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname): def compare_outputs(
""" Compare ORT and TE outputs.""" onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname
):
"""Compare ORT and TE outputs."""
assert len(onnx_outputs) == len(te_outputs) assert len(onnx_outputs) == len(te_outputs)
# Compare ORT and PyTorch outputs. # Compare ORT and PyTorch outputs.
for onnx_output, te_output in zip(onnx_outputs, te_outputs): for onnx_output, te_output in zip(onnx_outputs, te_outputs):
...@@ -192,11 +205,15 @@ def compare_outputs(onnx_outputs, te_outputs, atol, rtol, max_errors_printed, al ...@@ -192,11 +205,15 @@ def compare_outputs(onnx_outputs, te_outputs, atol, rtol, max_errors_printed, al
errors = abs_err[mismatches] errors = abs_err[mismatches]
for loc in mismatched_ids[:nb_vals]: for loc in mismatched_ids[:nb_vals]:
ref = te_output[loc] ref = te_output[loc]
print(f"{onnx_output[loc]} -- {te_output[loc]} err={abs_err[loc]} > {atol + rtol * abs(ref)}") print(
f"{onnx_output[loc]} -- {te_output[loc]} err={abs_err[loc]} >"
f" {atol + rtol * abs(ref)}"
)
print(f"Max error: {np.max(errors)}") print(f"Max error: {np.max(errors)}")
if nb_errors > allow_cnt_errors: if nb_errors > allow_cnt_errors:
raise ValueError(f"Output validation of {fname} failed with {nb_errors} errors") raise ValueError(f"Output validation of {fname} failed with {nb_errors} errors")
def serialize_inputs_outputs( def serialize_inputs_outputs(
fname: str, fname: str,
inputs: Union[Tuple[torch.Tensor], torch.Tensor], inputs: Union[Tuple[torch.Tensor], torch.Tensor],
...@@ -214,10 +231,10 @@ def serialize_inputs_outputs( ...@@ -214,10 +231,10 @@ def serialize_inputs_outputs(
inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,) inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,)
named_inputs = zip(input_names, inputs) named_inputs = zip(input_names, inputs)
input_data = [{k: v.cpu() for k, v in named_inputs if v is not None}] input_data = [{k: v.cpu() for k, v in named_inputs if v is not None}]
json_fname = fname[:-len(".onnx")] + "_inputs.json" json_fname = fname[: -len(".onnx")] + "_inputs.json"
save_json(input_data, json_fname, description="custom input data") save_json(input_data, json_fname, description="custom input data")
json_fname = fname[:-len(".onnx")] + "_output.json" json_fname = fname[: -len(".onnx")] + "_output.json"
named_outputs = zip(output_names, te_outputs) named_outputs = zip(output_names, te_outputs)
output_data = {k: v.detach().cpu() for k, v in named_outputs if v is not None} output_data = {k: v.detach().cpu() for k, v in named_outputs if v is not None}
custom_outputs = RunResults() custom_outputs = RunResults()
...@@ -229,14 +246,14 @@ def validate_result( ...@@ -229,14 +246,14 @@ def validate_result(
fname: str, fname: str,
inps: Union[Tuple[torch.Tensor], torch.Tensor], inps: Union[Tuple[torch.Tensor], torch.Tensor],
model: torch.nn.Module, model: torch.nn.Module,
atol: float=1.e-8, # np.isclose default atol atol: float = 1.0e-8, # np.isclose default atol
rtol: float=1.e-5, # np.isclose default rtol rtol: float = 1.0e-5, # np.isclose default rtol
max_errors_printed: int=10, max_errors_printed: int = 10,
is_fp8: bool=False, is_fp8: bool = False,
allow_cnt_errors: int=0, allow_cnt_errors: int = 0,
input_names: List[str]=None, input_names: List[str] = None,
output_names: List[str]=None, output_names: List[str] = None,
te_outputs: List[torch.Tensor]=None, te_outputs: List[torch.Tensor] = None,
): ):
"""Compare the outputs of a Transformer Engine (TE) module vs the outputs of its ONNX """Compare the outputs of a Transformer Engine (TE) module vs the outputs of its ONNX
representation using ONNX Runtime (ORT) and ensure they are close. representation using ONNX Runtime (ORT) and ensure they are close.
...@@ -262,7 +279,7 @@ def validate_result( ...@@ -262,7 +279,7 @@ def validate_result(
print("registered custom FP8 Q/DQ ops!") print("registered custom FP8 Q/DQ ops!")
"""Create an ONNX Runtime session for validation.""" """Create an ONNX Runtime session for validation."""
kwargs = {"providers": ['CUDAExecutionProvider', 'CPUExecutionProvider']} kwargs = {"providers": ["CUDAExecutionProvider", "CPUExecutionProvider"]}
if is_fp8: if is_fp8:
sess_options = ort.SessionOptions() sess_options = ort.SessionOptions()
load_custom_ops(sess_options) load_custom_ops(sess_options)
...@@ -288,10 +305,12 @@ def validate_result( ...@@ -288,10 +305,12 @@ def validate_result(
ort_s = create_ort_session(fname, is_fp8) ort_s = create_ort_session(fname, is_fp8)
input_feed = create_ort_input_dict(ort_s, inps) input_feed = create_ort_input_dict(ort_s, inps)
onnx_outputs = ort_s.run(None, input_feed=input_feed) onnx_outputs = ort_s.run(None, input_feed=input_feed)
compare_outputs(onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname) compare_outputs(
onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname
)
def create_meta(scale_factor: float, size: int=1): def create_meta(scale_factor: float, size: int = 1):
meta = tex.FP8TensorMeta() meta = tex.FP8TensorMeta()
meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda") meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda")
meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor
...@@ -324,7 +343,9 @@ def get_attn_mask_str(use_mask, attn_mask_type): ...@@ -324,7 +343,9 @@ def get_attn_mask_str(use_mask, attn_mask_type):
return "_mask" if use_mask else "_no-mask" return "_mask" if use_mask else "_no-mask"
attn_mask_str = "_arbitrary-no-mask" attn_mask_str = "_arbitrary-no-mask"
attn_mask_str = "_causal-mask" if attn_mask_type == "causal" else attn_mask_str attn_mask_str = "_causal-mask" if attn_mask_type == "causal" else attn_mask_str
attn_mask_str = "_arbitrary-mask" if use_mask and attn_mask_type == "arbitrary" else attn_mask_str attn_mask_str = (
"_arbitrary-mask" if use_mask and attn_mask_type == "arbitrary" else attn_mask_str
)
return attn_mask_str return attn_mask_str
...@@ -351,17 +372,11 @@ class FP8GemmModule(nn.Module): ...@@ -351,17 +372,11 @@ class FP8GemmModule(nn.Module):
self.outp_type = precision self.outp_type = precision
def forward(self, inp, weight): def forward(self, inp, weight):
inp_fp8 = cast_to_fp8( inp_fp8 = cast_to_fp8(inp, self.meta_inp, self.fp8_tensor_inp, self.inp_type)
inp,
self.meta_inp,
self.fp8_tensor_inp,
self.inp_type)
weight_fp8 = cast_to_fp8( weight_fp8 = cast_to_fp8(
weight, weight, self.meta_weight, self.fp8_tensor_weight, self.weights_type
self.meta_weight, )
self.fp8_tensor_weight,
self.weights_type)
ret, _ = fp8_gemm( ret, _ = fp8_gemm(
weight_fp8, weight_fp8,
...@@ -376,9 +391,11 @@ class FP8GemmModule(nn.Module): ...@@ -376,9 +391,11 @@ class FP8GemmModule(nn.Module):
get_workspace(), get_workspace(),
bias=self.bias, bias=self.bias,
use_bias=self.use_bias, use_bias=self.use_bias,
use_split_accumulator=False) use_split_accumulator=False,
)
return ret return ret
""" """
Tests cases begin here. Tests cases begin here.
""" """
...@@ -387,13 +404,17 @@ Tests cases begin here. ...@@ -387,13 +404,17 @@ Tests cases begin here.
@skip_FP8 @skip_FP8
@pytest.mark.parametrize("scale_factor", [1, 224]) @pytest.mark.parametrize("scale_factor", [1, 224])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"precision, atol", [ "precision, atol",
[torch.float32, 1e-7], [
[torch.float16, 1e-7], [torch.float32, 1e-7],
[torch.bfloat16, 5e-3], [torch.float16, 1e-7],
["fake-torch.bfloat16", 5e-3], [torch.bfloat16, 5e-3],
]) ["fake-torch.bfloat16", 5e-3],
def test_export_cast_ops(seed_default_rng, scale_factor: float, atol: float, precision: torch.dtype): ],
)
def test_export_cast_ops(
seed_default_rng, scale_factor: float, atol: float, precision: torch.dtype
):
fake_bf16_io = precision == "fake-torch.bfloat16" fake_bf16_io = precision == "fake-torch.bfloat16"
# reset precision to torch.bfloat16 after capturing fake BF16 mode # reset precision to torch.bfloat16 after capturing fake BF16 mode
precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision
...@@ -408,18 +429,9 @@ def test_export_cast_ops(seed_default_rng, scale_factor: float, atol: float, pre ...@@ -408,18 +429,9 @@ def test_export_cast_ops(seed_default_rng, scale_factor: float, atol: float, pre
self.fake_bf16_io = fake_bf16_io self.fake_bf16_io = fake_bf16_io
def forward(self, inp): def forward(self, inp):
ret = cast_to_fp8( ret = cast_to_fp8(inp, self.meta, self.fp8_tensor, self.fp8_type)
inp,
self.meta,
self.fp8_tensor,
self.fp8_type)
ret = cast_from_fp8( ret = cast_from_fp8(ret, self.meta, self.fp8_tensor, self.fp8_type, self.highprec_type)
ret,
self.meta,
self.fp8_tensor,
self.fp8_type,
self.highprec_type)
if self.fake_bf16_io: if self.fake_bf16_io:
ret = ret.type(torch.float32) ret = ret.type(torch.float32)
return ret return ret
...@@ -427,8 +439,9 @@ def test_export_cast_ops(seed_default_rng, scale_factor: float, atol: float, pre ...@@ -427,8 +439,9 @@ def test_export_cast_ops(seed_default_rng, scale_factor: float, atol: float, pre
# Set dimensions (these are arbitrary). # Set dimensions (these are arbitrary).
in_features = 64 in_features = 64
hidden_size = 256 hidden_size = 256
inp = torch.randn(hidden_size, in_features, device="cuda", inp = torch.randn(
dtype=torch.float if fake_bf16_io else precision) hidden_size, in_features, device="cuda", dtype=torch.float if fake_bf16_io else precision
)
high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io)
fname = f"te.cast_fp8_{scale_factor}{high_prec_str}.onnx" fname = f"te.cast_fp8_{scale_factor}{high_prec_str}.onnx"
model = TestFP8_QDQ(fake_bf16_io) model = TestFP8_QDQ(fake_bf16_io)
...@@ -439,15 +452,18 @@ def test_export_cast_ops(seed_default_rng, scale_factor: float, atol: float, pre ...@@ -439,15 +452,18 @@ def test_export_cast_ops(seed_default_rng, scale_factor: float, atol: float, pre
if fake_bf16_io or precision != torch.bfloat16: if fake_bf16_io or precision != torch.bfloat16:
validate_result(fname, inp, model, atol=atol, is_fp8=True, te_outputs=te_outputs) validate_result(fname, inp, model, atol=atol, is_fp8=True, te_outputs=te_outputs)
@skip_FP8 @skip_FP8
@pytest.mark.parametrize("scale_factor", [448]) @pytest.mark.parametrize("scale_factor", [448])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"precision, atol", [ "precision, atol",
[torch.float32, 1e-5], [
[torch.float16, 1e-5], [torch.float32, 1e-5],
[torch.bfloat16, 5e-3], [torch.float16, 1e-5],
["fake-torch.bfloat16", 5e-3] [torch.bfloat16, 5e-3],
]) ["fake-torch.bfloat16", 5e-3],
],
)
def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: float): def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: float):
fake_bf16_io = precision == "fake-torch.bfloat16" fake_bf16_io = precision == "fake-torch.bfloat16"
# reset precision to torch.bfloat16 after capturing fake BF16 mode # reset precision to torch.bfloat16 after capturing fake BF16 mode
...@@ -463,17 +479,8 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa ...@@ -463,17 +479,8 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa
self.fake_bf16_io = fake_bf16_io self.fake_bf16_io = fake_bf16_io
def forward(self, inp): def forward(self, inp):
ret = gelu( ret = gelu(inp, self.meta, self.fp8_tensor, self.fp8_type)
inp, ret = cast_from_fp8(ret, self.meta, self.fp8_tensor, self.fp8_type, self.highprec_type)
self.meta,
self.fp8_tensor,
self.fp8_type)
ret = cast_from_fp8(
ret,
self.meta,
self.fp8_tensor,
self.fp8_type,
self.highprec_type)
if self.fake_bf16_io: if self.fake_bf16_io:
ret = ret.type(torch.float32) ret = ret.type(torch.float32)
return ret return ret
...@@ -481,8 +488,9 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa ...@@ -481,8 +488,9 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa
# Set dimensions (these are arbitrary). # Set dimensions (these are arbitrary).
in_features = 64 in_features = 64
hidden_size = 256 hidden_size = 256
inp = torch.randn(hidden_size, in_features, device="cuda", inp = torch.randn(
dtype=torch.float if fake_bf16_io else precision) hidden_size, in_features, device="cuda", dtype=torch.float if fake_bf16_io else precision
)
high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io)
fname = f"te.gelu_fp8_{scale_factor}{high_prec_str}.onnx" fname = f"te.gelu_fp8_{scale_factor}{high_prec_str}.onnx"
model = TestFP8_Gelu(fake_bf16_io) model = TestFP8_Gelu(fake_bf16_io)
...@@ -490,39 +498,55 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa ...@@ -490,39 +498,55 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa
te_outputs = te_infer(model, inp, is_fp8=True) te_outputs = te_infer(model, inp, is_fp8=True)
serialize_inputs_outputs(fname, inp, te_outputs) serialize_inputs_outputs(fname, inp, te_outputs)
if fake_bf16_io or precision != torch.bfloat16: if fake_bf16_io or precision != torch.bfloat16:
validate_result(fname, inp, model, rtol=0, atol=atol, is_fp8=True, allow_cnt_errors=2, te_outputs=te_outputs) validate_result(
fname,
inp,
model,
rtol=0,
atol=atol,
is_fp8=True,
allow_cnt_errors=2,
te_outputs=te_outputs,
)
@pytest.mark.parametrize("scale_factors",
[(224, 224,),
])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"precision, use_fp8, use_bias, use_gelu", [ "scale_factors",
(torch.float32, False, False, False), [
(torch.float16, False, False, False), (
(torch.bfloat16, False, False, False), 224,
(torch.float32, False, True, False), 224,
(torch.float16, False, True, False), ),
(torch.bfloat16, False, True, False), ],
(torch.float32, False, True, True), )
(torch.float16, False, True, True), @pytest.mark.parametrize(
(torch.bfloat16, False, True, True), "precision, use_fp8, use_bias, use_gelu",
[
# For FP8 GEMM GeLU is not used. (torch.float32, False, False, False),
(torch.float32, True, False, False), (torch.float16, False, False, False),
(torch.float16, True, False, False), (torch.bfloat16, False, False, False),
(torch.bfloat16, True, False, False), (torch.float32, False, True, False),
# When enabling bias we must use float16 or bfloat16 (because of kernel limitations) (torch.float16, False, True, False),
(torch.float16, True, True, False), (torch.bfloat16, False, True, False),
(torch.bfloat16, True, True, False), (torch.float32, False, True, True),
]) (torch.float16, False, True, True),
(torch.bfloat16, False, True, True),
# For FP8 GEMM GeLU is not used.
(torch.float32, True, False, False),
(torch.float16, True, False, False),
(torch.bfloat16, True, False, False),
# When enabling bias we must use float16 or bfloat16 (because of kernel limitations)
(torch.float16, True, True, False),
(torch.bfloat16, True, True, False),
],
)
def test_export_gemm( def test_export_gemm(
seed_default_rng, seed_default_rng,
precision, # Precision of inputs, weights, output and bias precision, # Precision of inputs, weights, output and bias
use_fp8, use_fp8,
use_bias, use_bias,
use_gelu, use_gelu,
scale_factors scale_factors,
): ):
# Skip FP8 tests on non-hopper devices # Skip FP8 tests on non-hopper devices
if use_fp8 and not fp8_available: if use_fp8 and not fp8_available:
...@@ -548,21 +572,20 @@ def test_export_gemm( ...@@ -548,21 +572,20 @@ def test_export_gemm(
inp, inp,
outp_type, outp_type,
get_workspace(), get_workspace(),
# test bias # test bias
bias=self.bias, bias=self.bias,
use_bias=self.use_bias, use_bias=self.use_bias,
# test gelu # test gelu
gelu=self.gelu, gelu=self.gelu,
gelu_input=self.gelu_input, gelu_input=self.gelu_input,
grad=False, # only True for backward pass grad=False, # only True for backward pass
accumulate=False, accumulate=False,
) )
return ret return ret
# If gelu is applied then bias must be added, as defined by TE kernel. # If gelu is applied then bias must be added, as defined by TE kernel.
if use_gelu: assert use_bias if use_gelu:
assert use_bias
# Set dimensions (these are arbitrary). # Set dimensions (these are arbitrary).
out_features = 128 out_features = 128
hidden_size = 256 hidden_size = 256
...@@ -574,45 +597,64 @@ def test_export_gemm( ...@@ -574,45 +597,64 @@ def test_export_gemm(
gelu_str = "_gelu" if use_gelu else "" gelu_str = "_gelu" if use_gelu else ""
high_prec_str = dtype2str(precision) high_prec_str = dtype2str(precision)
fname = f"te.gemm{fp8_str}{bias_str}{gelu_str}{high_prec_str}.onnx" fname = f"te.gemm{fp8_str}{bias_str}{gelu_str}{high_prec_str}.onnx"
input_names = ['input', 'weight'] input_names = ["input", "weight"]
if use_fp8: if use_fp8:
model = FP8GemmModule(precision, use_bias, use_gelu, scale_factors, hidden_size, out_features) model = FP8GemmModule(
precision, use_bias, use_gelu, scale_factors, hidden_size, out_features
)
do_export(model, (inp, weight), fname, use_fp8, input_names=input_names) do_export(model, (inp, weight), fname, use_fp8, input_names=input_names)
te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8) te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8)
serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names) serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names)
if precision != torch.bfloat16: if precision != torch.bfloat16:
validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2, validate_result(
is_fp8=True, input_names=input_names, te_outputs=te_outputs) fname,
(inp, weight),
model,
rtol=1e-2,
atol=2e-2,
is_fp8=True,
input_names=input_names,
te_outputs=te_outputs,
)
else: else:
model = Test_GEMM(precision, use_bias, use_gelu) model = Test_GEMM(precision, use_bias, use_gelu)
do_export(model, (inp, weight), fname, use_fp8, input_names=input_names) do_export(model, (inp, weight), fname, use_fp8, input_names=input_names)
te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8) te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8)
serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names) serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names)
if precision != torch.bfloat16: if precision != torch.bfloat16:
validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2, validate_result(
input_names=input_names, te_outputs=te_outputs) fname,
(inp, weight),
model,
rtol=1e-2,
atol=2e-2,
input_names=input_names,
te_outputs=te_outputs,
)
@pytest.mark.parametrize("scale_factor", [448, 112]) @pytest.mark.parametrize("scale_factor", [448, 112])
@pytest.mark.parametrize("zero_centered_gamma", [False, True]) @pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_fp8, precision, atol", [ "use_fp8, precision, atol",
[False, torch.float32, 1e-7], [
[False, torch.float16, 1e-7], [False, torch.float32, 1e-7],
[False, torch.bfloat16, 1e-7], [False, torch.float16, 1e-7],
[False, "fake-torch.bfloat16", 1e-7], [False, torch.bfloat16, 1e-7],
[True, torch.float32, 1e-7], [False, "fake-torch.bfloat16", 1e-7],
[True, torch.float16, 1e-7], [True, torch.float32, 1e-7],
[True, torch.bfloat16, 1e-2], [True, torch.float16, 1e-7],
[True, "fake-torch.bfloat16", 1e-2] [True, torch.bfloat16, 1e-2],
]) [True, "fake-torch.bfloat16", 1e-2],
],
)
def test_export_layernorm( def test_export_layernorm(
seed_default_rng, seed_default_rng,
use_fp8: bool, use_fp8: bool,
scale_factor: float, scale_factor: float,
precision: torch.dtype, precision: torch.dtype,
zero_centered_gamma: bool, zero_centered_gamma: bool,
atol: float atol: float,
): ):
fake_bf16_io = precision == "fake-torch.bfloat16" fake_bf16_io = precision == "fake-torch.bfloat16"
# reset precision to torch.bfloat16 after capturing fake BF16 mode # reset precision to torch.bfloat16 after capturing fake BF16 mode
...@@ -628,10 +670,15 @@ def test_export_layernorm( ...@@ -628,10 +670,15 @@ def test_export_layernorm(
class Test_Layernorm(nn.Module): class Test_Layernorm(nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
eps = 1e-6 # An arbitrary small value eps = 1e-6 # An arbitrary small value
dtype = torch.float if fake_bf16_io else precision dtype = torch.float if fake_bf16_io else precision
self.ln = te.LayerNorm(inp_shape[1], eps, params_dtype=dtype, self.ln = (
zero_centered_gamma=zero_centered_gamma).eval().cuda() te.LayerNorm(
inp_shape[1], eps, params_dtype=dtype, zero_centered_gamma=zero_centered_gamma
)
.eval()
.cuda()
)
def forward(self, inp): def forward(self, inp):
ret = self.ln(inp) ret = self.ln(inp)
...@@ -641,11 +688,13 @@ def test_export_layernorm( ...@@ -641,11 +688,13 @@ def test_export_layernorm(
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
normalized_shape = torch.Size(inp.shape[1:]) normalized_shape = torch.Size(inp.shape[1:])
self.weight = torch.randn(*normalized_shape, device="cuda", self.weight = torch.randn(
dtype=torch.float32 if fake_bf16_io else precision) *normalized_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision
self.bias = torch.zeros(*normalized_shape, device="cuda", )
dtype=torch.float32 if fake_bf16_io else precision) self.bias = torch.zeros(
self.eps = 1e-6 # An arbitrary small value *normalized_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision
)
self.eps = 1e-6 # An arbitrary small value
self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT
self.meta = create_meta(scale_factor) self.meta = create_meta(scale_factor)
...@@ -661,14 +710,12 @@ def test_export_layernorm( ...@@ -661,14 +710,12 @@ def test_export_layernorm(
self.fp8_tensor, self.fp8_tensor,
self.fp8_type, self.fp8_type,
0, 0,
zero_centered_gamma) zero_centered_gamma,
)
ret = cast_from_fp8( ret = cast_from_fp8(
ret, ret, self.meta, self.fp8_tensor, self.fp8_type, as_te_type(precision)
self.meta, )
self.fp8_tensor,
self.fp8_type,
as_te_type(precision))
if fake_bf16_io: if fake_bf16_io:
ret = ret.type(torch.float32) ret = ret.type(torch.float32)
return ret return ret
...@@ -683,28 +730,32 @@ def test_export_layernorm( ...@@ -683,28 +730,32 @@ def test_export_layernorm(
serialize_inputs_outputs(fname, inp, te_outputs) serialize_inputs_outputs(fname, inp, te_outputs)
if fake_bf16_io or precision != torch.bfloat16: if fake_bf16_io or precision != torch.bfloat16:
validate_result( validate_result(
fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs) fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs
)
@pytest.mark.parametrize("scale_factor", [448, 112]) @pytest.mark.parametrize("scale_factor", [448, 112])
@pytest.mark.parametrize("zero_centered_gamma", [False, True]) @pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_fp8, precision, atol", [ "use_fp8, precision, atol",
[False, torch.float32, 1e-7], [
[False, torch.float16, 1e-7], [False, torch.float32, 1e-7],
[False, torch.bfloat16, 1e-7], [False, torch.float16, 1e-7],
[False, "fake-torch.bfloat16", 1e-7], [False, torch.bfloat16, 1e-7],
[True, torch.float32, 1e-7], [False, "fake-torch.bfloat16", 1e-7],
[True, torch.float16, 1e-7], [True, torch.float32, 1e-7],
[True, torch.bfloat16, 1e-2], [True, torch.float16, 1e-7],
[True, "fake-torch.bfloat16", 1e-2] [True, torch.bfloat16, 1e-2],
]) [True, "fake-torch.bfloat16", 1e-2],
],
)
def test_export_rmsnorm( def test_export_rmsnorm(
seed_default_rng, seed_default_rng,
use_fp8: bool, use_fp8: bool,
scale_factor: float, scale_factor: float,
precision: torch.dtype, precision: torch.dtype,
zero_centered_gamma: bool, zero_centered_gamma: bool,
atol: float atol: float,
): ):
fake_bf16_io = precision == "fake-torch.bfloat16" fake_bf16_io = precision == "fake-torch.bfloat16"
# reset precision to torch.bfloat16 after capturing fake BF16 mode # reset precision to torch.bfloat16 after capturing fake BF16 mode
...@@ -720,10 +771,15 @@ def test_export_rmsnorm( ...@@ -720,10 +771,15 @@ def test_export_rmsnorm(
class Test_RMSnorm(nn.Module): class Test_RMSnorm(nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
eps = 1e-6 # An arbitrary small value eps = 1e-6 # An arbitrary small value
dtype = torch.float if fake_bf16_io else precision dtype = torch.float if fake_bf16_io else precision
self.ln = te.RMSNorm(inp_shape[1], eps, params_dtype=dtype, self.ln = (
zero_centered_gamma=zero_centered_gamma).eval().cuda() te.RMSNorm(
inp_shape[1], eps, params_dtype=dtype, zero_centered_gamma=zero_centered_gamma
)
.eval()
.cuda()
)
def forward(self, inp): def forward(self, inp):
ret = self.ln(inp) ret = self.ln(inp)
...@@ -733,9 +789,10 @@ def test_export_rmsnorm( ...@@ -733,9 +789,10 @@ def test_export_rmsnorm(
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
normalized_shape = torch.Size(inp.shape[1:]) normalized_shape = torch.Size(inp.shape[1:])
self.weight = torch.randn(*normalized_shape, device="cuda", self.weight = torch.randn(
dtype=torch.float32 if fake_bf16_io else precision) *normalized_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision
self.eps = 1e-6 # An arbitrary small value )
self.eps = 1e-6 # An arbitrary small value
self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT
self.meta = create_meta(scale_factor) self.meta = create_meta(scale_factor)
...@@ -750,14 +807,12 @@ def test_export_rmsnorm( ...@@ -750,14 +807,12 @@ def test_export_rmsnorm(
self.fp8_tensor, self.fp8_tensor,
self.fp8_type, self.fp8_type,
0, 0,
zero_centered_gamma) zero_centered_gamma,
)
ret = cast_from_fp8( ret = cast_from_fp8(
ret, ret, self.meta, self.fp8_tensor, self.fp8_type, as_te_type(precision)
self.meta, )
self.fp8_tensor,
self.fp8_type,
as_te_type(precision))
if fake_bf16_io: if fake_bf16_io:
ret = ret.type(torch.float32) ret = ret.type(torch.float32)
return ret return ret
...@@ -772,7 +827,8 @@ def test_export_rmsnorm( ...@@ -772,7 +827,8 @@ def test_export_rmsnorm(
serialize_inputs_outputs(fname, inp, te_outputs) serialize_inputs_outputs(fname, inp, te_outputs)
if fake_bf16_io or precision != torch.bfloat16: if fake_bf16_io or precision != torch.bfloat16:
validate_result( validate_result(
fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs) fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs
)
@pytest.mark.parametrize("scale_factor", [1]) @pytest.mark.parametrize("scale_factor", [1])
...@@ -780,23 +836,25 @@ def test_export_rmsnorm( ...@@ -780,23 +836,25 @@ def test_export_rmsnorm(
# Returning the bias is a TE fusion optimization we don't care about. # Returning the bias is a TE fusion optimization we don't care about.
@pytest.mark.parametrize("return_bias", [False]) @pytest.mark.parametrize("return_bias", [False])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"precision, use_bias",[ "precision, use_bias",
(torch.float32, False), [
(torch.float32, True), (torch.float32, False),
(torch.float16, False), (torch.float32, True),
(torch.float16, True), (torch.float16, False),
# Todo: cannot configure BF16 when bias is disabled (ORT issue?) (torch.float16, True),
(torch.bfloat16, False), # Todo: cannot configure BF16 when bias is disabled (ORT issue?)
# Todo: cannot configure BF16 when bias is enabled (ORT issue?) (torch.bfloat16, False),
(torch.bfloat16, True), # Todo: cannot configure BF16 when bias is enabled (ORT issue?)
]) (torch.bfloat16, True),
],
)
def test_export_linear( def test_export_linear(
seed_default_rng, seed_default_rng,
scale_factor: float, scale_factor: float,
use_fp8: bool, use_fp8: bool,
use_bias: bool, use_bias: bool,
return_bias: bool, return_bias: bool,
precision: torch.dtype precision: torch.dtype,
): ):
# Skip FP8 tests on non-hopper devices # Skip FP8 tests on non-hopper devices
if use_fp8 and not fp8_available: if use_fp8 and not fp8_available:
...@@ -808,20 +866,14 @@ def test_export_linear( ...@@ -808,20 +866,14 @@ def test_export_linear(
hidden_size = 256 hidden_size = 256
class Test_Linear(nn.Module): class Test_Linear(nn.Module):
def __init__(self, def __init__(self, in_features, out_features, use_bias, return_bias, precision):
in_features,
out_features,
use_bias,
return_bias,
precision
):
super().__init__() super().__init__()
self.linear = te.Linear( self.linear = te.Linear(
in_features, in_features,
out_features, out_features,
bias=use_bias, bias=use_bias,
return_bias=return_bias, return_bias=return_bias,
params_dtype=precision params_dtype=precision,
) )
def forward(self, inp): def forward(self, inp):
...@@ -834,20 +886,16 @@ def test_export_linear( ...@@ -834,20 +886,16 @@ def test_export_linear(
high_prec_str = dtype2str(precision) high_prec_str = dtype2str(precision)
fname = f"te.linear{fp8_str}{bias_str}{high_prec_str}.onnx" fname = f"te.linear{fp8_str}{bias_str}{high_prec_str}.onnx"
with te.fp8_autocast(enabled=use_fp8): with te.fp8_autocast(enabled=use_fp8):
model = Test_Linear( model = Test_Linear(in_features, out_features, use_bias, return_bias, precision).to(
in_features, device="cuda"
out_features, )
use_bias,
return_bias,
precision
).to(device='cuda')
if use_fp8: if use_fp8:
set_layer_scale(model.linear, scale_factor, num_gemms=1) set_layer_scale(model.linear, scale_factor, num_gemms=1)
do_export(model, inp, fname, use_fp8) do_export(model, inp, fname, use_fp8)
te_outputs = te_infer(model, inp, is_fp8=use_fp8) te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs) serialize_inputs_outputs(fname, inp, te_outputs)
if precision in (torch.bfloat16, ): if precision in (torch.bfloat16,):
return return
if not use_fp8: if not use_fp8:
validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs) validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs)
...@@ -861,14 +909,16 @@ def test_export_linear( ...@@ -861,14 +909,16 @@ def test_export_linear(
@pytest.mark.parametrize("return_bias", [False]) @pytest.mark.parametrize("return_bias", [False])
@pytest.mark.parametrize("return_layernorm_output", [False]) @pytest.mark.parametrize("return_layernorm_output", [False])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"precision, use_bias",[ "precision, use_bias",
(torch.float32, False), [
(torch.float32, True), (torch.float32, False),
(torch.float16, True), (torch.float32, True),
(torch.float16, False), (torch.float16, True),
(torch.bfloat16, True), (torch.float16, False),
(torch.bfloat16, False), (torch.bfloat16, True),
]) (torch.bfloat16, False),
],
)
@pytest.mark.parametrize("zero_centered_gamma", [False, True]) @pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
def test_export_layernorm_linear( def test_export_layernorm_linear(
...@@ -907,13 +957,13 @@ def test_export_layernorm_linear( ...@@ -907,13 +957,13 @@ def test_export_layernorm_linear(
params_dtype=precision, params_dtype=precision,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
normalization=normalization, normalization=normalization,
).to(device='cuda') ).to(device="cuda")
if use_fp8: if use_fp8:
set_layer_scale(model, scale_factor, num_gemms=1) set_layer_scale(model, scale_factor, num_gemms=1)
do_export(model, inp, fname, use_fp8) do_export(model, inp, fname, use_fp8)
te_outputs = te_infer(model, inp, is_fp8=use_fp8) te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs) serialize_inputs_outputs(fname, inp, te_outputs)
if precision in (torch.bfloat16, ): if precision in (torch.bfloat16,):
return return
if not use_fp8: if not use_fp8:
validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs) validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs)
...@@ -927,14 +977,16 @@ def test_export_layernorm_linear( ...@@ -927,14 +977,16 @@ def test_export_layernorm_linear(
@pytest.mark.parametrize("return_bias", [False]) @pytest.mark.parametrize("return_bias", [False])
@pytest.mark.parametrize("return_layernorm_output", [False]) @pytest.mark.parametrize("return_layernorm_output", [False])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"precision, use_bias",[ "precision, use_bias",
(torch.float32, False), [
(torch.float32, True), (torch.float32, False),
(torch.float16, True), (torch.float32, True),
(torch.float16, False), (torch.float16, True),
(torch.bfloat16, True), (torch.float16, False),
(torch.bfloat16, False), (torch.bfloat16, True),
]) (torch.bfloat16, False),
],
)
@pytest.mark.parametrize("zero_centered_gamma", [False, True]) @pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("activation", supported_activations) @pytest.mark.parametrize("activation", supported_activations)
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
...@@ -954,7 +1006,6 @@ def test_export_layernorm_mlp( ...@@ -954,7 +1006,6 @@ def test_export_layernorm_mlp(
if use_fp8 and not fp8_available: if use_fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
# Set dimensions (these are arbitrary). # Set dimensions (these are arbitrary).
in_features = 64 in_features = 64
out_features = 256 out_features = 256
...@@ -977,30 +1028,32 @@ def test_export_layernorm_mlp( ...@@ -977,30 +1028,32 @@ def test_export_layernorm_mlp(
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
activation=activation, activation=activation,
normalization=normalization, normalization=normalization,
).to(device='cuda') ).to(device="cuda")
if use_fp8: if use_fp8:
set_layer_scale(model, scale_factor, num_gemms=2) set_layer_scale(model, scale_factor, num_gemms=2)
do_export(model, inp, fname, use_fp8) do_export(model, inp, fname, use_fp8)
te_outputs = te_infer(model, inp, is_fp8=use_fp8) te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs) serialize_inputs_outputs(fname, inp, te_outputs)
if precision in (torch.bfloat16, ): if precision in (torch.bfloat16,):
return return
atol = 1e-6 if use_fp8 else (5e-1 if activation=="swiglu" else 1e-3) atol = 1e-6 if use_fp8 else (5e-1 if activation == "swiglu" else 1e-3)
validate_result(fname, inp, model, atol=atol, is_fp8=use_fp8, te_outputs=te_outputs) validate_result(fname, inp, model, atol=atol, is_fp8=use_fp8, te_outputs=te_outputs)
@skip_FP8 @skip_FP8
@pytest.mark.parametrize( @pytest.mark.parametrize(
"precision, use_mask, attn_mask_type", [ "precision, use_mask, attn_mask_type",
(torch.float32, True, "arbitrary"), # calls forward_torch_softmax (apply user mask) [
(torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask) (torch.float32, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask) (torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.float16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask) (torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask) (torch.float16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask) (torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.bfloat16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask) (torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask) (torch.bfloat16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
]) (torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
],
)
def test_export_core_attention( def test_export_core_attention(
seed_default_rng, seed_default_rng,
set_max_seq_len, set_max_seq_len,
...@@ -1034,40 +1087,42 @@ def test_export_core_attention( ...@@ -1034,40 +1087,42 @@ def test_export_core_attention(
attention_dropout=0.5, attention_dropout=0.5,
qkv_format=qkv_format, qkv_format=qkv_format,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
).to(device='cuda') ).to(device="cuda")
do_export(model, do_export(model, inp, fname, input_names=input_names, use_fp8=True)
inp,
fname,
input_names=input_names,
use_fp8=True)
te_outputs = te_infer(model, inp, is_fp8=True) te_outputs = te_infer(model, inp, is_fp8=True)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
if precision in (torch.bfloat16, ): if precision in (torch.bfloat16,):
return return
validate_result(fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs) validate_result(
fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs
)
test_configs_multihead_attention = [ test_configs_multihead_attention = [
#"use_mask, attn_mask_type" # "use_mask, attn_mask_type"
(False, "no_mask"), # calls ScaledSoftmax (False, "no_mask"), # calls ScaledSoftmax
(True, "arbitrary"), # calls ScaledMaskedSoftmax (True, "arbitrary"), # calls ScaledMaskedSoftmax
] ]
test_configs_attention_type = [ test_configs_attention_type = [
#"input_layernorm, attention_type, fuse_qkv_params" # "input_layernorm, attention_type, fuse_qkv_params"
(True, "self", True), (True, "self", True),
(False, "self", True), (False, "self", True),
(True, "self", False), (True, "self", False),
(False, "self", False), (False, "self", False),
(True, "cross", True), (True, "cross", True),
(False, "cross", True), (False, "cross", True),
(True, "cross", False), (True, "cross", False),
(False, "cross", False), (False, "cross", False),
] ]
@pytest.mark.parametrize("use_fp8", [False, True]) @pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention) @pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("return_layernorm_output", [False]) @pytest.mark.parametrize("return_layernorm_output", [False])
@pytest.mark.parametrize("input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type) @pytest.mark.parametrize(
"input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type
)
def test_export_multihead_attention( def test_export_multihead_attention(
seed_default_rng, seed_default_rng,
set_max_seq_len, set_max_seq_len,
...@@ -1078,7 +1133,7 @@ def test_export_multihead_attention( ...@@ -1078,7 +1133,7 @@ def test_export_multihead_attention(
return_layernorm_output: bool, return_layernorm_output: bool,
input_layernorm: bool, input_layernorm: bool,
attention_type: str, attention_type: str,
fuse_qkv_params: bool fuse_qkv_params: bool,
): ):
# Skip FP8 tests on non-hopper devices # Skip FP8 tests on non-hopper devices
if use_fp8 and not fp8_available: if use_fp8 and not fp8_available:
...@@ -1102,17 +1157,23 @@ def test_export_multihead_attention( ...@@ -1102,17 +1157,23 @@ def test_export_multihead_attention(
output_layer_init_method, output_layer_init_method,
) )
hidden_states_context = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") hidden_states_context = torch.randn(
sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
)
attention_mask = None attention_mask = None
if use_mask and attn_mask_type != "causal": if use_mask and attn_mask_type != "causal":
# Generate a random mask with 50% probability for 0 or 1. # Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision) probs = 0.5 * torch.ones(
batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision
)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
encoder_output = None encoder_output = None
if attention_type == "cross": if attention_type == "cross":
encoder_output = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") encoder_output = torch.randn(
sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
)
fp8_str = "_fp8" if use_fp8 else "" fp8_str = "_fp8" if use_fp8 else ""
dtype_str = dtype2str(precision) dtype_str = dtype2str(precision)
...@@ -1131,49 +1192,98 @@ def test_export_multihead_attention( ...@@ -1131,49 +1192,98 @@ def test_export_multihead_attention(
attention_type=attention_type, attention_type=attention_type,
fuse_qkv_params=fuse_qkv_params, fuse_qkv_params=fuse_qkv_params,
return_bias=True, return_bias=True,
).to(device='cuda') ).to(device="cuda")
inp_context = (hidden_states_context, attention_mask, encoder_output) inp_context = (hidden_states_context, attention_mask, encoder_output)
input_names = ["hidden_states", "attention_mask", "encoder_output"] input_names = ["hidden_states", "attention_mask", "encoder_output"]
output_names=["attention_output", "attention_bias"] output_names = ["attention_output", "attention_bias"]
do_export(model, inp_context, fname, use_fp8, input_names=input_names, output_names=output_names, do_export(
dynamic_axes={"hidden_states": {0: "seq", 1:"bs"}, model,
"attention_output": {0: "seq", 1:"bs"}}) inp_context,
fname,
use_fp8,
input_names=input_names,
output_names=output_names,
dynamic_axes={
"hidden_states": {0: "seq", 1: "bs"},
"attention_output": {0: "seq", 1: "bs"},
},
)
te_outputs = te_infer(model, inp_context, is_fp8=use_fp8) te_outputs = te_infer(model, inp_context, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp_context, te_outputs, input_names=input_names, output_names=output_names) serialize_inputs_outputs(
if precision in (torch.bfloat16, ): fname, inp_context, te_outputs, input_names=input_names, output_names=output_names
)
if precision in (torch.bfloat16,):
return return
if not use_fp8: if not use_fp8:
validate_result(fname, inp_context, model, atol=1e-3, input_names=input_names, validate_result(
output_names=output_names, te_outputs=te_outputs) fname,
inp_context,
model,
atol=1e-3,
input_names=input_names,
output_names=output_names,
te_outputs=te_outputs,
)
else: else:
validate_result(fname, inp_context, model, atol=1e-2, is_fp8=use_fp8, validate_result(
input_names=input_names, output_names=output_names, allow_cnt_errors=3, fname,
te_outputs=te_outputs) inp_context,
model,
atol=1e-2,
is_fp8=use_fp8,
input_names=input_names,
output_names=output_names,
allow_cnt_errors=3,
te_outputs=te_outputs,
)
# In GPT generative phase (inference) the input sequence is smaller than the maximum # In GPT generative phase (inference) the input sequence is smaller than the maximum
# allowed sequence length and we want to test this condition. # allowed sequence length and we want to test this condition.
# Pretend that we're in generative phase when it makes sense (causal mask and self-attention). # Pretend that we're in generative phase when it makes sense (causal mask and self-attention).
is_generative_phase = (attn_mask_type == "causal" and attention_type == "self") is_generative_phase = attn_mask_type == "causal" and attention_type == "self"
if is_generative_phase: if is_generative_phase:
seq_len_offset = 8 seq_len_offset = 8
hidden_states_generative = torch.randn(sequence_length-seq_len_offset, batch_size, hidden_size, dtype=precision, device="cuda") hidden_states_generative = torch.randn(
sequence_length - seq_len_offset,
batch_size,
hidden_size,
dtype=precision,
device="cuda",
)
inp_generative = (hidden_states_generative, attention_mask, encoder_output) inp_generative = (hidden_states_generative, attention_mask, encoder_output)
if not use_fp8: if not use_fp8:
validate_result(fname, inp_generative, model, atol=1e-3, input_names=input_names, output_names=output_names) validate_result(
fname,
inp_generative,
model,
atol=1e-3,
input_names=input_names,
output_names=output_names,
)
else: else:
validate_result(fname, inp_generative, model, atol=1e-2, is_fp8=use_fp8, validate_result(
input_names=input_names, output_names=output_names, allow_cnt_errors=3) fname,
inp_generative,
model,
atol=1e-2,
is_fp8=use_fp8,
input_names=input_names,
output_names=output_names,
allow_cnt_errors=3,
)
@pytest.mark.parametrize("use_fp8", [False, True]) @pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention) @pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
@pytest.mark.parametrize("output_layernorm", [ @pytest.mark.parametrize(
#True, # TO DO: handle this "output_layernorm",
False [
]) # True, # TO DO: handle this
False
],
)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("fuse_qkv_params", [False, True]) @pytest.mark.parametrize("fuse_qkv_params", [False, True])
@pytest.mark.parametrize("zero_centered_gamma", [False, True]) @pytest.mark.parametrize("zero_centered_gamma", [False, True])
...@@ -1201,12 +1311,16 @@ def test_export_transformer_layer( ...@@ -1201,12 +1311,16 @@ def test_export_transformer_layer(
ffn_hidden_size = 256 ffn_hidden_size = 256
num_attention_heads = 4 num_attention_heads = 4
input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") input_tensor = torch.rand(
sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
)
input_names = ["input", "attention_mask"] input_names = ["input", "attention_mask"]
attention_mask = None attention_mask = None
if use_mask and attn_mask_type != "causal": if use_mask and attn_mask_type != "causal":
# Generate a random mask with 50% probability for 0 or 1. # Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision) probs = 0.5 * torch.ones(
batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision
)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
inp = (input_tensor, attention_mask) inp = (input_tensor, attention_mask)
...@@ -1225,19 +1339,30 @@ def test_export_transformer_layer( ...@@ -1225,19 +1339,30 @@ def test_export_transformer_layer(
params_dtype=precision, params_dtype=precision,
fuse_qkv_params=fuse_qkv_params, fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
activation=activation).to(device='cuda') activation=activation,
).to(device="cuda")
do_export(model, inp, fname, use_fp8, input_names=input_names) do_export(model, inp, fname, use_fp8, input_names=input_names)
te_outputs = te_infer(model, inp, is_fp8=use_fp8) te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
if precision in (torch.bfloat16, ): if precision in (torch.bfloat16,):
return return
atol = 5e-1 if use_fp8 else (5e-1 if activation=="swiglu" else 1e-3) atol = 5e-1 if use_fp8 else (5e-1 if activation == "swiglu" else 1e-3)
validate_result(fname, inp, model, atol=atol, is_fp8=use_fp8, input_names=input_names, te_outputs=te_outputs) validate_result(
fname, inp, model, atol=atol, is_fp8=use_fp8, input_names=input_names, te_outputs=te_outputs
)
@pytest.mark.parametrize("use_fp8", [True]) @pytest.mark.parametrize("use_fp8", [True])
@pytest.mark.parametrize("ln_scale_factor", [448*2]) @pytest.mark.parametrize("ln_scale_factor", [448 * 2])
@pytest.mark.parametrize("gemm_scale_factors", [(224, 224,),]) @pytest.mark.parametrize(
"gemm_scale_factors",
[
(
224,
224,
),
],
)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("zero_centered_gamma", [False, True]) @pytest.mark.parametrize("zero_centered_gamma", [False, True])
def test_export_gemm_layernorm( def test_export_gemm_layernorm(
...@@ -1246,7 +1371,7 @@ def test_export_gemm_layernorm( ...@@ -1246,7 +1371,7 @@ def test_export_gemm_layernorm(
ln_scale_factor: float, ln_scale_factor: float,
gemm_scale_factors: Tuple[float, float], gemm_scale_factors: Tuple[float, float],
precision: torch.dtype, precision: torch.dtype,
zero_centered_gamma: bool zero_centered_gamma: bool,
): ):
"""This is a regression test for testing that all LN inputs have the same type. """This is a regression test for testing that all LN inputs have the same type.
...@@ -1260,20 +1385,26 @@ def test_export_gemm_layernorm( ...@@ -1260,20 +1385,26 @@ def test_export_gemm_layernorm(
# Skip FP8 tests on non-hopper devices # Skip FP8 tests on non-hopper devices
if use_fp8 and not fp8_available: if use_fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
class TestFP8_GemmLayernorm(nn.Module): class TestFP8_GemmLayernorm(nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
normalized_shape = torch.Size(inp.shape[1:]) normalized_shape = torch.Size(inp.shape[1:])
self.weight = torch.randn(*normalized_shape, dtype=precision, device="cuda") self.weight = torch.randn(*normalized_shape, dtype=precision, device="cuda")
self.bias = torch.zeros(*normalized_shape, dtype=precision, device="cuda") self.bias = torch.zeros(*normalized_shape, dtype=precision, device="cuda")
self.eps = 1e-6 # An arbitrary small value self.eps = 1e-6 # An arbitrary small value
self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT
self.meta = create_meta(ln_scale_factor) self.meta = create_meta(ln_scale_factor)
self.fp8_type = tex.DType.kFloat8E4M3 self.fp8_type = tex.DType.kFloat8E4M3
self.gemm = FP8GemmModule( self.gemm = FP8GemmModule(
precision, use_bias=False, gelu=False, scale_factors=gemm_scale_factors, precision,
hidden_size=hidden_size, out_features=out_features) use_bias=False,
gelu=False,
scale_factors=gemm_scale_factors,
hidden_size=hidden_size,
out_features=out_features,
)
def forward(self, inp, weight): def forward(self, inp, weight):
x = self.gemm(inp, weight) x = self.gemm(inp, weight)
...@@ -1286,14 +1417,16 @@ def test_export_gemm_layernorm( ...@@ -1286,14 +1417,16 @@ def test_export_gemm_layernorm(
self.fp8_tensor, self.fp8_tensor,
self.fp8_type, self.fp8_type,
0, 0,
zero_centered_gamma) zero_centered_gamma,
)
x = cast_from_fp8( x = cast_from_fp8(
x, x,
self.meta, self.meta,
self.fp8_tensor, self.fp8_tensor,
self.fp8_type, self.fp8_type,
tex.DType.kFloat32 if precision == torch.float32 else tex.DType.kFloat16) tex.DType.kFloat32 if precision == torch.float32 else tex.DType.kFloat16,
)
return x return x
inp = torch.randn(hidden_size, in_features, dtype=precision, device="cuda") inp = torch.randn(hidden_size, in_features, dtype=precision, device="cuda")
...@@ -1302,14 +1435,21 @@ def test_export_gemm_layernorm( ...@@ -1302,14 +1435,21 @@ def test_export_gemm_layernorm(
high_prec_str = dtype2str(precision) high_prec_str = dtype2str(precision)
fp8_str = f"_fp8" if use_fp8 else "" fp8_str = f"_fp8" if use_fp8 else ""
fname = f"te.gemm_layernorm{fp8_str}{high_prec_str}.onnx" fname = f"te.gemm_layernorm{fp8_str}{high_prec_str}.onnx"
input_names = ['input', 'weight'] input_names = ["input", "weight"]
do_export(model, (inp, weight), fname, use_fp8=use_fp8, input_names=input_names) do_export(model, (inp, weight), fname, use_fp8=use_fp8, input_names=input_names)
te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8) te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8)
serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names) serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names)
if precision not in (torch.bfloat16, ): if precision not in (torch.bfloat16,):
validate_result( validate_result(
fname, (inp, weight), model, atol=5e-2, is_fp8=use_fp8, allow_cnt_errors=2, fname,
input_names=input_names, te_outputs=te_outputs) (inp, weight),
model,
atol=5e-2,
is_fp8=use_fp8,
allow_cnt_errors=2,
input_names=input_names,
te_outputs=te_outputs,
)
@skip_FP8 @skip_FP8
...@@ -1357,32 +1497,61 @@ def test_export_gpt_generation( ...@@ -1357,32 +1497,61 @@ def test_export_gpt_generation(
output_layernorm=output_layernorm, output_layernorm=output_layernorm,
params_dtype=precision, params_dtype=precision,
fuse_qkv_params=fuse_qkv_params, fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma).to(device='cuda') zero_centered_gamma=zero_centered_gamma,
).to(device="cuda")
# "Context phase": use full input sequence length # "Context phase": use full input sequence length
input_names = ["input"] input_names = ["input"]
output_names = ["output"] output_names = ["output"]
input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") input_tensor = torch.rand(
sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
)
inp = (input_tensor,) inp = (input_tensor,)
do_export(model, inp, fname, use_fp8, do_export(
input_names=input_names, output_names=output_names, model,
dynamic_axes={"input": {0: "seq", 1:"bs"}, inp,
"output": {0: "seq", 1:"bs"}, }) fname,
use_fp8,
input_names=input_names,
output_names=output_names,
dynamic_axes={
"input": {0: "seq", 1: "bs"},
"output": {0: "seq", 1: "bs"},
},
)
te_outputs = te_infer(model, inp, is_fp8=use_fp8) te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names, output_names=output_names) serialize_inputs_outputs(
if precision not in (torch.bfloat16, ): fname, inp, te_outputs, input_names=input_names, output_names=output_names
validate_result(fname, inp, model, atol=6e-3, is_fp8=use_fp8, input_names=input_names, )
te_outputs=te_outputs) if precision not in (torch.bfloat16,):
validate_result(
fname,
inp,
model,
atol=6e-3,
is_fp8=use_fp8,
input_names=input_names,
te_outputs=te_outputs,
)
# "Generative phase": use a single input (sequence len=1). For FP8 we need to pad the sequence to mult of 8. # "Generative phase": use a single input (sequence len=1). For FP8 we need to pad the sequence to mult of 8.
sequence_length = 1 if not use_fp8 else 8 sequence_length = 1 if not use_fp8 else 8
input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") input_tensor = torch.rand(
sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
)
inp = (input_tensor, attention_mask) inp = (input_tensor, attention_mask)
te_outputs = te_infer(model, inp, is_fp8=use_fp8) te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
if precision not in (torch.bfloat16, ): if precision not in (torch.bfloat16,):
validate_result(fname, inp, model, atol=6e-3, is_fp8=use_fp8, input_names=input_names, validate_result(
te_outputs=te_outputs) fname,
inp,
model,
atol=6e-3,
is_fp8=use_fp8,
input_names=input_names,
te_outputs=te_outputs,
)
@pytest.mark.parametrize("enabled", [True, False]) @pytest.mark.parametrize("enabled", [True, False])
......
...@@ -19,6 +19,7 @@ from transformer_engine.pytorch.fp8 import ( ...@@ -19,6 +19,7 @@ from transformer_engine.pytorch.fp8 import (
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFP8Recipe: class TestFP8Recipe:
...@@ -95,8 +96,8 @@ class TestFP8Recipe: ...@@ -95,8 +96,8 @@ class TestFP8Recipe:
ref_amax_backward = amax_history_backward[-1] ref_amax_backward = amax_history_backward[-1]
else: else:
raise ValueError(f"{amax_compute_algo=} is not supported") raise ValueError(f"{amax_compute_algo=} is not supported")
ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2 ** margin) ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2**margin)
ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2 ** margin) ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2**margin)
ref_scale_inv_forward = torch.reciprocal(ref_scale_forward) ref_scale_inv_forward = torch.reciprocal(ref_scale_forward)
update_weight_amax = is_first_microbatch is None or is_first_microbatch update_weight_amax = is_first_microbatch is None or is_first_microbatch
if not update_weight_amax: if not update_weight_amax:
...@@ -128,8 +129,8 @@ class TestFP8Recipe: ...@@ -128,8 +129,8 @@ class TestFP8Recipe:
ref_amax_backward = amax_history_backward[-1] ref_amax_backward = amax_history_backward[-1]
else: else:
raise ValueError(f"{amax_compute_algo=} is not supported") raise ValueError(f"{amax_compute_algo=} is not supported")
ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2 ** margin) ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2**margin)
ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2 ** margin) ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2**margin)
ref_scale_inv_forward = torch.reciprocal(ref_scale_forward) ref_scale_inv_forward = torch.reciprocal(ref_scale_forward)
ref_scale_inv_backward = torch.reciprocal(ref_scale_backward) ref_scale_inv_backward = torch.reciprocal(ref_scale_backward)
...@@ -180,8 +181,9 @@ class TestFP8Recipe: ...@@ -180,8 +181,9 @@ class TestFP8Recipe:
scaling_factor_compute_algo = None scaling_factor_compute_algo = None
if fused_update: if fused_update:
scaling_factor_compute_algo = ( scaling_factor_compute_algo = (
lambda amax, scale, fp8_max, recipe: lambda amax, scale, fp8_max, recipe: te.fp8._default_sf_compute(
te.fp8._default_sf_compute(amax, scale, fp8_max, recipe.margin) amax, scale, fp8_max, recipe.margin
)
) )
recipe = transformer_engine.common.recipe.DelayedScaling( recipe = transformer_engine.common.recipe.DelayedScaling(
fp8_format=fp8_format, scaling_factor_compute_algo=scaling_factor_compute_algo fp8_format=fp8_format, scaling_factor_compute_algo=scaling_factor_compute_algo
...@@ -205,7 +207,9 @@ class TestFP8Recipe: ...@@ -205,7 +207,9 @@ class TestFP8Recipe:
# test different scenarios # test different scenarios
if amax_case == "zero": if amax_case == "zero":
fp8_meta[forward_key].amax_history = torch.tensor([[0]], dtype=torch.float32, device="cuda") fp8_meta[forward_key].amax_history = torch.tensor(
[[0]], dtype=torch.float32, device="cuda"
)
expected_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda") expected_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
elif amax_case == "tiny": elif amax_case == "tiny":
# calculate the minimum amax value that results in a FP32 maximum scale # calculate the minimum amax value that results in a FP32 maximum scale
...@@ -254,4 +258,6 @@ class TestFP8Recipe: ...@@ -254,4 +258,6 @@ class TestFP8Recipe:
) )
torch.testing.assert_close(fp8_meta[forward_key].scale, expected_scale) torch.testing.assert_close(fp8_meta[forward_key].scale, expected_scale)
torch.testing.assert_close(fp8_meta[forward_key].scale_inv, torch.reciprocal(expected_scale)) torch.testing.assert_close(
fp8_meta[forward_key].scale_inv, torch.reciprocal(expected_scale)
)
...@@ -30,7 +30,13 @@ from transformer_engine.pytorch import ( ...@@ -30,7 +30,13 @@ from transformer_engine.pytorch import (
) )
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, gelu, cast_to_fp8, cast_from_fp8 from transformer_engine.pytorch.cpp_extensions import (
gemm,
fp8_gemm,
gelu,
cast_to_fp8,
cast_from_fp8,
)
from transformer_engine.pytorch.module.base import get_workspace from transformer_engine.pytorch.module.base import get_workspace
from test_onnx_export import create_meta from test_onnx_export import create_meta
...@@ -75,6 +81,7 @@ class ModelConfig: ...@@ -75,6 +81,7 @@ class ModelConfig:
return False return False
return True return True
model_configs = { model_configs = {
"126m": ModelConfig(12, 2048, 2, 768, 12), "126m": ModelConfig(12, 2048, 2, 768, 12),
"small": ModelConfig(2, 32, 2, 64, 2), "small": ModelConfig(2, 32, 2, 64, 2),
...@@ -82,7 +89,7 @@ model_configs = { ...@@ -82,7 +89,7 @@ model_configs = {
} }
fp8_recipes = [ fp8_recipes = [
None, # Handles non-FP8 case None, # Handles non-FP8 case
recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3), recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3),
recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID), recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID),
recipe.DelayedScaling( recipe.DelayedScaling(
...@@ -126,6 +133,7 @@ batch_sizes_with_zero = [0, 1, 2] ...@@ -126,6 +133,7 @@ batch_sizes_with_zero = [0, 1, 2]
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu"] all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu"]
all_normalizations = ["LayerNorm", "RMSNorm"] all_normalizations = ["LayerNorm", "RMSNorm"]
def _disable_wgrads(block): def _disable_wgrads(block):
for p in block.parameters(): for p in block.parameters():
p.requires_grad = False p.requires_grad = False
...@@ -143,8 +151,17 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): ...@@ -143,8 +151,17 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
optimizer = torch.optim.SGD(block.parameters(), lr=0.1) optimizer = torch.optim.SGD(block.parameters(), lr=0.1)
# Placeholders used for capture. # Placeholders used for capture.
static_input = torch.randn(config.seq_len, config.batch_size, config.hidden_size, device='cuda', dtype=dtype, requires_grad=True) static_input = torch.randn(
static_target = torch.randn(config.seq_len, config.batch_size, config.hidden_size, device='cuda', dtype=dtype) config.seq_len,
config.batch_size,
config.hidden_size,
device="cuda",
dtype=dtype,
requires_grad=True,
)
static_target = torch.randn(
config.seq_len, config.batch_size, config.hidden_size, device="cuda", dtype=dtype
)
real_input = torch.rand_like(static_input) real_input = torch.rand_like(static_input)
real_target = torch.rand_like(static_target) real_target = torch.rand_like(static_target)
...@@ -403,11 +420,7 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz ...@@ -403,11 +420,7 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz
config = model_configs[model] config = model_configs[model]
module = RMSNorm if normalization == "RMSNorm" else LayerNorm module = RMSNorm if normalization == "RMSNorm" else LayerNorm
block = ( block = module(config.hidden_size).to(dtype=torch.float32).cuda()
module(config.hidden_size)
.to(dtype=torch.float32)
.cuda()
)
_test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad) _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
...@@ -418,9 +431,9 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz ...@@ -418,9 +431,9 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean) @pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_layernorm_linear(dtype, fp8_recipe, model, skip_wgrad, def test_sanity_layernorm_linear(
zero_centered_gamma, skip_dgrad, dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, normalization
normalization): ):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
...@@ -480,7 +493,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad): ...@@ -480,7 +493,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias): def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
config = model_configs[model] config = model_configs[model]
ffn_hidden_size = 4 * config.hidden_size ffn_hidden_size = 4 * config.hidden_size
num_tokens = bs*config.seq_len num_tokens = bs * config.seq_len
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not fp8_available:
...@@ -490,15 +503,9 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ ...@@ -490,15 +503,9 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
with fp8_model_init(enabled=use_fp8 and fp8_model_params): with fp8_model_init(enabled=use_fp8 and fp8_model_params):
te_linear = ( te_linear = Linear(
Linear( config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
config.hidden_size, ).cuda()
ffn_hidden_size,
bias=use_bias,
params_dtype=dtype
)
.cuda()
)
inp_hidden_states = torch.randn( inp_hidden_states = torch.randn(
num_tokens, config.hidden_size, dtype=dtype, requires_grad=True num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
...@@ -518,9 +525,9 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ ...@@ -518,9 +525,9 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
@pytest.mark.parametrize("skip_dgrad", all_boolean) @pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("activation", all_activations) @pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad, def test_sanity_layernorm_mlp(
zero_centered_gamma, skip_dgrad, activation, dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, activation, normalization
normalization): ):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
...@@ -557,10 +564,18 @@ def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad, ...@@ -557,10 +564,18 @@ def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad,
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean) @pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
@pytest.mark.parametrize("cpu_offload", all_boolean) @pytest.mark.parametrize("cpu_offload", all_boolean)
def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad, def test_sanity_gpt(
zero_centered_gamma, bias, activation, dtype,
normalization, parallel_attention_mlp, fp8_recipe,
cpu_offload): model,
skip_wgrad,
zero_centered_gamma,
bias,
activation,
normalization,
parallel_attention_mlp,
cpu_offload,
):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
...@@ -625,8 +640,7 @@ def test_sanity_gpt_126m(): ...@@ -625,8 +640,7 @@ def test_sanity_gpt_126m():
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
normalization):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
...@@ -683,8 +697,7 @@ def test_sanity_bert_126m(): ...@@ -683,8 +697,7 @@ def test_sanity_bert_126m():
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
normalization):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
...@@ -845,7 +858,9 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): ...@@ -845,7 +858,9 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma): def test_sanity_gradient_accumulation_fusion(
dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma
):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
...@@ -885,8 +900,7 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra ...@@ -885,8 +900,7 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
normalization):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
...@@ -919,9 +933,10 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm ...@@ -919,9 +933,10 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm
_test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad) _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad)
def test_model_multiple_cast(): def test_model_multiple_cast():
a = torch.zeros((16,16), device="cuda") a = torch.zeros((16, 16), device="cuda")
m = Linear(16,32) m = Linear(16, 32)
y = m(a) y = m(a)
assert y.dtype == torch.float32 assert y.dtype == torch.float32
...@@ -937,15 +952,11 @@ def test_model_multiple_cast(): ...@@ -937,15 +952,11 @@ def test_model_multiple_cast():
@pytest.mark.parametrize("offset", [1, 3, 5]) @pytest.mark.parametrize("offset", [1, 3, 5])
@pytest.mark.parametrize("datatype", param_types) @pytest.mark.parametrize("datatype", param_types)
def test_sanity_gemm_with_unalignment(N, offset, datatype): def test_sanity_gemm_with_unalignment(N, offset, datatype):
scratchpad = torch.randn(N*N + 2*offset, device="cuda", dtype=datatype) scratchpad = torch.randn(N * N + 2 * offset, device="cuda", dtype=datatype)
inp = torch.reshape(scratchpad[offset:-offset], (N, N)) inp = torch.reshape(scratchpad[offset:-offset], (N, N))
weight = torch.reshape(scratchpad[offset*2:], (N, N)) weight = torch.reshape(scratchpad[offset * 2 :], (N, N))
_, _, _ = gemm( _, _, _ = gemm(A=weight, B=inp, dtype=datatype, workspace=get_workspace())
A=weight,
B=inp,
dtype=datatype,
workspace=get_workspace())
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -954,38 +965,35 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype): ...@@ -954,38 +965,35 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype):
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
def test_sanity_fp8_gemm_with_unalignment(N, datatype): def test_sanity_fp8_gemm_with_unalignment(N, datatype):
offset = 16 offset = 16
scratchpad = torch.randn(N*N + offset, device="cuda", dtype=datatype) scratchpad = torch.randn(N * N + offset, device="cuda", dtype=datatype)
fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT
fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT
nb_inp_scales, nb_weight_scales = 1, N nb_inp_scales, nb_weight_scales = 1, N
scale_factor = 1. scale_factor = 1.0
meta_inp = create_meta(scale_factor, nb_inp_scales) meta_inp = create_meta(scale_factor, nb_inp_scales)
meta_weight = create_meta(scale_factor, nb_weight_scales) meta_weight = create_meta(scale_factor, nb_weight_scales)
inp_type = tex.DType.kFloat8E4M3 inp_type = tex.DType.kFloat8E4M3
weights_type = tex.DType.kFloat8E4M3 weights_type = tex.DType.kFloat8E4M3
outp_type = datatype outp_type = datatype
scratchpad_fp8 = cast_to_fp8( scratchpad_fp8 = cast_to_fp8(scratchpad, meta_weight, fp8_tensor_inp, inp_type)
scratchpad,
meta_weight,
fp8_tensor_inp,
inp_type)
inp_fp8 = torch.reshape(scratchpad_fp8[:-offset], (N, N)) inp_fp8 = torch.reshape(scratchpad_fp8[:-offset], (N, N))
weight_fp8 = torch.reshape(scratchpad_fp8[offset:], (N, N)) weight_fp8 = torch.reshape(scratchpad_fp8[offset:], (N, N))
_, _ = fp8_gemm( _, _ = fp8_gemm(
weight_fp8, weight_fp8,
meta_weight.scale_inv, meta_weight.scale_inv,
fp8_tensor_weight, fp8_tensor_weight,
inp_type, inp_type,
inp_fp8, inp_fp8,
meta_inp.scale_inv, meta_inp.scale_inv,
fp8_tensor_inp, fp8_tensor_inp,
weights_type, weights_type,
outp_type, outp_type,
get_workspace(), get_workspace(),
bias=None, bias=None,
use_bias=False, use_bias=False,
use_split_accumulator=False) use_split_accumulator=False,
)
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -3,4 +3,5 @@ ...@@ -3,4 +3,5 @@
# See LICENSE for license information. # See LICENSE for license information.
import transformer_engine.pytorch import transformer_engine.pytorch
print("OK") print("OK")
...@@ -28,7 +28,7 @@ from transformer_engine.pytorch.module.base import TransformerEngineBaseModule ...@@ -28,7 +28,7 @@ from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
def init_meta(size: int=1): def init_meta(size: int = 1):
meta = tex.FP8TensorMeta() meta = tex.FP8TensorMeta()
meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") meta.scale = torch.ones(size, dtype=torch.float32, device="cuda")
meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda")
...@@ -65,22 +65,18 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd ...@@ -65,22 +65,18 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd
self.inp_type = tex.DType.kFloat8E4M3 self.inp_type = tex.DType.kFloat8E4M3
self.weights_type = tex.DType.kFloat8E4M3 self.weights_type = tex.DType.kFloat8E4M3
self.outp_type = precision self.outp_type = precision
def get_fp8_weights_scratchpad(self, is_first_microbatch): def get_fp8_weights_scratchpad(self, is_first_microbatch):
raise RuntimeError("Method get_fp8_weights_scratchpad is dummy and should not be invoked.") raise RuntimeError(
"Method get_fp8_weights_scratchpad is dummy and should not be invoked."
)
def forward(self, inp, weight): def forward(self, inp, weight):
inp_fp8 = cast_to_fp8( inp_fp8 = cast_to_fp8(inp, self.meta_inp, self.fp8_tensor_inp, self.inp_type)
inp,
self.meta_inp,
self.fp8_tensor_inp,
self.inp_type)
weight_fp8 = cast_to_fp8( weight_fp8 = cast_to_fp8(
weight, weight, self.meta_weight, self.fp8_tensor_weight, self.weights_type
self.meta_weight, )
self.fp8_tensor_weight,
self.weights_type)
ret = fp8_gemm( ret = fp8_gemm(
weight_fp8, weight_fp8,
...@@ -95,20 +91,33 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd ...@@ -95,20 +91,33 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd
get_workspace(), get_workspace(),
bias=self.bias, bias=self.bias,
use_bias=self.use_bias, use_bias=self.use_bias,
use_split_accumulator=False) use_split_accumulator=False,
)
return ret return ret
model_in = Test_TE_Export(precision, True) model_in = Test_TE_Export(precision, True)
with te.fp8_autocast(enabled=True): with te.fp8_autocast(enabled=True):
model_in.init_fp8_metadata() model_in.init_fp8_metadata()
# scaling fwd # scaling fwd
model_in.fp8_meta["scaling_fwd"].scale = torch.ones(3, dtype=torch.float32, device="cuda") * scale_fwd model_in.fp8_meta["scaling_fwd"].scale = (
model_in.fp8_meta["scaling_fwd"].scale_inv = torch.ones(3, dtype=torch.float32, device="cuda") / scale_fwd torch.ones(3, dtype=torch.float32, device="cuda") * scale_fwd
model_in.fp8_meta["scaling_fwd"].amax_history = torch.ones(3, dtype=torch.float32, device="cuda") * history_fwd )
model_in.fp8_meta["scaling_fwd"].scale_inv = (
torch.ones(3, dtype=torch.float32, device="cuda") / scale_fwd
)
model_in.fp8_meta["scaling_fwd"].amax_history = (
torch.ones(3, dtype=torch.float32, device="cuda") * history_fwd
)
# scaling bwd # scaling bwd
model_in.fp8_meta["scaling_bwd"].scale = torch.ones(2, dtype=torch.float32, device="cuda") * scale_bwd model_in.fp8_meta["scaling_bwd"].scale = (
model_in.fp8_meta["scaling_bwd"].scale_inv = torch.ones(2, dtype=torch.float32, device="cuda") / scale_bwd torch.ones(2, dtype=torch.float32, device="cuda") * scale_bwd
model_in.fp8_meta["scaling_bwd"].amax_history = torch.ones(2, dtype=torch.float32, device="cuda") * history_bwd )
model_in.fp8_meta["scaling_bwd"].scale_inv = (
torch.ones(2, dtype=torch.float32, device="cuda") / scale_bwd
)
model_in.fp8_meta["scaling_bwd"].amax_history = (
torch.ones(2, dtype=torch.float32, device="cuda") * history_bwd
)
torch.save(model_in.state_dict(), tmp_filename) torch.save(model_in.state_dict(), tmp_filename)
...@@ -117,13 +126,27 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd ...@@ -117,13 +126,27 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd
model_out.eval() model_out.eval()
# scaling fwd # scaling fwd
assert torch.allclose(model_in.fp8_meta["scaling_fwd"].scale, model_out.fp8_meta["scaling_fwd"].scale) assert torch.allclose(
assert torch.allclose(model_in.fp8_meta["scaling_fwd"].scale_inv, model_out.fp8_meta["scaling_fwd"].scale_inv) model_in.fp8_meta["scaling_fwd"].scale, model_out.fp8_meta["scaling_fwd"].scale
assert torch.allclose(model_in.fp8_meta["scaling_fwd"].amax_history, model_out.fp8_meta["scaling_fwd"].amax_history) )
assert torch.allclose(
model_in.fp8_meta["scaling_fwd"].scale_inv, model_out.fp8_meta["scaling_fwd"].scale_inv
)
assert torch.allclose(
model_in.fp8_meta["scaling_fwd"].amax_history,
model_out.fp8_meta["scaling_fwd"].amax_history,
)
# scaling bwd # scaling bwd
assert torch.allclose(model_in.fp8_meta["scaling_bwd"].scale, model_out.fp8_meta["scaling_bwd"].scale) assert torch.allclose(
assert torch.allclose(model_in.fp8_meta["scaling_bwd"].scale_inv, model_out.fp8_meta["scaling_bwd"].scale_inv) model_in.fp8_meta["scaling_bwd"].scale, model_out.fp8_meta["scaling_bwd"].scale
assert torch.allclose(model_in.fp8_meta["scaling_bwd"].amax_history, model_out.fp8_meta["scaling_bwd"].amax_history) )
assert torch.allclose(
model_in.fp8_meta["scaling_bwd"].scale_inv, model_out.fp8_meta["scaling_bwd"].scale_inv
)
assert torch.allclose(
model_in.fp8_meta["scaling_bwd"].amax_history,
model_out.fp8_meta["scaling_bwd"].amax_history,
)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
...@@ -132,7 +155,7 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd ...@@ -132,7 +155,7 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd
def test_fp8_model_checkpoint( def test_fp8_model_checkpoint(
save_fp8_model: bool, save_fp8_model: bool,
load_fp8_model: bool, load_fp8_model: bool,
dims: Iterable[int] = [32,32], dims: Iterable[int] = [32, 32],
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
): ):
...@@ -153,7 +176,7 @@ def test_fp8_model_checkpoint( ...@@ -153,7 +176,7 @@ def test_fp8_model_checkpoint(
with te.fp8_autocast(): with te.fp8_autocast():
y_ref = model(x.detach().clone()).detach().clone() y_ref = model(x.detach().clone()).detach().clone()
fp8_meta_ref = { "scaling_fwd": {}, "scaling_bwd": {} } fp8_meta_ref = {"scaling_fwd": {}, "scaling_bwd": {}}
with te.fp8_autocast(), torch.no_grad(): with te.fp8_autocast(), torch.no_grad():
fp8_meta_fwd = model.fp8_meta["scaling_fwd"] fp8_meta_fwd = model.fp8_meta["scaling_fwd"]
fp8_meta_bwd = model.fp8_meta["scaling_bwd"] fp8_meta_bwd = model.fp8_meta["scaling_bwd"]
...@@ -168,7 +191,7 @@ def test_fp8_model_checkpoint( ...@@ -168,7 +191,7 @@ def test_fp8_model_checkpoint(
fp8_meta_bwd.scale.copy_(fp8_meta_bwd_ref["scale"]) fp8_meta_bwd.scale.copy_(fp8_meta_bwd_ref["scale"])
fp8_meta_bwd.scale_inv.copy_(fp8_meta_bwd_ref["scale_inv"]) fp8_meta_bwd.scale_inv.copy_(fp8_meta_bwd_ref["scale_inv"])
del fp8_meta_fwd, fp8_meta_bwd del fp8_meta_fwd, fp8_meta_bwd
# [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ] # [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ]
# This line copies the fp8 scale_inv from the model metadata to the weight fp8 tensor. # This line copies the fp8 scale_inv from the model metadata to the weight fp8 tensor.
# The sole purpose of the following lines is to set the scale_inv of the weight tensor, which is the simplest method. # The sole purpose of the following lines is to set the scale_inv of the weight tensor, which is the simplest method.
...@@ -226,15 +249,14 @@ def test_fp8_model_checkpoint( ...@@ -226,15 +249,14 @@ def test_fp8_model_checkpoint(
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
torch.testing.assert_close(y, y_ref, **tols) torch.testing.assert_close(y, y_ref, **tols)
# [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ] # [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ]
# When save_fp8_model=True, we load a model with weights in high precision, # When save_fp8_model=True, we load a model with weights in high precision,
# which does not include _scale_inv, # which does not include _scale_inv,
# but has the fp8 scaling factor in the meta data. This scenario can occur # but has the fp8 scaling factor in the meta data. This scenario can occur
# when using te.fp8_autocast(enabled=False, calibrating=True). # when using te.fp8_autocast(enabled=False, calibrating=True).
# #
# In such cases, the default behavior of load_state_dict is incorrect - it loads tensors first, # In such cases, the default behavior of load_state_dict is incorrect - it loads tensors first,
# followed by the fp8 metadata. This results in an incorrect _scale_inv for the tensor. This behavior # followed by the fp8 metadata. This results in an incorrect _scale_inv for the tensor. This behavior
# is corrected by overriding the _load_state_dict method from PyTorch in TransformerEngineBaseModule, # is corrected by overriding the _load_state_dict method from PyTorch in TransformerEngineBaseModule,
# to load the fp8 metadata before loading tensors. # to load the fp8 metadata before loading tensors.
# #
...@@ -262,4 +284,6 @@ def test_fp8_model_checkpoint( ...@@ -262,4 +284,6 @@ def test_fp8_model_checkpoint(
# We need to ensure that the tensor's scale_inv parameter matches its meta data. # We need to ensure that the tensor's scale_inv parameter matches its meta data.
# This is crucial to avoid confusion about which value is correct. # This is crucial to avoid confusion about which value is correct.
meta_index = model.weight._fp8_meta_index meta_index = model.weight._fp8_meta_index
torch.testing.assert_close(model.weight._scale_inv.item(), fp8_meta_fwd_ref["scale_inv"][meta_index].item()) torch.testing.assert_close(
\ No newline at end of file model.weight._scale_inv.item(), fp8_meta_fwd_ref["scale_inv"][meta_index].item()
)
...@@ -4,74 +4,59 @@ ...@@ -4,74 +4,59 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <transformer_engine/activation.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "../util/vectorized_pointwise.h" #include <transformer_engine/activation.h>
#include "../common.h"
#include "../common.h"
#include "../util/vectorized_pointwise.h"
namespace transformer_engine { namespace transformer_engine {
template <typename ComputeType, typename Param, template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
ComputeType (*OP)(ComputeType, const Param&)> void act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) {
void act_fn(const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "act_lu_input"); CheckInputTensor(input, "act_lu_input");
CheckOutputTensor(*output, "act_lu_output"); CheckOutputTensor(*output, "act_lu_output");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match."); NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
const size_t tot_elts = product(input.data.shape); const size_t tot_elts = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, input.data.dtype, IType,
constexpr int nvec = 32 / sizeof(IType); TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
VectorizedUnaryKernelLauncher<nvec, Param, OP>( output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
reinterpret_cast<const IType*>(input.data.dptr), VectorizedUnaryKernelLauncher<nvec, Param, OP>(
reinterpret_cast<OType*>(output->data.dptr), reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<const ComputeType*>(output->scale.dptr), reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<ComputeType*>(output->amax.dptr), reinterpret_cast<const ComputeType *>(output->scale.dptr),
tot_elts, reinterpret_cast<ComputeType *>(output->amax.dptr), tot_elts, {},
{}, stream);); // NOLINT(*)
stream); ); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
} }
template <typename ComputeType, typename Param, template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
ComputeType (*OP)(ComputeType, const Param&)> void dact_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) {
void dact_fn(const Tensor &grad,
const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "dact_lu_input"); CheckInputTensor(input, "dact_lu_input");
CheckInputTensor(grad, "dact_lu_input_grad"); CheckInputTensor(grad, "dact_lu_input_grad");
CheckOutputTensor(*output, "dact_lu_output"); CheckOutputTensor(*output, "dact_lu_output");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match."); NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
NVTE_CHECK(input.data.dtype == grad.data.dtype, NVTE_CHECK(input.data.dtype == grad.data.dtype, "Input and incoming gradient types must match.");
"Input and incoming gradient types must match.");
const size_t tot_elts = product(input.data.shape); const size_t tot_elts = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, input.data.dtype, IType,
constexpr int nvec = 32 / sizeof(IType); TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
VectorizedUnaryGradKernelLauncher<nvec, Param, OP>( output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
reinterpret_cast<const IType*>(grad.data.dptr), VectorizedUnaryGradKernelLauncher<nvec, Param, OP>(
reinterpret_cast<const IType*>(input.data.dptr), reinterpret_cast<const IType *>(grad.data.dptr),
reinterpret_cast<OType*>(output->data.dptr), reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<const ComputeType*>(output->scale.dptr), reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<ComputeType*>(output->amax.dptr), reinterpret_cast<const ComputeType *>(output->scale.dptr),
tot_elts, reinterpret_cast<ComputeType *>(output->amax.dptr), tot_elts, {},
{}, stream);); // NOLINT(*)
stream); ); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
} }
template <typename ComputeType, typename Param, template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
ComputeType (*OP)(ComputeType, const Param&)> void gated_act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) {
void gated_act_fn(const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "gated_act_input"); CheckInputTensor(input, "gated_act_input");
CheckOutputTensor(*output, "gated_act_output"); CheckOutputTensor(*output, "gated_act_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
...@@ -81,29 +66,23 @@ void gated_act_fn(const Tensor &input, ...@@ -81,29 +66,23 @@ void gated_act_fn(const Tensor &input,
NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2, NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2,
"Input shape[1] must be 2x larger than output shape[1]."); "Input shape[1] must be 2x larger than output shape[1].");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, input.data.dtype, IType,
constexpr int nvec = 32 / sizeof(IType); TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
GatedActivationKernelLauncher<nvec, ComputeType, Param, OP>( output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
reinterpret_cast<const IType*>(input.data.dptr), GatedActivationKernelLauncher<nvec, ComputeType, Param, OP>(
reinterpret_cast<OType*>(output->data.dptr), reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<const ComputeType*>(output->scale.dptr), reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<ComputeType*>(output->amax.dptr), reinterpret_cast<const ComputeType *>(output->scale.dptr),
output->data.shape[0], reinterpret_cast<ComputeType *>(output->amax.dptr), output->data.shape[0],
output->data.shape[1], output->data.shape[1], {},
{}, stream);); // NOLINT(*)
stream); ); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
} }
template <typename ComputeType, typename Param, template <typename ComputeType, typename Param, ComputeType (*OP1)(ComputeType, const Param &),
ComputeType (*OP1)(ComputeType, const Param&), ComputeType (*OP2)(ComputeType, const Param &)>
ComputeType (*OP2)(ComputeType, const Param&)> void dgated_act_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) {
void dgated_act_fn(const Tensor &grad,
const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(grad, "dgated_act_grad"); CheckInputTensor(grad, "dgated_act_grad");
CheckInputTensor(input, "dgated_act_input"); CheckInputTensor(input, "dgated_act_input");
CheckOutputTensor(*output, "dgated_act_output"); CheckOutputTensor(*output, "dgated_act_output");
...@@ -114,23 +93,19 @@ void dgated_act_fn(const Tensor &grad, ...@@ -114,23 +93,19 @@ void dgated_act_fn(const Tensor &grad,
"Output shape[0] must be equal to grad shape[0]."); "Output shape[0] must be equal to grad shape[0].");
NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2, NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2,
"Output shape[1] must be 2x larger than grad shape[1]."); "Output shape[1] must be 2x larger than grad shape[1].");
NVTE_CHECK(input.data.shape == output->data.shape, NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
"Input and output shapes must match.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, input.data.dtype, IType,
constexpr int nvec = 32 / sizeof(IType); TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
DGatedActivationKernelLauncher<nvec, ComputeType, Param, OP1, OP2>( output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
reinterpret_cast<const IType*>(grad.data.dptr), DGatedActivationKernelLauncher<nvec, ComputeType, Param, OP1, OP2>(
reinterpret_cast<const IType*>(input.data.dptr), reinterpret_cast<const IType *>(grad.data.dptr),
reinterpret_cast<OType*>(output->data.dptr), reinterpret_cast<const IType *>(input.data.dptr),
grad.data.shape[0], reinterpret_cast<OType *>(output->data.dptr), grad.data.shape[0], grad.data.shape[1],
grad.data.shape[1], {},
{}, stream);); // NOLINT(*)
stream); ); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
} }
} // namespace transformer_engine } // namespace transformer_engine
...@@ -3,96 +3,69 @@ ...@@ -3,96 +3,69 @@
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "./activation_template.h"
#include "../util/math.h" #include "../util/math.h"
#include "./activation_template.h"
void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
void nvte_gelu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_gelu); NVTE_API_CALL(nvte_gelu);
using namespace transformer_engine; using namespace transformer_engine;
act_fn<fp32, Empty, gelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input), act_fn<fp32, Empty, gelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output), stream);
stream);
} }
void nvte_dgelu(const NVTETensor grad, void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dgelu); NVTE_API_CALL(nvte_dgelu);
using namespace transformer_engine; using namespace transformer_engine;
dact_fn<fp32, Empty, dgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad), dact_fn<fp32, Empty, dgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output), stream);
stream);
} }
void nvte_geglu(const NVTETensor input, void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_geglu); NVTE_API_CALL(nvte_geglu);
using namespace transformer_engine; using namespace transformer_engine;
gated_act_fn<fp32, Empty, gelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input), gated_act_fn<fp32, Empty, gelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output), stream);
stream);
} }
void nvte_dgeglu(const NVTETensor grad, void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu); NVTE_API_CALL(nvte_dgeglu);
using namespace transformer_engine; using namespace transformer_engine;
dgated_act_fn<fp32, Empty, gelu<fp32, fp32>, dgelu<fp32, fp32>>( dgated_act_fn<fp32, Empty, gelu<fp32, fp32>, dgelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(input), reinterpret_cast<Tensor*>(output), stream);
reinterpret_cast<Tensor*>(output),
stream);
} }
void nvte_qgelu(const NVTETensor input, void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_qgelu); NVTE_API_CALL(nvte_qgelu);
using namespace transformer_engine; using namespace transformer_engine;
act_fn<fp32, Empty, qgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input), act_fn<fp32, Empty, qgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output), stream);
stream);
} }
void nvte_dqgelu(const NVTETensor grad, void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
const NVTETensor input, cudaStream_t stream) {
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgelu); NVTE_API_CALL(nvte_dqgelu);
using namespace transformer_engine; using namespace transformer_engine;
dact_fn<fp32, Empty, dqgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad), dact_fn<fp32, Empty, dqgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output), stream);
stream);
} }
void nvte_qgeglu(const NVTETensor input, void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_qgeglu); NVTE_API_CALL(nvte_qgeglu);
using namespace transformer_engine; using namespace transformer_engine;
gated_act_fn<fp32, Empty, qgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input), gated_act_fn<fp32, Empty, qgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output), stream);
stream);
} }
void nvte_dqgeglu(const NVTETensor grad, void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
const NVTETensor input, cudaStream_t stream) {
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgeglu); NVTE_API_CALL(nvte_dqgeglu);
using namespace transformer_engine; using namespace transformer_engine;
dgated_act_fn<fp32, Empty, qgelu<fp32, fp32>, dqgelu<fp32, fp32>>( dgated_act_fn<fp32, Empty, qgelu<fp32, fp32>, dqgelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(input), reinterpret_cast<Tensor*>(output), stream);
reinterpret_cast<Tensor*>(output),
stream);
} }
...@@ -4,96 +4,69 @@ ...@@ -4,96 +4,69 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "./activation_template.h"
#include "../util/math.h" #include "../util/math.h"
#include "./activation_template.h"
void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
void nvte_relu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_relu); NVTE_API_CALL(nvte_relu);
using namespace transformer_engine; using namespace transformer_engine;
act_fn<fp32, Empty, relu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input), act_fn<fp32, Empty, relu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output), stream);
stream);
} }
void nvte_drelu(const NVTETensor grad, void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_drelu); NVTE_API_CALL(nvte_drelu);
using namespace transformer_engine; using namespace transformer_engine;
dact_fn<fp32, Empty, drelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad), dact_fn<fp32, Empty, drelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output), stream);
stream);
} }
void nvte_reglu(const NVTETensor input, void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_reglu); NVTE_API_CALL(nvte_reglu);
using namespace transformer_engine; using namespace transformer_engine;
gated_act_fn<fp32, Empty, relu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input), gated_act_fn<fp32, Empty, relu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output), stream);
stream);
} }
void nvte_dreglu(const NVTETensor grad, void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dreglu); NVTE_API_CALL(nvte_dreglu);
using namespace transformer_engine; using namespace transformer_engine;
dgated_act_fn<fp32, Empty, relu<fp32, fp32>, drelu<fp32, fp32>>( dgated_act_fn<fp32, Empty, relu<fp32, fp32>, drelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(input), reinterpret_cast<Tensor*>(output), stream);
reinterpret_cast<Tensor*>(output),
stream);
} }
void nvte_srelu(const NVTETensor input, void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_srelu); NVTE_API_CALL(nvte_srelu);
using namespace transformer_engine; using namespace transformer_engine;
act_fn<fp32, Empty, srelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input), act_fn<fp32, Empty, srelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output), stream);
stream);
} }
void nvte_dsrelu(const NVTETensor grad, void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
const NVTETensor input, cudaStream_t stream) {
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsrelu); NVTE_API_CALL(nvte_dsrelu);
using namespace transformer_engine; using namespace transformer_engine;
dact_fn<fp32, Empty, dsrelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad), dact_fn<fp32, Empty, dsrelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output), stream);
stream);
} }
void nvte_sreglu(const NVTETensor input, void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_sreglu); NVTE_API_CALL(nvte_sreglu);
using namespace transformer_engine; using namespace transformer_engine;
gated_act_fn<fp32, Empty, srelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input), gated_act_fn<fp32, Empty, srelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output), stream);
stream);
} }
void nvte_dsreglu(const NVTETensor grad, void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
const NVTETensor input, cudaStream_t stream) {
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsreglu); NVTE_API_CALL(nvte_dsreglu);
using namespace transformer_engine; using namespace transformer_engine;
dgated_act_fn<fp32, Empty, srelu<fp32, fp32>, dsrelu<fp32, fp32>>( dgated_act_fn<fp32, Empty, srelu<fp32, fp32>, dsrelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(input), reinterpret_cast<Tensor*>(output), stream);
reinterpret_cast<Tensor*>(output),
stream);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment