Unverified Commit 6aa1fcc8 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] move mask types to fprop (#402)



* API change and some test fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* more test fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* ONNX fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixed fused attention tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rm duplicate test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 94c57e4d
...@@ -77,10 +77,10 @@ def test_dot_product_attention(dtype, bs, model, ckpt_attn, bias_type): ...@@ -77,10 +77,10 @@ def test_dot_product_attention(dtype, bs, model, ckpt_attn, bias_type):
atol, rtol = (2.5e-2, 2.5e-2) if dtype == torch.bfloat16 else (5e-3, 5e-3) atol, rtol = (2.5e-2, 2.5e-2) if dtype == torch.bfloat16 else (5e-3, 5e-3)
if bias_type == "no_bias": if bias_type == "no_bias":
assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol) assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol) assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol) assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol) assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type): def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type):
...@@ -94,18 +94,18 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type) ...@@ -94,18 +94,18 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type)
inp = torch.randn( inp = torch.randn(
config.seq_len, bs, 3, config.num_attention_heads, config.head_dim, config.seq_len, bs, 3, config.num_attention_heads, config.head_dim,
dtype = dtype).cuda() dtype=dtype).cuda()
inp.requires_grad=True inp.requires_grad=True
seqlens = torch.empty(bs, dtype = torch.int32).cuda() seqlens = torch.empty(bs, dtype=torch.int32).cuda()
seqlens.fill_(config.seq_len) seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32) cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0) cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
op_grad = torch.randn( op_grad = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim, config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda() dtype = dtype).cuda()
if bias_type != "no_bias": if bias_type != "no_bias":
bias = torch.randn(1, config.num_attention_heads, config.seq_len, config.seq_len, bias = torch.randn(1, config.num_attention_heads, config.seq_len, config.seq_len,
dtype = dtype).cuda() dtype=dtype).cuda()
else: else:
bias = None bias = None
...@@ -113,24 +113,23 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type) ...@@ -113,24 +113,23 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type)
DotProductAttention( DotProductAttention(
config.num_attention_heads, config.num_attention_heads,
config.head_dim, config.head_dim,
attention_dropout = config.dropout_p, attention_dropout=config.dropout_p,
attn_mask_type = config.attn_mask_type, sequence_parallel=False,
sequence_parallel = False, tp_size=1,
tp_size = 1, get_rng_state_tracker=get_dummy_cuda_rng_tracker,
get_rng_state_tracker = get_dummy_cuda_rng_tracker, tp_group=None,
tp_group = None, layer_number=1,
layer_number = 1, attention_type="self"
attention_type = "self" ).to(dtype=dtype).cuda()
).to(dtype = dtype).cuda()
) )
q = inp[:, :,0,:,:] q = inp[:, :,0,:,:]
k = inp[:, :,1,:,:] k = inp[:, :,1,:,:]
v = inp[:, :,2,:,:] v = inp[:, :,2,:,:]
op = block(q, k, v, op = block(q, k, v, attn_mask_type=config.attn_mask_type,
checkpoint_core_attention = ckpt_attn, checkpoint_core_attention=ckpt_attn,
core_attention_bias_type = bias_type, core_attention_bias_type=bias_type,
core_attention_bias = bias) core_attention_bias=bias)
op.backward(op_grad) op.backward(op_grad)
return op, inp.grad return op, inp.grad
...@@ -158,10 +157,10 @@ def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type): ...@@ -158,10 +157,10 @@ def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type):
atol, rtol = (5e-1, 5e-2) atol, rtol = (5e-1, 5e-2)
if bias_type == "no_bias": if bias_type == "no_bias":
assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol) assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol) assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol) assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol) assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type): def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type):
...@@ -175,12 +174,12 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type): ...@@ -175,12 +174,12 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type):
inp = torch.randn( inp = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim, config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda() dtype=dtype).cuda()
inp.requires_grad=True inp.requires_grad=True
seqlens = torch.empty(bs, dtype = torch.int32).cuda() seqlens = torch.empty(bs, dtype=torch.int32).cuda()
seqlens.fill_(config.seq_len) seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32) cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0) cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
sigma = 0.02 sigma = 0.02
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
...@@ -192,7 +191,7 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type): ...@@ -192,7 +191,7 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type):
rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)] rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
if bias_type != "no_bias": if bias_type != "no_bias":
bias = torch.randn(1, config.num_attention_heads, config.seq_len, config.seq_len, bias = torch.randn(1, config.num_attention_heads, config.seq_len, config.seq_len,
dtype = dtype).cuda() dtype=dtype).cuda()
else: else:
bias = None bias = None
...@@ -201,43 +200,42 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type): ...@@ -201,43 +200,42 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type):
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_attention_heads,
layernorm_epsilon = 1e-5, layernorm_epsilon=1e-5,
hidden_dropout = 0.0, hidden_dropout=0.0,
attention_dropout = config.dropout_p, attention_dropout=config.dropout_p,
init_method = init_method, init_method=init_method,
output_layer_init_method = output_layer_init_method, output_layer_init_method=output_layer_init_method,
layer_number = layer_number, layer_number=layer_number,
kv_channels = config.head_dim, kv_channels=config.head_dim,
self_attn_mask_type = config.attn_mask_type, tp_group=None,
tp_group = None, tp_size=1,
tp_size = 1, params_dtype=dtype,
params_dtype = dtype, get_rng_state_tracker=None,
get_rng_state_tracker = None, fuse_wgrad_accumulation=False,
fuse_wgrad_accumulation = False, seq_length=config.seq_len,
seq_length = config.seq_len, micro_batch_size=bs,
micro_batch_size = bs, sequence_parallel=False,
sequence_parallel = False, apply_residual_connection_post_layernorm=False,
apply_residual_connection_post_layernorm = False, output_layernorm=False,
output_layernorm = False, layer_type="encoder",
layer_type = "encoder", drop_path_rate=drop_path_rates[layer_number - 1],
drop_path_rate = drop_path_rates[layer_number - 1], set_parallel_mode=True,
set_parallel_mode = True, fuse_qkv_params=True,
fuse_qkv_params = True, zero_centered_gamma=False,
zero_centered_gamma = False, qkv_weight_interleaved=False,
qkv_weight_interleaved = False, ub_tp_comm_overlap=False,
ub_tp_comm_overlap = False, bias=True,
bias = True,
) )
.to(dtype = dtype) .to(dtype=dtype)
.cuda() .cuda()
) )
num_iters = 10 num_iters = 10
for i in range(num_iters): for i in range(num_iters):
op = block(inp, op = block(inp, self_attn_mask_type=config.attn_mask_type,
checkpoint_core_attention = ckpt_attn, checkpoint_core_attention=ckpt_attn,
core_attention_bias_type = bias_type, core_attention_bias_type=bias_type,
core_attention_bias = bias) core_attention_bias=bias)
loss = op.sum() loss = op.sum()
loss.backward() loss.backward()
...@@ -270,8 +268,8 @@ def test_transformer_layer_gqa(dtype, bs, model): ...@@ -270,8 +268,8 @@ def test_transformer_layer_gqa(dtype, bs, model):
dtype, bs, config, "UnfusedDotProductAttention", num_q_per_gqa_group) dtype, bs, config, "UnfusedDotProductAttention", num_q_per_gqa_group)
atol, rtol = 5e-1, 5e-2 atol, rtol = 5e-1, 5e-2
assert torch.allclose(flash_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol) assert torch.allclose(flash_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
assert torch.allclose(flash_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol) assert torch.allclose(flash_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_group): def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_group):
...@@ -282,15 +280,15 @@ def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_gr ...@@ -282,15 +280,15 @@ def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_gr
inp = torch.randn( inp = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim, config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda() dtype=dtype).cuda()
inp.requires_grad=True inp.requires_grad=True
seqlens = torch.empty(bs, dtype = torch.int32).cuda() seqlens = torch.empty(bs, dtype=torch.int32).cuda()
seqlens.fill_(config.seq_len) seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32) cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0) cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
op_grad = torch.randn( op_grad = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim, config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda() dtype=dtype).cuda()
sigma = 0.02 sigma = 0.02
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
...@@ -306,39 +304,38 @@ def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_gr ...@@ -306,39 +304,38 @@ def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_gr
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_attention_heads,
num_gqa_groups = config.num_attention_heads / num_querys_per_gqa_group, num_gqa_groups=config.num_attention_heads / num_querys_per_gqa_group,
layernorm_epsilon = 1e-5, layernorm_epsilon=1e-5,
hidden_dropout = 0.0, hidden_dropout=0.0,
attention_dropout = config.dropout_p, attention_dropout=config.dropout_p,
init_method = init_method, init_method=init_method,
output_layer_init_method = output_layer_init_method, output_layer_init_method=output_layer_init_method,
layer_number = layer_number, layer_number=layer_number,
kv_channels = config.head_dim, kv_channels=config.head_dim,
self_attn_mask_type = config.attn_mask_type, tp_group=None,
tp_group = None, tp_size= 1,
tp_size = 1, params_dtype=dtype,
params_dtype = dtype, get_rng_state_tracker=None,
get_rng_state_tracker = None, fuse_wgrad_accumulation=False,
fuse_wgrad_accumulation = False, seq_length=config.seq_len,
seq_length = config.seq_len, micro_batch_size=bs,
micro_batch_size = bs, sequence_parallel=False,
sequence_parallel = False, apply_residual_connection_post_layernorm=False,
apply_residual_connection_post_layernorm = False, output_layernorm=False,
output_layernorm = False, layer_type="encoder",
layer_type = "encoder", drop_path_rate=drop_path_rates[layer_number - 1],
drop_path_rate = drop_path_rates[layer_number - 1], set_parallel_mode=True,
set_parallel_mode = True, fuse_qkv_params=True,
fuse_qkv_params = True, zero_centered_gamma=False,
zero_centered_gamma = False, qkv_weight_interleaved=False,
qkv_weight_interleaved = False, ub_tp_comm_overlap=False,
ub_tp_comm_overlap = False, bias=True,
bias = True,
) )
.to(dtype = dtype) .to(dtype=dtype)
.cuda() .cuda()
) )
op = block(inp) op = block(inp, self_attn_mask_type=config.attn_mask_type)
op.backward(op_grad) op.backward(op_grad)
return op, inp.grad return op, inp.grad
...@@ -365,8 +362,8 @@ def test_dpa_fp8(dtype, bs, model): ...@@ -365,8 +362,8 @@ def test_dpa_fp8(dtype, bs, model):
dtype, bs, config, "UnfusedDotProductAttention") dtype, bs, config, "UnfusedDotProductAttention")
atol, rtol = (2.5e-2, 2.5e-2) atol, rtol = (2.5e-2, 2.5e-2)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol) assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol) assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
def _run_dpa_fp8(dtype, bs, config, backend): def _run_dpa_fp8(dtype, bs, config, backend):
...@@ -376,15 +373,15 @@ def _run_dpa_fp8(dtype, bs, config, backend): ...@@ -376,15 +373,15 @@ def _run_dpa_fp8(dtype, bs, config, backend):
inp = 0.01 * torch.randn( inp = 0.01 * torch.randn(
bs * config.seq_len, config.num_attention_heads * config.head_dim, bs * config.seq_len, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda() dtype=dtype).cuda()
inp.requires_grad=True inp.requires_grad=True
seqlens = torch.empty(bs, dtype = torch.int32).cuda() seqlens = torch.empty(bs, dtype=torch.int32).cuda()
seqlens.fill_(config.seq_len) seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32) cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0) cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
op_grad = 0.01 * torch.randn( op_grad = 0.01 * torch.randn(
bs * config.seq_len, config.num_attention_heads * config.head_dim, bs * config.seq_len, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda() dtype=dtype).cuda()
torch.save(op_grad, 'op_grad.pt') torch.save(op_grad, 'op_grad.pt')
fp8_recipe = recipe.DelayedScaling( fp8_recipe = recipe.DelayedScaling(
...@@ -395,7 +392,7 @@ def _run_dpa_fp8(dtype, bs, config, backend): ...@@ -395,7 +392,7 @@ def _run_dpa_fp8(dtype, bs, config, backend):
amax_compute_algo="most_recent", amax_compute_algo="most_recent",
) )
dpa = DPA_FP8(config).to(dtype = torch.float16).cuda() dpa = DPA_FP8(config).to(dtype=torch.float16).cuda()
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
op = dpa(inp, cu_seqlens, config.seq_len) op = dpa(inp, cu_seqlens, config.seq_len)
op.backward(op_grad) op.backward(op_grad)
...@@ -416,31 +413,30 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend): ...@@ -416,31 +413,30 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend):
inp = torch.load('qkv.pt').cuda() inp = torch.load('qkv.pt').cuda()
inp.requires_grad=True inp.requires_grad=True
seqlens = torch.empty(bs, dtype = torch.int32).cuda() seqlens = torch.empty(bs, dtype=torch.int32).cuda()
seqlens.fill_(config.seq_len) seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32) cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0) cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
op_grad = torch.load('op_grad.pt').cuda().view(bs, config.seq_len, -1).transpose(0,1) op_grad = torch.load('op_grad.pt').cuda().view(bs, config.seq_len, -1).transpose(0,1)
block = ( block = (
DotProductAttention( DotProductAttention(
config.num_attention_heads, config.num_attention_heads,
config.head_dim, config.head_dim,
attention_dropout = config.dropout_p, attention_dropout=config.dropout_p,
attn_mask_type = config.attn_mask_type, sequence_parallel=False,
sequence_parallel = False, tp_size=1,
tp_size = 1, get_rng_state_tracker=None,
get_rng_state_tracker = None, tp_group=None,
tp_group = None, layer_number=1,
layer_number = 1, attention_type="self"
attention_type = "self" ).to(dtype=dtype).cuda()
).to(dtype = dtype).cuda()
) )
q = inp[:, :,0,:,:] q = inp[:, :,0,:,:]
k = inp[:, :,1,:,:] k = inp[:, :,1,:,:]
v = inp[:, :,2,:,:] v = inp[:, :,2,:,:]
op = block(q, k, v) op = block(q, k, v, attn_mask_type=config.attn_mask_type)
op.backward(op_grad) op.backward(op_grad)
torch.save(op,'ctx_ref.pt') torch.save(op,'ctx_ref.pt')
torch.save(inp.grad,'dqkv_ref.pt') torch.save(inp.grad,'dqkv_ref.pt')
...@@ -533,8 +529,8 @@ class _dpa_fp8(torch.autograd.Function): ...@@ -533,8 +529,8 @@ class _dpa_fp8(torch.autograd.Function):
workspace, workspace,
bias=qkv_bias, bias=qkv_bias,
use_bias=True, use_bias=True,
out_index = META_QKV, out_index=META_QKV,
fp8_meta_tensor = fp8_meta["scaling_fwd"], fp8_meta_tensor=fp8_meta["scaling_fwd"],
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=_2X_ACC_FPROP,
D_dtype=fp8_dtype_forward, D_dtype=fp8_dtype_forward,
) )
...@@ -558,13 +554,13 @@ class _dpa_fp8(torch.autograd.Function): ...@@ -558,13 +554,13 @@ class _dpa_fp8(torch.autograd.Function):
fp8_meta["scaling_fwd"].scale[META_O], fp8_meta["scaling_fwd"].scale[META_O],
fp8_meta["scaling_fwd"].amax_history[0][META_S], fp8_meta["scaling_fwd"].amax_history[0][META_S],
fp8_meta["scaling_fwd"].amax_history[0][META_O], fp8_meta["scaling_fwd"].amax_history[0][META_O],
attn_scale = None, attn_scale=None,
dropout = p_dropout, dropout=p_dropout,
fast_zero_fill = fast_zero_fill, fast_zero_fill=fast_zero_fill,
qkv_layout = "qkv_interleaved", qkv_layout="qkv_interleaved",
attn_bias_type = "no_bias", attn_bias_type="no_bias",
attn_mask_type = "padding", attn_mask_type="padding",
rng_gen = None, rng_gen=None,
) )
M, ZInv, philox_unpacked = aux_ctx_tensors M, ZInv, philox_unpacked = aux_ctx_tensors
......
...@@ -376,8 +376,8 @@ class TorchMHA(nn.Module): ...@@ -376,8 +376,8 @@ class TorchMHA(nn.Module):
batch_first=False, batch_first=False,
) )
def forward(self, x, attn_mask=None): def forward(self, x, attention_mask=None):
output = self.mhsa(x, x, x, attn_mask=attn_mask, need_weights=False) output = self.mhsa(x, x, x, attn_mask=attention_mask, need_weights=False)
if isinstance(output, tuple): if isinstance(output, tuple):
output = output[0] output = output[0]
return output return output
...@@ -461,7 +461,7 @@ def _test_e2e_selective_recompute(block, bs, dtype, config, recompute=False): ...@@ -461,7 +461,7 @@ def _test_e2e_selective_recompute(block, bs, dtype, config, recompute=False):
te_out = block( te_out = block(
te_inp_hidden_states, te_inp_hidden_states,
te_inp_attn_mask, attention_mask=te_inp_attn_mask,
checkpoint_core_attention=recompute, checkpoint_core_attention=recompute,
) )
loss = te_out.sum() loss = te_out.sum()
...@@ -526,13 +526,13 @@ def _test_e2e_full_recompute(block, bs, dtype, config, recompute=False): ...@@ -526,13 +526,13 @@ def _test_e2e_full_recompute(block, bs, dtype, config, recompute=False):
get_dummy_cuda_rng_tracker, get_dummy_cuda_rng_tracker,
None, # tp_group None, # tp_group
te_inp_hidden_states, te_inp_hidden_states,
te_inp_attn_mask, attention_mask=te_inp_attn_mask,
checkpoint_core_attention=False, checkpoint_core_attention=False,
) )
else: else:
te_out = block( te_out = block(
te_inp_hidden_states, te_inp_hidden_states,
te_inp_attn_mask, attention_mask=te_inp_attn_mask,
checkpoint_core_attention=False, checkpoint_core_attention=False,
) )
loss = te_out.sum() loss = te_out.sum()
...@@ -766,7 +766,7 @@ def test_gpt_accuracy(dtype, bs, model): ...@@ -766,7 +766,7 @@ def test_gpt_accuracy(dtype, bs, model):
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
def _test_mha_accuracy(block, bs, dtype, config, mask_type): def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
reset_rng_states() reset_rng_states()
inp_hidden_states = torch.randn( inp_hidden_states = torch.randn(
...@@ -775,7 +775,12 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type): ...@@ -775,7 +775,12 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type):
inp_hidden_states.retain_grad() inp_hidden_states.retain_grad()
inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None
out = block(inp_hidden_states, inp_attn_mask) forward_kwargs = {}
if te:
forward_kwargs["attn_mask_type"] = mask_type
forward_kwargs["attention_mask"] = inp_attn_mask
out = block(inp_hidden_states, **forward_kwargs)
loss = out.sum() loss = out.sum()
loss.backward() loss.backward()
...@@ -801,7 +806,6 @@ def test_mha_accuracy(dtype, bs, model, mask_type): ...@@ -801,7 +806,6 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
fuse_qkv_params=True, fuse_qkv_params=True,
qkv_weight_interleaved=False, qkv_weight_interleaved=False,
input_layernorm=False, input_layernorm=False,
attn_mask_type=mask_type,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -825,8 +829,8 @@ def test_mha_accuracy(dtype, bs, model, mask_type): ...@@ -825,8 +829,8 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
torch_mha.mhsa.out_proj.weight = Parameter(te_mha.proj.weight.clone()) torch_mha.mhsa.out_proj.weight = Parameter(te_mha.proj.weight.clone())
torch_mha.mhsa.out_proj.bias = Parameter(te_mha.proj.bias.clone()) torch_mha.mhsa.out_proj.bias = Parameter(te_mha.proj.bias.clone())
te_outputs = _test_mha_accuracy(te_mha, bs, dtype, config, mask_type) te_outputs = _test_mha_accuracy(te_mha, bs, dtype, config, mask_type, te=True)
torch_outputs = _test_mha_accuracy(torch_mha, bs, dtype, config, mask_type) torch_outputs = _test_mha_accuracy(torch_mha, bs, dtype, config, mask_type, te=False)
# Check output. # Check output.
if dtype == torch.float32: if dtype == torch.float32:
......
...@@ -783,7 +783,6 @@ def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision ...@@ -783,7 +783,6 @@ def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision
self.fake_bf16_io = fake_bf16_io self.fake_bf16_io = fake_bf16_io
if self.softmax_fn == te.softmax.FusedScaleMaskSoftmax: if self.softmax_fn == te.softmax.FusedScaleMaskSoftmax:
self.fused_scaled_softmax = te.softmax.FusedScaleMaskSoftmax( self.fused_scaled_softmax = te.softmax.FusedScaleMaskSoftmax(
attn_mask_type="causal",
mask_func=te.utils.attention_mask_func, mask_func=te.utils.attention_mask_func,
softmax_in_fp32=True, softmax_in_fp32=True,
) )
...@@ -793,7 +792,7 @@ def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision ...@@ -793,7 +792,7 @@ def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision
inp = inp.type(torch.bfloat16) inp = inp.type(torch.bfloat16)
if self.fused_scaled_softmax: if self.fused_scaled_softmax:
ret = self.fused_scaled_softmax(inp, mask, self.scale) ret = self.fused_scaled_softmax(inp, mask, "causal", self.scale)
else: else:
if self.mask_inp: if self.mask_inp:
ret = self.softmax_fn.apply(inp, mask, self.scale) ret = self.softmax_fn.apply(inp, mask, self.scale)
...@@ -867,7 +866,6 @@ def test_softmax_mask_fn(seed_default_rng, precision): ...@@ -867,7 +866,6 @@ def test_softmax_mask_fn(seed_default_rng, precision):
# even when is_in_onnx_export_mode()==False. # even when is_in_onnx_export_mode()==False.
os.environ["NVTE_MASKED_SOFTMAX_FUSION"] = "0" os.environ["NVTE_MASKED_SOFTMAX_FUSION"] = "0"
self.fused_scaled_softmax = te.softmax.FusedScaleMaskSoftmax( self.fused_scaled_softmax = te.softmax.FusedScaleMaskSoftmax(
attn_mask_type="causal",
mask_func=te.utils.attention_mask_func, mask_func=te.utils.attention_mask_func,
softmax_in_fp32=True, softmax_in_fp32=True,
) )
...@@ -875,7 +873,7 @@ def test_softmax_mask_fn(seed_default_rng, precision): ...@@ -875,7 +873,7 @@ def test_softmax_mask_fn(seed_default_rng, precision):
def forward(self, inp, mask): def forward(self, inp, mask):
if self.fake_bf16_io: if self.fake_bf16_io:
inp = inp.type(torch.bfloat16) inp = inp.type(torch.bfloat16)
ret = self.fused_scaled_softmax(inp, mask, self.scale) ret = self.fused_scaled_softmax(inp, mask, "causal", scale=self.scale)
if self.fake_bf16_io: if self.fake_bf16_io:
ret = ret.type(torch.float) ret = ret.type(torch.float)
return ret return ret
...@@ -1161,13 +1159,13 @@ def test_export_core_attention( ...@@ -1161,13 +1159,13 @@ def test_export_core_attention(
query_layer = torch.randn(qkv_size, dtype=precision, device="cuda") query_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
key_layer = torch.randn(qkv_size, dtype=precision, device="cuda") key_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
value_layer = torch.randn(qkv_size, dtype=precision, device="cuda") value_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
input_names = ["query", "key", "value", "attention_mask"] input_names = ["query", "key", "value", "attention_mask", "attn_mask_type"]
attention_mask = None attention_mask = None
if use_mask: if use_mask:
# Generate a random mask with 50% probability for 0 or 1. # Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(qkv_size[1], qkv_size[2], qkv_size[0], qkv_size[0], device="cuda", dtype=precision) probs = 0.5 * torch.ones(qkv_size[1], qkv_size[2], qkv_size[0], qkv_size[0], device="cuda", dtype=precision)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
inp = (query_layer, key_layer, value_layer, attention_mask) inp = (query_layer, key_layer, value_layer, attention_mask, attn_mask_type)
mask_str = get_attn_mask_str(use_mask, attn_mask_type) mask_str = get_attn_mask_str(use_mask, attn_mask_type)
high_prec_str = dtype2str(precision) high_prec_str = dtype2str(precision)
...@@ -1177,7 +1175,6 @@ def test_export_core_attention( ...@@ -1177,7 +1175,6 @@ def test_export_core_attention(
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
kv_channels=kv_channels, kv_channels=kv_channels,
attention_dropout=0.5, attention_dropout=0.5,
attn_mask_type=attn_mask_type,
).to(device='cuda') ).to(device='cuda')
do_export(model, do_export(model,
inp, inp,
...@@ -1193,9 +1190,8 @@ def test_export_core_attention( ...@@ -1193,9 +1190,8 @@ def test_export_core_attention(
test_configs_multihead_attention = [ test_configs_multihead_attention = [
#"use_mask, attn_mask_type" #"use_mask, attn_mask_type"
(False, "causal"), # calls ScaledUpperTriangMaskedSoftmax (False, "no_mask"), # calls ScaledUpperTriangMaskedSoftmax
(True, "padding"), # calls ScaledMaskedSoftmax (True, "padding"), # calls ScaledMaskedSoftmax
(False, "padding"), # calls ScaledSoftmax
] ]
test_configs_attention_type = [ test_configs_attention_type = [
#"input_layernorm, attention_type, fuse_qkv_params" #"input_layernorm, attention_type, fuse_qkv_params"
...@@ -1269,7 +1265,6 @@ def test_export_multihead_attention( ...@@ -1269,7 +1265,6 @@ def test_export_multihead_attention(
model = te.MultiheadAttention( model = te.MultiheadAttention(
*attention_args, *attention_args,
attn_mask_type=attn_mask_type,
params_dtype=precision, params_dtype=precision,
return_layernorm_output=return_layernorm_output, return_layernorm_output=return_layernorm_output,
input_layernorm=input_layernorm, input_layernorm=input_layernorm,
...@@ -1278,8 +1273,8 @@ def test_export_multihead_attention( ...@@ -1278,8 +1273,8 @@ def test_export_multihead_attention(
return_bias=True, return_bias=True,
).to(device='cuda') ).to(device='cuda')
inp_context = (hidden_states_context, attention_mask, encoder_output) inp_context = (hidden_states_context, attention_mask, encoder_output, attn_mask_type)
input_names = ["hidden_states", "attention_mask", "encoder_output"] input_names = ["hidden_states", "attention_mask", "encoder_output", "attn_mask_type"]
output_names=["attention_output", "attention_bias"] output_names=["attention_output", "attention_bias"]
do_export(model, inp_context, fname, use_fp8, input_names=input_names, output_names=output_names, do_export(model, inp_context, fname, use_fp8, input_names=input_names, output_names=output_names,
dynamic_axes={"hidden_states": {0: "seq", 1:"bs"}, dynamic_axes={"hidden_states": {0: "seq", 1:"bs"},
...@@ -1347,13 +1342,13 @@ def test_export_transformer_layer( ...@@ -1347,13 +1342,13 @@ def test_export_transformer_layer(
num_attention_heads = 4 num_attention_heads = 4
input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
input_names = ["input", "attention_mask"] input_names = ["input", "attention_mask", "self_attn_mask_type"]
attention_mask = None attention_mask = None
if use_mask and attn_mask_type != "causal": if use_mask and attn_mask_type != "causal":
# Generate a random mask with 50% probability for 0 or 1. # Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision) probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
inp = (input_tensor, attention_mask) inp = (input_tensor, attention_mask, attn_mask_type)
fp8_str = "_fp8" if use_fp8 else "" fp8_str = "_fp8" if use_fp8 else ""
fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else "" fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else ""
...@@ -1365,7 +1360,6 @@ def test_export_transformer_layer( ...@@ -1365,7 +1360,6 @@ def test_export_transformer_layer(
hidden_size, hidden_size,
ffn_hidden_size, ffn_hidden_size,
num_attention_heads, num_attention_heads,
self_attn_mask_type=attn_mask_type,
output_layernorm=output_layernorm, output_layernorm=output_layernorm,
params_dtype=precision, params_dtype=precision,
fuse_qkv_params=fuse_qkv_params, fuse_qkv_params=fuse_qkv_params,
...@@ -1547,17 +1541,16 @@ def test_export_gpt_generation( ...@@ -1547,17 +1541,16 @@ def test_export_gpt_generation(
hidden_size, hidden_size,
ffn_hidden_size, ffn_hidden_size,
num_attention_heads, num_attention_heads,
self_attn_mask_type=attn_mask_type,
output_layernorm=output_layernorm, output_layernorm=output_layernorm,
params_dtype=precision, params_dtype=precision,
fuse_qkv_params=fuse_qkv_params, fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma).to(device='cuda') zero_centered_gamma=zero_centered_gamma).to(device='cuda')
# "Context phase": use full input sequence length # "Context phase": use full input sequence length
input_names = ["input"] input_names = ["input", "attention_mask", "self_attn_mask_type"]
output_names = ["output"] output_names = ["output"]
input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
inp = (input_tensor,) inp = (input_tensor, None, attn_mask_type)
do_export(model, inp, fname, use_fp8, do_export(model, inp, fname, use_fp8,
input_names=input_names, output_names=output_names, input_names=input_names, output_names=output_names,
dynamic_axes={"input": {0: "seq", 1:"bs"}, dynamic_axes={"input": {0: "seq", 1:"bs"},
......
...@@ -176,7 +176,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad): ...@@ -176,7 +176,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad):
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
with torch.autocast(device_type="cuda", enabled=True, dtype=dtype): with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, te_inp_attn_mask) te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
loss = te_out.sum() loss = te_out.sum()
loss.backward() loss.backward()
...@@ -217,7 +217,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_ ...@@ -217,7 +217,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, te_inp_attn_mask) te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
loss = te_out.sum() loss = te_out.sum()
loss.backward() loss.backward()
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -253,7 +253,7 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad): ...@@ -253,7 +253,7 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad):
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, te_inp_attn_mask) te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
loss = te_out.sum() loss = te_out.sum()
loss.backward() loss.backward()
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -282,7 +282,9 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad): ...@@ -282,7 +282,9 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad):
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block( te_out = block(
te_inp_hidden_states, te_inp_attn_mask, encoder_output=te_inp_hidden_states te_inp_hidden_states,
attention_mask=te_inp_attn_mask,
encoder_output=te_inp_hidden_states
) )
loss = te_out.sum() loss = te_out.sum()
loss.backward() loss.backward()
......
...@@ -196,23 +196,15 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -196,23 +196,15 @@ class UnfusedDotProductAttention(torch.nn.Module):
norm_factor: float, norm_factor: float,
attention_dropout: float = 0.0, attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext, attention_dropout_ctx: Optional[Callable] = nullcontext,
attn_mask_type: str = "causal",
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
assert (
attn_mask_type in AttnMaskTypes
), f"attn_mask_type {attn_mask_type} not supported"
self.norm_factor = norm_factor self.norm_factor = norm_factor
self.attention_dropout_ctx = attention_dropout_ctx self.attention_dropout_ctx = attention_dropout_ctx
self.layer_number = layer_number self.layer_number = layer_number
self.scale_mask_softmax = FusedScaleMaskSoftmax( self.scale_mask_softmax = FusedScaleMaskSoftmax(attention_mask_func)
attn_mask_type,
attention_mask_func,
)
# Dropout. Note that for a single iteration, this layer will generate # Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but # different outputs on different number of parallel partitions but
...@@ -228,11 +220,17 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -228,11 +220,17 @@ class UnfusedDotProductAttention(torch.nn.Module):
query_layer: torch.Tensor, query_layer: torch.Tensor,
key_layer: torch.Tensor, key_layer: torch.Tensor,
value_layer: torch.Tensor, value_layer: torch.Tensor,
attn_mask_type: str = "causal",
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""core attention fprop""" """core attention fprop"""
assert (
attn_mask_type in AttnMaskTypes
), f"attn_mask_type {attn_mask_type} not supported"
batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
...@@ -321,7 +319,8 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -321,7 +319,8 @@ class UnfusedDotProductAttention(torch.nn.Module):
# attention scores and attention mask [b, np, sq, sk] # attention scores and attention mask [b, np, sq, sk]
softmax_scale = self.layer_number if apply_qk_layer_scaling else None softmax_scale = self.layer_number if apply_qk_layer_scaling else None
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, softmax_scale) attention_probs = self.scale_mask_softmax(
attention_scores, attention_mask, attn_mask_type, softmax_scale)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
...@@ -464,7 +463,6 @@ class FlashAttention(torch.nn.Module): ...@@ -464,7 +463,6 @@ class FlashAttention(torch.nn.Module):
norm_factor: float, norm_factor: float,
attention_dropout: float = 0.0, attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext, attention_dropout_ctx: Optional[Callable] = nullcontext,
attn_mask_type: str = "causal",
deterministic: bool = False, deterministic: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -473,7 +471,6 @@ class FlashAttention(torch.nn.Module): ...@@ -473,7 +471,6 @@ class FlashAttention(torch.nn.Module):
_flash_attn_version >= _flash_attn_version_required _flash_attn_version >= _flash_attn_version_required
), f"FlashAttention minimum version {_flash_attn_version_required} is required." ), f"FlashAttention minimum version {_flash_attn_version_required} is required."
self.attn_causal_mask = attn_mask_type == "causal"
self.norm_factor = norm_factor self.norm_factor = norm_factor
self.attention_dropout_ctx = attention_dropout_ctx self.attention_dropout_ctx = attention_dropout_ctx
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
...@@ -484,6 +481,7 @@ class FlashAttention(torch.nn.Module): ...@@ -484,6 +481,7 @@ class FlashAttention(torch.nn.Module):
query_layer: torch.Tensor, query_layer: torch.Tensor,
key_layer: torch.Tensor, key_layer: torch.Tensor,
value_layer: torch.Tensor, value_layer: torch.Tensor,
attn_mask_type: str = "causal",
) -> torch.Tensor: ) -> torch.Tensor:
"""flash-attn fprop""" """flash-attn fprop"""
...@@ -531,7 +529,7 @@ class FlashAttention(torch.nn.Module): ...@@ -531,7 +529,7 @@ class FlashAttention(torch.nn.Module):
output = flash_attn_forward_func( output = flash_attn_forward_func(
query_layer, key_layer, value_layer, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, query_layer, key_layer, value_layer, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
self.attention_dropout if self.training else 0.0, self.attention_dropout if self.training else 0.0,
softmax_scale=1.0/self.norm_factor, causal=self.attn_causal_mask, softmax_scale=1.0/self.norm_factor, causal=attn_mask_type=="causal",
**fa_optional_forward_kwargs **fa_optional_forward_kwargs
) )
...@@ -703,7 +701,6 @@ class FusedAttention(torch.nn.Module): ...@@ -703,7 +701,6 @@ class FusedAttention(torch.nn.Module):
norm_factor: float, norm_factor: float,
attention_dropout: float = 0.0, attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext, attention_dropout_ctx: Optional[Callable] = nullcontext,
attn_mask_type: str = "causal",
attention_type: str = "self", attention_type: str = "self",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -711,7 +708,6 @@ class FusedAttention(torch.nn.Module): ...@@ -711,7 +708,6 @@ class FusedAttention(torch.nn.Module):
self.norm_factor = norm_factor self.norm_factor = norm_factor
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.attention_dropout_ctx = attention_dropout_ctx self.attention_dropout_ctx = attention_dropout_ctx
self.attn_mask_type = attn_mask_type
self.attention_type = attention_type self.attention_type = attention_type
self.use_FAv2_bwd = (os.getenv("NVTE_FUSED_ATTN_USE_FAv2_BWD", "1") == "1" self.use_FAv2_bwd = (os.getenv("NVTE_FUSED_ATTN_USE_FAv2_BWD", "1") == "1"
and _flash_attn_2_available and _flash_attn_2_available
...@@ -722,6 +718,7 @@ class FusedAttention(torch.nn.Module): ...@@ -722,6 +718,7 @@ class FusedAttention(torch.nn.Module):
query_layer: torch.Tensor, query_layer: torch.Tensor,
key_layer: torch.Tensor, key_layer: torch.Tensor,
value_layer: torch.Tensor, value_layer: torch.Tensor,
attn_mask_type: str = "causal",
fused_attention_backend: fused_attention_backend:
tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend, tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
...@@ -797,7 +794,7 @@ class FusedAttention(torch.nn.Module): ...@@ -797,7 +794,7 @@ class FusedAttention(torch.nn.Module):
fast_zero_fill, fast_zero_fill,
qkv_layout, qkv_layout,
core_attention_bias_type, core_attention_bias_type,
self.attn_mask_type, attn_mask_type,
None, # rng_gen None, # rng_gen
fused_attention_backend, fused_attention_backend,
use_FAv2_bwd use_FAv2_bwd
...@@ -858,7 +855,7 @@ class FusedAttention(torch.nn.Module): ...@@ -858,7 +855,7 @@ class FusedAttention(torch.nn.Module):
fast_zero_fill, fast_zero_fill,
qkv_layout, qkv_layout,
core_attention_bias_type, core_attention_bias_type,
self.attn_mask_type, attn_mask_type,
None, # rng_gen None, # rng_gen
fused_attention_backend, fused_attention_backend,
use_FAv2_bwd use_FAv2_bwd
...@@ -886,6 +883,11 @@ class DotProductAttention(torch.nn.Module): ...@@ -886,6 +883,11 @@ class DotProductAttention(torch.nn.Module):
and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`. to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
.. warning::
Argument :attr:`attn_mask_type` has been moved to the `forward` method and
is deprecated. It will be fully removed in future releases.
Parameters Parameters
---------- ----------
num_attention_heads : int num_attention_heads : int
...@@ -902,8 +904,6 @@ class DotProductAttention(torch.nn.Module): ...@@ -902,8 +904,6 @@ class DotProductAttention(torch.nn.Module):
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
attention_dropout: float, default = 0.0 attention_dropout: float, default = 0.0
dropout probability for the dropout op during multi-head attention. dropout probability for the dropout op during multi-head attention.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
type of attention mask passed into softmax operation.
layer_number: int, default = `None` layer_number: int, default = `None`
layer number of the current `DotProductAttention` when multiple such modules layer number of the current `DotProductAttention` when multiple such modules
are concatenated, for instance in consecutive transformer blocks. are concatenated, for instance in consecutive transformer blocks.
...@@ -924,7 +924,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -924,7 +924,7 @@ class DotProductAttention(torch.nn.Module):
kv_channels: int, kv_channels: int,
num_gqa_groups: Optional[int] = None, num_gqa_groups: Optional[int] = None,
attention_dropout: float = 0.0, attention_dropout: float = 0.0,
attn_mask_type: str = "causal", attn_mask_type: Optional[str] = None,
sequence_parallel: bool = False, sequence_parallel: bool = False,
tp_size: int = 1, tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None, get_rng_state_tracker: Optional[Callable] = None,
...@@ -934,6 +934,14 @@ class DotProductAttention(torch.nn.Module): ...@@ -934,6 +934,14 @@ class DotProductAttention(torch.nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
if attn_mask_type is not None:
warnings.warn(
"Argument :attr:`attn_mask_type` has been moved to the `forward` method and"
"is deprecated. It will be fully removed in future releases.",
category=DeprecationWarning,
)
self.attn_mask_type = attn_mask_type
self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group) self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
self.tp_group = tp_group self.tp_group = tp_group
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
...@@ -978,10 +986,8 @@ class DotProductAttention(torch.nn.Module): ...@@ -978,10 +986,8 @@ class DotProductAttention(torch.nn.Module):
attn_kwargs = { attn_kwargs = {
"attention_dropout": attention_dropout, "attention_dropout": attention_dropout,
"attention_dropout_ctx": attention_dropout_ctx, "attention_dropout_ctx": attention_dropout_ctx,
"attn_mask_type": attn_mask_type,
} }
self.attention_type = attention_type self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
if self.use_flash_attention: if self.use_flash_attention:
...@@ -1025,6 +1031,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -1025,6 +1031,7 @@ class DotProductAttention(torch.nn.Module):
key_layer: torch.Tensor, key_layer: torch.Tensor,
value_layer: torch.Tensor, value_layer: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
attn_mask_type: str = "causal",
checkpoint_core_attention: bool = False, checkpoint_core_attention: bool = False,
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
...@@ -1067,6 +1074,8 @@ class DotProductAttention(torch.nn.Module): ...@@ -1067,6 +1074,8 @@ class DotProductAttention(torch.nn.Module):
Value tensor. Value tensor.
attention_mask : Optional[torch.Tensor], default = `None` attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out softmax input when not using flash-attn. Boolean tensor used to mask out softmax input when not using flash-attn.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
type of attention mask passed into softmax operation.
checkpoint_core_attention : bool, default = `False` checkpoint_core_attention : bool, default = `False`
If true, forward activations for attention are recomputed If true, forward activations for attention are recomputed
during the backward pass in order to save memory that would during the backward pass in order to save memory that would
...@@ -1080,6 +1089,15 @@ class DotProductAttention(torch.nn.Module): ...@@ -1080,6 +1089,15 @@ class DotProductAttention(torch.nn.Module):
Whether to use the fast path to set output tensors to 0 or not. Whether to use the fast path to set output tensors to 0 or not.
""" """
if self.attn_mask_type is not None:
warnings.warn(
"Argument :attr:`attn_mask_type` has been moved to the `forward` method and"
"is deprecated. It will be fully removed in future releases.",
category=DeprecationWarning,
)
# Keep previous functionality for current users.
attn_mask_type = self.attn_mask_type
assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition
and value_layer.shape[-2] == self.num_gqa_groups_per_partition and value_layer.shape[-2] == self.num_gqa_groups_per_partition
), f"Keys and values must have {self.num_gqa_groups} heads!" ), f"Keys and values must have {self.num_gqa_groups} heads!"
...@@ -1102,7 +1120,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -1102,7 +1120,7 @@ class DotProductAttention(torch.nn.Module):
if not _flash_attn_2_available and self.num_gqa_groups != self.num_attention_heads: if not _flash_attn_2_available and self.num_gqa_groups != self.num_attention_heads:
use_flash_attention = False use_flash_attention = False
if self.attn_mask_type == "padding" and attention_mask is not None: if attn_mask_type == "padding" and attention_mask is not None:
use_flash_attention = False use_flash_attention = False
use_fused_attention = False use_fused_attention = False
...@@ -1121,7 +1139,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -1121,7 +1139,7 @@ class DotProductAttention(torch.nn.Module):
TE_DType[key_layer.dtype], TE_DType[key_layer.dtype],
QKVLayout[qkv_layout], QKVLayout[qkv_layout],
AttnBiasType[core_attention_bias_type], AttnBiasType[core_attention_bias_type],
AttnMaskType[self.attn_mask_type], AttnMaskType[attn_mask_type],
self.attention_dropout, self.attention_dropout,
query_layer.shape[0], key_layer.shape[0], query_layer.shape[0], key_layer.shape[0],
query_layer.shape[-1]) query_layer.shape[-1])
...@@ -1144,8 +1162,10 @@ class DotProductAttention(torch.nn.Module): ...@@ -1144,8 +1162,10 @@ class DotProductAttention(torch.nn.Module):
return self._checkpointed_attention_forward(self.flash_attention, return self._checkpointed_attention_forward(self.flash_attention,
query_layer, query_layer,
key_layer, key_layer,
value_layer) value_layer,
return self.flash_attention(query_layer, key_layer, value_layer) attn_mask_type=attn_mask_type)
return self.flash_attention(
query_layer, key_layer, value_layer, attn_mask_type=attn_mask_type)
if use_fused_attention: if use_fused_attention:
if checkpoint_core_attention: if checkpoint_core_attention:
...@@ -1153,15 +1173,17 @@ class DotProductAttention(torch.nn.Module): ...@@ -1153,15 +1173,17 @@ class DotProductAttention(torch.nn.Module):
query_layer, query_layer,
key_layer, key_layer,
value_layer, value_layer,
fused_attention_backend = fused_attention_backend, attn_mask_type=attn_mask_type,
core_attention_bias_type = core_attention_bias_type, fused_attention_backend=fused_attention_backend,
core_attention_bias = core_attention_bias, core_attention_bias_type=core_attention_bias_type,
fast_zero_fill = fast_zero_fill) core_attention_bias=core_attention_bias,
fast_zero_fill=fast_zero_fill)
return self.fused_attention(query_layer, key_layer, value_layer, return self.fused_attention(query_layer, key_layer, value_layer,
fused_attention_backend = fused_attention_backend, attn_mask_type=attn_mask_type,
core_attention_bias_type = core_attention_bias_type, fused_attention_backend=fused_attention_backend,
core_attention_bias = core_attention_bias, core_attention_bias_type=core_attention_bias_type,
fast_zero_fill = fast_zero_fill) core_attention_bias=core_attention_bias,
fast_zero_fill=fast_zero_fill)
if checkpoint_core_attention: if checkpoint_core_attention:
return self._checkpointed_attention_forward( return self._checkpointed_attention_forward(
...@@ -1169,16 +1191,18 @@ class DotProductAttention(torch.nn.Module): ...@@ -1169,16 +1191,18 @@ class DotProductAttention(torch.nn.Module):
query_layer, query_layer,
key_layer, key_layer,
value_layer, value_layer,
attention_mask = attention_mask, attn_mask_type=attn_mask_type,
core_attention_bias_type = core_attention_bias_type, attention_mask=attention_mask,
core_attention_bias = core_attention_bias, core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
) )
return self.unfused_attention(query_layer, return self.unfused_attention(query_layer,
key_layer, key_layer,
value_layer, value_layer,
attention_mask = attention_mask, attn_mask_type=attn_mask_type,
core_attention_bias_type = core_attention_bias_type, attention_mask=attention_mask,
core_attention_bias = core_attention_bias, core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
) )
...@@ -1190,7 +1214,12 @@ class MultiheadAttention(torch.nn.Module): ...@@ -1190,7 +1214,12 @@ class MultiheadAttention(torch.nn.Module):
.. note:: .. note::
Argument :attr:`attention_mask` will be ignored in the `forward` call when Argument :attr:`attention_mask` will be ignored in the `forward` call when
:attr:`self_attn_mask_type` is set to `"causal"`. :attr:`attn_mask_type` is set to `"causal"`.
.. warning::
Argument :attr:`attn_mask_type` has been moved to the `forward` method and
is deprecated. It will be fully removed in future releases.
Parameters Parameters
---------- ----------
...@@ -1217,8 +1246,6 @@ class MultiheadAttention(torch.nn.Module): ...@@ -1217,8 +1246,6 @@ class MultiheadAttention(torch.nn.Module):
layer_number: int, default = `None` layer_number: int, default = `None`
layer number of the current `TransformerLayer` when multiple such modules are layer number of the current `TransformerLayer` when multiple such modules are
concatenated to form a transformer block. concatenated to form a transformer block.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
type of attention mask passed into softmax operation.
num_gqa_groups : int, default = `None` num_gqa_groups : int, default = `None`
number of GQA groups in the transformer layer. number of GQA groups in the transformer layer.
Grouped Query Attention is described in Grouped Query Attention is described in
...@@ -1309,7 +1336,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -1309,7 +1336,7 @@ class MultiheadAttention(torch.nn.Module):
init_method: Optional[Callable] = None, init_method: Optional[Callable] = None,
output_layer_init_method: Optional[Callable] = None, output_layer_init_method: Optional[Callable] = None,
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
attn_mask_type: str = "causal", attn_mask_type: Optional[str] = None,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
tp_size: int = 1, tp_size: int = 1,
num_gqa_groups: Optional[int] = None, num_gqa_groups: Optional[int] = None,
...@@ -1334,6 +1361,15 @@ class MultiheadAttention(torch.nn.Module): ...@@ -1334,6 +1361,15 @@ class MultiheadAttention(torch.nn.Module):
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
) -> None: ) -> None:
super().__init__() super().__init__()
if attn_mask_type is not None:
warnings.warn(
"Argument :attr:`attn_mask_type` has been moved to the `forward` method and"
"is deprecated. It will be fully removed in future releases.",
category=DeprecationWarning,
)
self.attn_mask_type = attn_mask_type
self.layer_number = layer_number self.layer_number = layer_number
self.input_layernorm = input_layernorm self.input_layernorm = input_layernorm
self.attention_type = attention_type self.attention_type = attention_type
...@@ -1341,7 +1377,6 @@ class MultiheadAttention(torch.nn.Module): ...@@ -1341,7 +1377,6 @@ class MultiheadAttention(torch.nn.Module):
self.tp_group = tp_group self.tp_group = tp_group
self.return_layernorm_output = return_layernorm_output self.return_layernorm_output = return_layernorm_output
self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.attn_mask_type = attn_mask_type
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.return_bias = return_bias self.return_bias = return_bias
...@@ -1467,7 +1502,6 @@ class MultiheadAttention(torch.nn.Module): ...@@ -1467,7 +1502,6 @@ class MultiheadAttention(torch.nn.Module):
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
tp_size=tp_size, tp_size=tp_size,
get_rng_state_tracker=get_rng_state_tracker, get_rng_state_tracker=get_rng_state_tracker,
attn_mask_type=attn_mask_type,
sequence_parallel=sequence_parallel, sequence_parallel=sequence_parallel,
tp_group=tp_group, tp_group=tp_group,
layer_number=self.layer_number, layer_number=self.layer_number,
...@@ -1508,6 +1542,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -1508,6 +1542,7 @@ class MultiheadAttention(torch.nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
encoder_output: Optional[torch.Tensor] = None, encoder_output: Optional[torch.Tensor] = None,
attn_mask_type: str = "causal",
is_first_microbatch: Optional[bool] = None, is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False, checkpoint_core_attention: bool = False,
inference_params: Optional[Any] = None, inference_params: Optional[Any] = None,
...@@ -1521,7 +1556,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -1521,7 +1556,7 @@ class MultiheadAttention(torch.nn.Module):
.. note:: .. note::
Argument :attr:`attention_mask` will be ignored when :attr:`self_attn_mask_type` Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type`
is set to `"causal"`. is set to `"causal"`.
Parameters Parameters
...@@ -1530,6 +1565,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -1530,6 +1565,8 @@ class MultiheadAttention(torch.nn.Module):
Input tensor. Input tensor.
attention_mask : Optional[torch.Tensor], default = `None` attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out self-attention softmax input. Boolean tensor used to mask out self-attention softmax input.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
type of attention mask passed into softmax operation.
encoder_output : Optional[torch.Tensor], default = `None` encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`. `layer_type="decoder"`.
...@@ -1563,7 +1600,16 @@ class MultiheadAttention(torch.nn.Module): ...@@ -1563,7 +1600,16 @@ class MultiheadAttention(torch.nn.Module):
""" """
# hidden_states: [sq, b, h] # hidden_states: [sq, b, h]
if self.attn_mask_type == "padding" and attention_mask is not None: if self.attn_mask_type is not None:
warnings.warn(
"Argument :attr:`attn_mask_type` has been moved to the `forward` method and"
"is deprecated. It will be fully removed in future releases.",
category=DeprecationWarning,
)
# Keep previous functionality for current users.
attn_mask_type = self.attn_mask_type
if attn_mask_type == "padding" and attention_mask is not None:
assert ( assert (
attention_mask.dtype == torch.bool attention_mask.dtype == torch.bool
), "Attention mask must be a boolean tensor" ), "Attention mask must be a boolean tensor"
...@@ -1768,7 +1814,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -1768,7 +1814,8 @@ class MultiheadAttention(torch.nn.Module):
query_layer, query_layer,
key_layer, key_layer,
value_layer, value_layer,
attention_mask, attention_mask=attention_mask,
attn_mask_type=attn_mask_type,
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
core_attention_bias_type=core_attention_bias_type, core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
......
...@@ -215,19 +215,16 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -215,19 +215,16 @@ class FusedScaleMaskSoftmax(nn.Module):
fused operation: scaling + mask + softmax fused operation: scaling + mask + softmax
Arguments: Arguments:
attn_mask_type: attention mask type (pad or causal)
mask_func: mask function to be applied. mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision. softmax_in_fp32: if true, softmax in performed at fp32 precision.
""" """
def __init__( def __init__(
self, self,
attn_mask_type: str,
mask_func: Callable, mask_func: Callable,
softmax_in_fp32: bool = True, softmax_in_fp32: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = bool( self.scaled_masked_softmax_fusion = bool(
int(os.getenv("NVTE_MASKED_SOFTMAX_FUSION", "1")) int(os.getenv("NVTE_MASKED_SOFTMAX_FUSION", "1"))
) )
...@@ -249,6 +246,7 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -249,6 +246,7 @@ class FusedScaleMaskSoftmax(nn.Module):
self, self,
inp: torch.Tensor, inp: torch.Tensor,
mask: torch.Tensor, mask: torch.Tensor,
attn_mask_type: str,
scale: Optional[float] = None, scale: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""FusedScaleMaskSoftmax fprop""" """FusedScaleMaskSoftmax fprop"""
...@@ -257,6 +255,7 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -257,6 +255,7 @@ class FusedScaleMaskSoftmax(nn.Module):
self.input_in_fp16 = inp.dtype == torch.float16 self.input_in_fp16 = inp.dtype == torch.float16
self.input_in_bf16 = inp.dtype == torch.bfloat16 self.input_in_bf16 = inp.dtype == torch.bfloat16
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type
assert ( assert (
scale is None or self.softmax_in_fp32 scale is None or self.softmax_in_fp32
......
...@@ -73,10 +73,10 @@ class TransformerLayer(torch.nn.Module): ...@@ -73,10 +73,10 @@ class TransformerLayer(torch.nn.Module):
Arguments :attr:`attention_softmax_in_fp32` and :attr:`apply_query_key_layer_scaling` Arguments :attr:`attention_softmax_in_fp32` and :attr:`apply_query_key_layer_scaling`
are deprecated and will be fully removed in future releases. are deprecated and will be fully removed in future releases.
.. note:: .. warning::
Argument :attr:`attention_mask` will be ignored in the `forward` call when Argument :attr:`self_attn_mask_type` has been moved to the `forward` method and
:attr:`self_attn_mask_type` is set to `"causal"`. is deprecated. It will be fully removed in future releases.
Parameters Parameters
---------- ----------
...@@ -127,8 +127,6 @@ class TransformerLayer(torch.nn.Module): ...@@ -127,8 +127,6 @@ class TransformerLayer(torch.nn.Module):
kv_channels: int, default = `None` kv_channels: int, default = `None`
number of key-value channels. defaults to number of key-value channels. defaults to
:attr:`hidden_size` / :attr:`num_attention_heads` if `None`. :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
self_attn_mask_type: {'causal', 'padding'}, default = `causal`
type of attention mask passed into softmax operation.
zero_centered_gamma : bool, default = 'False' zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to the LayerNorm formula changes to
...@@ -214,7 +212,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -214,7 +212,7 @@ class TransformerLayer(torch.nn.Module):
output_layer_init_method: Optional[Callable] = None, output_layer_init_method: Optional[Callable] = None,
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
kv_channels: Optional[int] = None, kv_channels: Optional[int] = None,
self_attn_mask_type: str = "causal", self_attn_mask_type: Optional[str] = None,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
tp_size: int = 1, tp_size: int = 1,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
...@@ -241,6 +239,13 @@ class TransformerLayer(torch.nn.Module): ...@@ -241,6 +239,13 @@ class TransformerLayer(torch.nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
if self_attn_mask_type is not None:
warnings.warn(
"Argument :attr:`self_attn_mask_type` has been moved to the `forward` method and"
"is deprecated. It will be fully removed in future releases.",
category=DeprecationWarning,
)
warnings.warn( warnings.warn(
"Arguments `attention_softmax_in_fp32` and `apply_query_key_layer_scaling`" "Arguments `attention_softmax_in_fp32` and `apply_query_key_layer_scaling`"
"are deprecated and will be fully removed in future releases.", "are deprecated and will be fully removed in future releases.",
...@@ -252,6 +257,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -252,6 +257,7 @@ class TransformerLayer(torch.nn.Module):
tex.userbuf_comm_available() tex.userbuf_comm_available()
), "Userbuffer communication backend not available." ), "Userbuffer communication backend not available."
self.self_attn_mask_type = self_attn_mask_type
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
ub_tp_comm_overlap = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_OVERLAP", "1"))) ub_tp_comm_overlap = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_OVERLAP", "1")))
ub_bulk_wgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_WGRAD", "1"))) ub_bulk_wgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_WGRAD", "1")))
...@@ -265,10 +271,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -265,10 +271,7 @@ class TransformerLayer(torch.nn.Module):
self.apply_residual_connection_post_layernorm = ( self.apply_residual_connection_post_layernorm = (
apply_residual_connection_post_layernorm apply_residual_connection_post_layernorm
) )
self.self_attn_mask_type = self_attn_mask_type
assert (
self_attn_mask_type in AttnMaskTypes
), f"self_attn_mask_type {self_attn_mask_type} not supported"
assert layer_type in LayerTypes, f"layer_type {layer_type} not supported" assert layer_type in LayerTypes, f"layer_type {layer_type} not supported"
if not fuse_qkv_params: if not fuse_qkv_params:
...@@ -326,7 +329,6 @@ class TransformerLayer(torch.nn.Module): ...@@ -326,7 +329,6 @@ class TransformerLayer(torch.nn.Module):
self.self_attention = MultiheadAttention( self.self_attention = MultiheadAttention(
*attention_args, *attention_args,
**common_attention_kwargs, **common_attention_kwargs,
attn_mask_type=self_attn_mask_type,
input_layernorm=not output_layernorm, input_layernorm=not output_layernorm,
attention_type="self", attention_type="self",
bias=bias, bias=bias,
...@@ -429,6 +431,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -429,6 +431,7 @@ class TransformerLayer(torch.nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
self_attn_mask_type: str = "causal",
encoder_output: Optional[torch.Tensor] = None, encoder_output: Optional[torch.Tensor] = None,
enc_dec_attn_mask: Optional[torch.Tensor] = None, enc_dec_attn_mask: Optional[torch.Tensor] = None,
is_first_microbatch: Optional[bool] = None, is_first_microbatch: Optional[bool] = None,
...@@ -453,6 +456,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -453,6 +456,8 @@ class TransformerLayer(torch.nn.Module):
Input tensor. Input tensor.
attention_mask : Optional[torch.Tensor], default = `None` attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out self-attention softmax input. Boolean tensor used to mask out self-attention softmax input.
self_attn_mask_type: {'causal', 'padding'}, default = `causal`
type of attention mask passed into softmax operation.
encoder_output : Optional[torch.Tensor], default = `None` encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`. `layer_type="decoder"`.
...@@ -488,6 +493,19 @@ class TransformerLayer(torch.nn.Module): ...@@ -488,6 +493,19 @@ class TransformerLayer(torch.nn.Module):
Whether to set output tensors to 0 or not before use. Whether to set output tensors to 0 or not before use.
""" """
if self.self_attn_mask_type is not None:
warnings.warn(
"Argument :attr:`self_attn_mask_type` has been moved to the `forward` method and"
"is deprecated. It will be fully removed in future releases.",
category=DeprecationWarning,
)
# Keep previous functionality for current users.
self_attn_mask_type = self.self_attn_mask_type
assert (
self_attn_mask_type in AttnMaskTypes
), f"self_attn_mask_type {self_attn_mask_type} not supported"
hidden_states = hidden_states.contiguous() hidden_states = hidden_states.contiguous()
if self.sequence_parallel and self.seq_length is not None: if self.sequence_parallel and self.seq_length is not None:
...@@ -495,7 +513,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -495,7 +513,7 @@ class TransformerLayer(torch.nn.Module):
hidden_states.shape[0] == self.seq_length // self.tp_size hidden_states.shape[0] == self.seq_length // self.tp_size
), "Sequence dimension must be split across TP group when using sequence parallel." ), "Sequence dimension must be split across TP group when using sequence parallel."
if self.self_attn_mask_type != "causal" and attention_mask is not None: if self_attn_mask_type != "causal" and attention_mask is not None:
assert ( assert (
attention_mask.dtype == torch.bool attention_mask.dtype == torch.bool
), "Attention mask must be a boolean tensor" ), "Attention mask must be a boolean tensor"
...@@ -509,7 +527,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -509,7 +527,8 @@ class TransformerLayer(torch.nn.Module):
# Self attention. # Self attention.
self_attention_outputs = self.self_attention( self_attention_outputs = self.self_attention(
hidden_states, hidden_states,
attention_mask, attention_mask=attention_mask,
attn_mask_type=self_attn_mask_type,
inference_params=inference_params, inference_params=inference_params,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
...@@ -556,7 +575,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -556,7 +575,8 @@ class TransformerLayer(torch.nn.Module):
if self.layer_type == "decoder": if self.layer_type == "decoder":
inter_attention_outputs = self.inter_attention( inter_attention_outputs = self.inter_attention(
bda_output, bda_output,
enc_dec_attn_mask, attention_mask=enc_dec_attn_mask,
attn_mask_type=self_attn_mask_type,
encoder_output=encoder_output, encoder_output=encoder_output,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment