"...git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "ab3a5a8259922ce312d01be39d29e27666968039"
Unverified Commit 6aa1fcc8 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] move mask types to fprop (#402)



* API change and some test fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* more test fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* ONNX fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixed fused attention tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rm duplicate test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 94c57e4d
......@@ -77,10 +77,10 @@ def test_dot_product_attention(dtype, bs, model, ckpt_attn, bias_type):
atol, rtol = (2.5e-2, 2.5e-2) if dtype == torch.bfloat16 else (5e-3, 5e-3)
if bias_type == "no_bias":
assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type):
......@@ -94,18 +94,18 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type)
inp = torch.randn(
config.seq_len, bs, 3, config.num_attention_heads, config.head_dim,
dtype = dtype).cuda()
dtype=dtype).cuda()
inp.requires_grad=True
seqlens = torch.empty(bs, dtype = torch.int32).cuda()
seqlens = torch.empty(bs, dtype=torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
op_grad = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda()
if bias_type != "no_bias":
bias = torch.randn(1, config.num_attention_heads, config.seq_len, config.seq_len,
dtype = dtype).cuda()
dtype=dtype).cuda()
else:
bias = None
......@@ -113,24 +113,23 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type)
DotProductAttention(
config.num_attention_heads,
config.head_dim,
attention_dropout = config.dropout_p,
attn_mask_type = config.attn_mask_type,
sequence_parallel = False,
tp_size = 1,
get_rng_state_tracker = get_dummy_cuda_rng_tracker,
tp_group = None,
layer_number = 1,
attention_type = "self"
).to(dtype = dtype).cuda()
attention_dropout=config.dropout_p,
sequence_parallel=False,
tp_size=1,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
tp_group=None,
layer_number=1,
attention_type="self"
).to(dtype=dtype).cuda()
)
q = inp[:, :,0,:,:]
k = inp[:, :,1,:,:]
v = inp[:, :,2,:,:]
op = block(q, k, v,
checkpoint_core_attention = ckpt_attn,
core_attention_bias_type = bias_type,
core_attention_bias = bias)
op = block(q, k, v, attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn,
core_attention_bias_type=bias_type,
core_attention_bias=bias)
op.backward(op_grad)
return op, inp.grad
......@@ -158,10 +157,10 @@ def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type):
atol, rtol = (5e-1, 5e-2)
if bias_type == "no_bias":
assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type):
......@@ -175,12 +174,12 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type):
inp = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda()
dtype=dtype).cuda()
inp.requires_grad=True
seqlens = torch.empty(bs, dtype = torch.int32).cuda()
seqlens = torch.empty(bs, dtype=torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
sigma = 0.02
init_method = init_method_normal(sigma)
......@@ -192,7 +191,7 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type):
rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
if bias_type != "no_bias":
bias = torch.randn(1, config.num_attention_heads, config.seq_len, config.seq_len,
dtype = dtype).cuda()
dtype=dtype).cuda()
else:
bias = None
......@@ -201,43 +200,42 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type):
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon = 1e-5,
hidden_dropout = 0.0,
attention_dropout = config.dropout_p,
init_method = init_method,
output_layer_init_method = output_layer_init_method,
layer_number = layer_number,
kv_channels = config.head_dim,
self_attn_mask_type = config.attn_mask_type,
tp_group = None,
tp_size = 1,
params_dtype = dtype,
get_rng_state_tracker = None,
fuse_wgrad_accumulation = False,
seq_length = config.seq_len,
micro_batch_size = bs,
sequence_parallel = False,
apply_residual_connection_post_layernorm = False,
output_layernorm = False,
layer_type = "encoder",
drop_path_rate = drop_path_rates[layer_number - 1],
set_parallel_mode = True,
fuse_qkv_params = True,
zero_centered_gamma = False,
qkv_weight_interleaved = False,
ub_tp_comm_overlap = False,
bias = True,
layernorm_epsilon=1e-5,
hidden_dropout=0.0,
attention_dropout=config.dropout_p,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
layer_number=layer_number,
kv_channels=config.head_dim,
tp_group=None,
tp_size=1,
params_dtype=dtype,
get_rng_state_tracker=None,
fuse_wgrad_accumulation=False,
seq_length=config.seq_len,
micro_batch_size=bs,
sequence_parallel=False,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
layer_type="encoder",
drop_path_rate=drop_path_rates[layer_number - 1],
set_parallel_mode=True,
fuse_qkv_params=True,
zero_centered_gamma=False,
qkv_weight_interleaved=False,
ub_tp_comm_overlap=False,
bias=True,
)
.to(dtype = dtype)
.to(dtype=dtype)
.cuda()
)
num_iters = 10
for i in range(num_iters):
op = block(inp,
checkpoint_core_attention = ckpt_attn,
core_attention_bias_type = bias_type,
core_attention_bias = bias)
op = block(inp, self_attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn,
core_attention_bias_type=bias_type,
core_attention_bias=bias)
loss = op.sum()
loss.backward()
......@@ -270,8 +268,8 @@ def test_transformer_layer_gqa(dtype, bs, model):
dtype, bs, config, "UnfusedDotProductAttention", num_q_per_gqa_group)
atol, rtol = 5e-1, 5e-2
assert torch.allclose(flash_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(flash_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol)
assert torch.allclose(flash_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
assert torch.allclose(flash_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_group):
......@@ -282,15 +280,15 @@ def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_gr
inp = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda()
dtype=dtype).cuda()
inp.requires_grad=True
seqlens = torch.empty(bs, dtype = torch.int32).cuda()
seqlens = torch.empty(bs, dtype=torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
op_grad = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda()
dtype=dtype).cuda()
sigma = 0.02
init_method = init_method_normal(sigma)
......@@ -306,39 +304,38 @@ def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_gr
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
num_gqa_groups = config.num_attention_heads / num_querys_per_gqa_group,
layernorm_epsilon = 1e-5,
hidden_dropout = 0.0,
attention_dropout = config.dropout_p,
init_method = init_method,
output_layer_init_method = output_layer_init_method,
layer_number = layer_number,
kv_channels = config.head_dim,
self_attn_mask_type = config.attn_mask_type,
tp_group = None,
tp_size = 1,
params_dtype = dtype,
get_rng_state_tracker = None,
fuse_wgrad_accumulation = False,
seq_length = config.seq_len,
micro_batch_size = bs,
sequence_parallel = False,
apply_residual_connection_post_layernorm = False,
output_layernorm = False,
layer_type = "encoder",
drop_path_rate = drop_path_rates[layer_number - 1],
set_parallel_mode = True,
fuse_qkv_params = True,
zero_centered_gamma = False,
qkv_weight_interleaved = False,
ub_tp_comm_overlap = False,
bias = True,
num_gqa_groups=config.num_attention_heads / num_querys_per_gqa_group,
layernorm_epsilon=1e-5,
hidden_dropout=0.0,
attention_dropout=config.dropout_p,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
layer_number=layer_number,
kv_channels=config.head_dim,
tp_group=None,
tp_size= 1,
params_dtype=dtype,
get_rng_state_tracker=None,
fuse_wgrad_accumulation=False,
seq_length=config.seq_len,
micro_batch_size=bs,
sequence_parallel=False,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
layer_type="encoder",
drop_path_rate=drop_path_rates[layer_number - 1],
set_parallel_mode=True,
fuse_qkv_params=True,
zero_centered_gamma=False,
qkv_weight_interleaved=False,
ub_tp_comm_overlap=False,
bias=True,
)
.to(dtype = dtype)
.to(dtype=dtype)
.cuda()
)
op = block(inp)
op = block(inp, self_attn_mask_type=config.attn_mask_type)
op.backward(op_grad)
return op, inp.grad
......@@ -365,8 +362,8 @@ def test_dpa_fp8(dtype, bs, model):
dtype, bs, config, "UnfusedDotProductAttention")
atol, rtol = (2.5e-2, 2.5e-2)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
def _run_dpa_fp8(dtype, bs, config, backend):
......@@ -376,15 +373,15 @@ def _run_dpa_fp8(dtype, bs, config, backend):
inp = 0.01 * torch.randn(
bs * config.seq_len, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda()
dtype=dtype).cuda()
inp.requires_grad=True
seqlens = torch.empty(bs, dtype = torch.int32).cuda()
seqlens = torch.empty(bs, dtype=torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
op_grad = 0.01 * torch.randn(
bs * config.seq_len, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda()
dtype=dtype).cuda()
torch.save(op_grad, 'op_grad.pt')
fp8_recipe = recipe.DelayedScaling(
......@@ -395,7 +392,7 @@ def _run_dpa_fp8(dtype, bs, config, backend):
amax_compute_algo="most_recent",
)
dpa = DPA_FP8(config).to(dtype = torch.float16).cuda()
dpa = DPA_FP8(config).to(dtype=torch.float16).cuda()
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
op = dpa(inp, cu_seqlens, config.seq_len)
op.backward(op_grad)
......@@ -416,31 +413,30 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend):
inp = torch.load('qkv.pt').cuda()
inp.requires_grad=True
seqlens = torch.empty(bs, dtype = torch.int32).cuda()
seqlens = torch.empty(bs, dtype=torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
op_grad = torch.load('op_grad.pt').cuda().view(bs, config.seq_len, -1).transpose(0,1)
block = (
DotProductAttention(
config.num_attention_heads,
config.head_dim,
attention_dropout = config.dropout_p,
attn_mask_type = config.attn_mask_type,
sequence_parallel = False,
tp_size = 1,
get_rng_state_tracker = None,
tp_group = None,
layer_number = 1,
attention_type = "self"
).to(dtype = dtype).cuda()
attention_dropout=config.dropout_p,
sequence_parallel=False,
tp_size=1,
get_rng_state_tracker=None,
tp_group=None,
layer_number=1,
attention_type="self"
).to(dtype=dtype).cuda()
)
q = inp[:, :,0,:,:]
k = inp[:, :,1,:,:]
v = inp[:, :,2,:,:]
op = block(q, k, v)
op = block(q, k, v, attn_mask_type=config.attn_mask_type)
op.backward(op_grad)
torch.save(op,'ctx_ref.pt')
torch.save(inp.grad,'dqkv_ref.pt')
......@@ -533,8 +529,8 @@ class _dpa_fp8(torch.autograd.Function):
workspace,
bias=qkv_bias,
use_bias=True,
out_index = META_QKV,
fp8_meta_tensor = fp8_meta["scaling_fwd"],
out_index=META_QKV,
fp8_meta_tensor=fp8_meta["scaling_fwd"],
use_split_accumulator=_2X_ACC_FPROP,
D_dtype=fp8_dtype_forward,
)
......@@ -558,13 +554,13 @@ class _dpa_fp8(torch.autograd.Function):
fp8_meta["scaling_fwd"].scale[META_O],
fp8_meta["scaling_fwd"].amax_history[0][META_S],
fp8_meta["scaling_fwd"].amax_history[0][META_O],
attn_scale = None,
dropout = p_dropout,
fast_zero_fill = fast_zero_fill,
qkv_layout = "qkv_interleaved",
attn_bias_type = "no_bias",
attn_mask_type = "padding",
rng_gen = None,
attn_scale=None,
dropout=p_dropout,
fast_zero_fill=fast_zero_fill,
qkv_layout="qkv_interleaved",
attn_bias_type="no_bias",
attn_mask_type="padding",
rng_gen=None,
)
M, ZInv, philox_unpacked = aux_ctx_tensors
......
......@@ -376,8 +376,8 @@ class TorchMHA(nn.Module):
batch_first=False,
)
def forward(self, x, attn_mask=None):
output = self.mhsa(x, x, x, attn_mask=attn_mask, need_weights=False)
def forward(self, x, attention_mask=None):
output = self.mhsa(x, x, x, attn_mask=attention_mask, need_weights=False)
if isinstance(output, tuple):
output = output[0]
return output
......@@ -461,7 +461,7 @@ def _test_e2e_selective_recompute(block, bs, dtype, config, recompute=False):
te_out = block(
te_inp_hidden_states,
te_inp_attn_mask,
attention_mask=te_inp_attn_mask,
checkpoint_core_attention=recompute,
)
loss = te_out.sum()
......@@ -526,13 +526,13 @@ def _test_e2e_full_recompute(block, bs, dtype, config, recompute=False):
get_dummy_cuda_rng_tracker,
None, # tp_group
te_inp_hidden_states,
te_inp_attn_mask,
attention_mask=te_inp_attn_mask,
checkpoint_core_attention=False,
)
else:
te_out = block(
te_inp_hidden_states,
te_inp_attn_mask,
attention_mask=te_inp_attn_mask,
checkpoint_core_attention=False,
)
loss = te_out.sum()
......@@ -766,7 +766,7 @@ def test_gpt_accuracy(dtype, bs, model):
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
def _test_mha_accuracy(block, bs, dtype, config, mask_type):
def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
reset_rng_states()
inp_hidden_states = torch.randn(
......@@ -775,7 +775,12 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type):
inp_hidden_states.retain_grad()
inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None
out = block(inp_hidden_states, inp_attn_mask)
forward_kwargs = {}
if te:
forward_kwargs["attn_mask_type"] = mask_type
forward_kwargs["attention_mask"] = inp_attn_mask
out = block(inp_hidden_states, **forward_kwargs)
loss = out.sum()
loss.backward()
......@@ -801,7 +806,6 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
fuse_qkv_params=True,
qkv_weight_interleaved=False,
input_layernorm=False,
attn_mask_type=mask_type,
)
.to(dtype=dtype)
.cuda()
......@@ -825,8 +829,8 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
torch_mha.mhsa.out_proj.weight = Parameter(te_mha.proj.weight.clone())
torch_mha.mhsa.out_proj.bias = Parameter(te_mha.proj.bias.clone())
te_outputs = _test_mha_accuracy(te_mha, bs, dtype, config, mask_type)
torch_outputs = _test_mha_accuracy(torch_mha, bs, dtype, config, mask_type)
te_outputs = _test_mha_accuracy(te_mha, bs, dtype, config, mask_type, te=True)
torch_outputs = _test_mha_accuracy(torch_mha, bs, dtype, config, mask_type, te=False)
# Check output.
if dtype == torch.float32:
......
......@@ -783,7 +783,6 @@ def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision
self.fake_bf16_io = fake_bf16_io
if self.softmax_fn == te.softmax.FusedScaleMaskSoftmax:
self.fused_scaled_softmax = te.softmax.FusedScaleMaskSoftmax(
attn_mask_type="causal",
mask_func=te.utils.attention_mask_func,
softmax_in_fp32=True,
)
......@@ -793,7 +792,7 @@ def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision
inp = inp.type(torch.bfloat16)
if self.fused_scaled_softmax:
ret = self.fused_scaled_softmax(inp, mask, self.scale)
ret = self.fused_scaled_softmax(inp, mask, "causal", self.scale)
else:
if self.mask_inp:
ret = self.softmax_fn.apply(inp, mask, self.scale)
......@@ -867,7 +866,6 @@ def test_softmax_mask_fn(seed_default_rng, precision):
# even when is_in_onnx_export_mode()==False.
os.environ["NVTE_MASKED_SOFTMAX_FUSION"] = "0"
self.fused_scaled_softmax = te.softmax.FusedScaleMaskSoftmax(
attn_mask_type="causal",
mask_func=te.utils.attention_mask_func,
softmax_in_fp32=True,
)
......@@ -875,7 +873,7 @@ def test_softmax_mask_fn(seed_default_rng, precision):
def forward(self, inp, mask):
if self.fake_bf16_io:
inp = inp.type(torch.bfloat16)
ret = self.fused_scaled_softmax(inp, mask, self.scale)
ret = self.fused_scaled_softmax(inp, mask, "causal", scale=self.scale)
if self.fake_bf16_io:
ret = ret.type(torch.float)
return ret
......@@ -1161,13 +1159,13 @@ def test_export_core_attention(
query_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
key_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
value_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
input_names = ["query", "key", "value", "attention_mask"]
input_names = ["query", "key", "value", "attention_mask", "attn_mask_type"]
attention_mask = None
if use_mask:
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(qkv_size[1], qkv_size[2], qkv_size[0], qkv_size[0], device="cuda", dtype=precision)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
inp = (query_layer, key_layer, value_layer, attention_mask)
inp = (query_layer, key_layer, value_layer, attention_mask, attn_mask_type)
mask_str = get_attn_mask_str(use_mask, attn_mask_type)
high_prec_str = dtype2str(precision)
......@@ -1177,7 +1175,6 @@ def test_export_core_attention(
num_attention_heads=num_attention_heads,
kv_channels=kv_channels,
attention_dropout=0.5,
attn_mask_type=attn_mask_type,
).to(device='cuda')
do_export(model,
inp,
......@@ -1193,9 +1190,8 @@ def test_export_core_attention(
test_configs_multihead_attention = [
#"use_mask, attn_mask_type"
(False, "causal"), # calls ScaledUpperTriangMaskedSoftmax
(False, "no_mask"), # calls ScaledUpperTriangMaskedSoftmax
(True, "padding"), # calls ScaledMaskedSoftmax
(False, "padding"), # calls ScaledSoftmax
]
test_configs_attention_type = [
#"input_layernorm, attention_type, fuse_qkv_params"
......@@ -1269,7 +1265,6 @@ def test_export_multihead_attention(
model = te.MultiheadAttention(
*attention_args,
attn_mask_type=attn_mask_type,
params_dtype=precision,
return_layernorm_output=return_layernorm_output,
input_layernorm=input_layernorm,
......@@ -1278,8 +1273,8 @@ def test_export_multihead_attention(
return_bias=True,
).to(device='cuda')
inp_context = (hidden_states_context, attention_mask, encoder_output)
input_names = ["hidden_states", "attention_mask", "encoder_output"]
inp_context = (hidden_states_context, attention_mask, encoder_output, attn_mask_type)
input_names = ["hidden_states", "attention_mask", "encoder_output", "attn_mask_type"]
output_names=["attention_output", "attention_bias"]
do_export(model, inp_context, fname, use_fp8, input_names=input_names, output_names=output_names,
dynamic_axes={"hidden_states": {0: "seq", 1:"bs"},
......@@ -1347,13 +1342,13 @@ def test_export_transformer_layer(
num_attention_heads = 4
input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
input_names = ["input", "attention_mask"]
input_names = ["input", "attention_mask", "self_attn_mask_type"]
attention_mask = None
if use_mask and attn_mask_type != "causal":
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
inp = (input_tensor, attention_mask)
inp = (input_tensor, attention_mask, attn_mask_type)
fp8_str = "_fp8" if use_fp8 else ""
fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else ""
......@@ -1365,7 +1360,6 @@ def test_export_transformer_layer(
hidden_size,
ffn_hidden_size,
num_attention_heads,
self_attn_mask_type=attn_mask_type,
output_layernorm=output_layernorm,
params_dtype=precision,
fuse_qkv_params=fuse_qkv_params,
......@@ -1547,17 +1541,16 @@ def test_export_gpt_generation(
hidden_size,
ffn_hidden_size,
num_attention_heads,
self_attn_mask_type=attn_mask_type,
output_layernorm=output_layernorm,
params_dtype=precision,
fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma).to(device='cuda')
# "Context phase": use full input sequence length
input_names = ["input"]
input_names = ["input", "attention_mask", "self_attn_mask_type"]
output_names = ["output"]
input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
inp = (input_tensor,)
inp = (input_tensor, None, attn_mask_type)
do_export(model, inp, fname, use_fp8,
input_names=input_names, output_names=output_names,
dynamic_axes={"input": {0: "seq", 1:"bs"},
......
......@@ -176,7 +176,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad):
use_fp8 = fp8_recipe is not None
with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, te_inp_attn_mask)
te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
loss = te_out.sum()
loss.backward()
......@@ -217,7 +217,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_
use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, te_inp_attn_mask)
te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
......@@ -253,7 +253,7 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad):
use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, te_inp_attn_mask)
te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
......@@ -282,7 +282,9 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad):
use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(
te_inp_hidden_states, te_inp_attn_mask, encoder_output=te_inp_hidden_states
te_inp_hidden_states,
attention_mask=te_inp_attn_mask,
encoder_output=te_inp_hidden_states
)
loss = te_out.sum()
loss.backward()
......
......@@ -196,23 +196,15 @@ class UnfusedDotProductAttention(torch.nn.Module):
norm_factor: float,
attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext,
attn_mask_type: str = "causal",
layer_number: Optional[int] = None,
) -> None:
super().__init__()
assert (
attn_mask_type in AttnMaskTypes
), f"attn_mask_type {attn_mask_type} not supported"
self.norm_factor = norm_factor
self.attention_dropout_ctx = attention_dropout_ctx
self.layer_number = layer_number
self.scale_mask_softmax = FusedScaleMaskSoftmax(
attn_mask_type,
attention_mask_func,
)
self.scale_mask_softmax = FusedScaleMaskSoftmax(attention_mask_func)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
......@@ -228,11 +220,17 @@ class UnfusedDotProductAttention(torch.nn.Module):
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attn_mask_type: str = "causal",
attention_mask: Optional[torch.Tensor] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""core attention fprop"""
assert (
attn_mask_type in AttnMaskTypes
), f"attn_mask_type {attn_mask_type} not supported"
batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
......@@ -321,7 +319,8 @@ class UnfusedDotProductAttention(torch.nn.Module):
# attention scores and attention mask [b, np, sq, sk]
softmax_scale = self.layer_number if apply_qk_layer_scaling else None
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, softmax_scale)
attention_probs = self.scale_mask_softmax(
attention_scores, attention_mask, attn_mask_type, softmax_scale)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
......@@ -464,7 +463,6 @@ class FlashAttention(torch.nn.Module):
norm_factor: float,
attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext,
attn_mask_type: str = "causal",
deterministic: bool = False,
) -> None:
super().__init__()
......@@ -473,7 +471,6 @@ class FlashAttention(torch.nn.Module):
_flash_attn_version >= _flash_attn_version_required
), f"FlashAttention minimum version {_flash_attn_version_required} is required."
self.attn_causal_mask = attn_mask_type == "causal"
self.norm_factor = norm_factor
self.attention_dropout_ctx = attention_dropout_ctx
self.attention_dropout = attention_dropout
......@@ -484,6 +481,7 @@ class FlashAttention(torch.nn.Module):
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attn_mask_type: str = "causal",
) -> torch.Tensor:
"""flash-attn fprop"""
......@@ -531,7 +529,7 @@ class FlashAttention(torch.nn.Module):
output = flash_attn_forward_func(
query_layer, key_layer, value_layer, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
self.attention_dropout if self.training else 0.0,
softmax_scale=1.0/self.norm_factor, causal=self.attn_causal_mask,
softmax_scale=1.0/self.norm_factor, causal=attn_mask_type=="causal",
**fa_optional_forward_kwargs
)
......@@ -703,7 +701,6 @@ class FusedAttention(torch.nn.Module):
norm_factor: float,
attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext,
attn_mask_type: str = "causal",
attention_type: str = "self",
) -> None:
super().__init__()
......@@ -711,7 +708,6 @@ class FusedAttention(torch.nn.Module):
self.norm_factor = norm_factor
self.attention_dropout = attention_dropout
self.attention_dropout_ctx = attention_dropout_ctx
self.attn_mask_type = attn_mask_type
self.attention_type = attention_type
self.use_FAv2_bwd = (os.getenv("NVTE_FUSED_ATTN_USE_FAv2_BWD", "1") == "1"
and _flash_attn_2_available
......@@ -722,6 +718,7 @@ class FusedAttention(torch.nn.Module):
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attn_mask_type: str = "causal",
fused_attention_backend:
tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
core_attention_bias_type: str = "no_bias",
......@@ -797,7 +794,7 @@ class FusedAttention(torch.nn.Module):
fast_zero_fill,
qkv_layout,
core_attention_bias_type,
self.attn_mask_type,
attn_mask_type,
None, # rng_gen
fused_attention_backend,
use_FAv2_bwd
......@@ -858,7 +855,7 @@ class FusedAttention(torch.nn.Module):
fast_zero_fill,
qkv_layout,
core_attention_bias_type,
self.attn_mask_type,
attn_mask_type,
None, # rng_gen
fused_attention_backend,
use_FAv2_bwd
......@@ -886,6 +883,11 @@ class DotProductAttention(torch.nn.Module):
and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
.. warning::
Argument :attr:`attn_mask_type` has been moved to the `forward` method and
is deprecated. It will be fully removed in future releases.
Parameters
----------
num_attention_heads : int
......@@ -902,8 +904,6 @@ class DotProductAttention(torch.nn.Module):
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
attention_dropout: float, default = 0.0
dropout probability for the dropout op during multi-head attention.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
type of attention mask passed into softmax operation.
layer_number: int, default = `None`
layer number of the current `DotProductAttention` when multiple such modules
are concatenated, for instance in consecutive transformer blocks.
......@@ -924,7 +924,7 @@ class DotProductAttention(torch.nn.Module):
kv_channels: int,
num_gqa_groups: Optional[int] = None,
attention_dropout: float = 0.0,
attn_mask_type: str = "causal",
attn_mask_type: Optional[str] = None,
sequence_parallel: bool = False,
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
......@@ -934,6 +934,14 @@ class DotProductAttention(torch.nn.Module):
) -> None:
super().__init__()
if attn_mask_type is not None:
warnings.warn(
"Argument :attr:`attn_mask_type` has been moved to the `forward` method and"
"is deprecated. It will be fully removed in future releases.",
category=DeprecationWarning,
)
self.attn_mask_type = attn_mask_type
self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
self.tp_group = tp_group
self.get_rng_state_tracker = get_rng_state_tracker
......@@ -978,10 +986,8 @@ class DotProductAttention(torch.nn.Module):
attn_kwargs = {
"attention_dropout": attention_dropout,
"attention_dropout_ctx": attention_dropout_ctx,
"attn_mask_type": attn_mask_type,
}
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
self.attention_dropout = attention_dropout
if self.use_flash_attention:
......@@ -1025,6 +1031,7 @@ class DotProductAttention(torch.nn.Module):
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attn_mask_type: str = "causal",
checkpoint_core_attention: bool = False,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
......@@ -1067,6 +1074,8 @@ class DotProductAttention(torch.nn.Module):
Value tensor.
attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out softmax input when not using flash-attn.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
type of attention mask passed into softmax operation.
checkpoint_core_attention : bool, default = `False`
If true, forward activations for attention are recomputed
during the backward pass in order to save memory that would
......@@ -1080,6 +1089,15 @@ class DotProductAttention(torch.nn.Module):
Whether to use the fast path to set output tensors to 0 or not.
"""
if self.attn_mask_type is not None:
warnings.warn(
"Argument :attr:`attn_mask_type` has been moved to the `forward` method and"
"is deprecated. It will be fully removed in future releases.",
category=DeprecationWarning,
)
# Keep previous functionality for current users.
attn_mask_type = self.attn_mask_type
assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition
and value_layer.shape[-2] == self.num_gqa_groups_per_partition
), f"Keys and values must have {self.num_gqa_groups} heads!"
......@@ -1102,7 +1120,7 @@ class DotProductAttention(torch.nn.Module):
if not _flash_attn_2_available and self.num_gqa_groups != self.num_attention_heads:
use_flash_attention = False
if self.attn_mask_type == "padding" and attention_mask is not None:
if attn_mask_type == "padding" and attention_mask is not None:
use_flash_attention = False
use_fused_attention = False
......@@ -1121,7 +1139,7 @@ class DotProductAttention(torch.nn.Module):
TE_DType[key_layer.dtype],
QKVLayout[qkv_layout],
AttnBiasType[core_attention_bias_type],
AttnMaskType[self.attn_mask_type],
AttnMaskType[attn_mask_type],
self.attention_dropout,
query_layer.shape[0], key_layer.shape[0],
query_layer.shape[-1])
......@@ -1144,8 +1162,10 @@ class DotProductAttention(torch.nn.Module):
return self._checkpointed_attention_forward(self.flash_attention,
query_layer,
key_layer,
value_layer)
return self.flash_attention(query_layer, key_layer, value_layer)
value_layer,
attn_mask_type=attn_mask_type)
return self.flash_attention(
query_layer, key_layer, value_layer, attn_mask_type=attn_mask_type)
if use_fused_attention:
if checkpoint_core_attention:
......@@ -1153,15 +1173,17 @@ class DotProductAttention(torch.nn.Module):
query_layer,
key_layer,
value_layer,
fused_attention_backend = fused_attention_backend,
core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias,
fast_zero_fill = fast_zero_fill)
attn_mask_type=attn_mask_type,
fused_attention_backend=fused_attention_backend,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
fast_zero_fill=fast_zero_fill)
return self.fused_attention(query_layer, key_layer, value_layer,
fused_attention_backend = fused_attention_backend,
core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias,
fast_zero_fill = fast_zero_fill)
attn_mask_type=attn_mask_type,
fused_attention_backend=fused_attention_backend,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
fast_zero_fill=fast_zero_fill)
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
......@@ -1169,16 +1191,18 @@ class DotProductAttention(torch.nn.Module):
query_layer,
key_layer,
value_layer,
attention_mask = attention_mask,
core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
)
return self.unfused_attention(query_layer,
key_layer,
value_layer,
attention_mask = attention_mask,
core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
)
......@@ -1190,7 +1214,12 @@ class MultiheadAttention(torch.nn.Module):
.. note::
Argument :attr:`attention_mask` will be ignored in the `forward` call when
:attr:`self_attn_mask_type` is set to `"causal"`.
:attr:`attn_mask_type` is set to `"causal"`.
.. warning::
Argument :attr:`attn_mask_type` has been moved to the `forward` method and
is deprecated. It will be fully removed in future releases.
Parameters
----------
......@@ -1217,8 +1246,6 @@ class MultiheadAttention(torch.nn.Module):
layer_number: int, default = `None`
layer number of the current `TransformerLayer` when multiple such modules are
concatenated to form a transformer block.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
type of attention mask passed into softmax operation.
num_gqa_groups : int, default = `None`
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
......@@ -1309,7 +1336,7 @@ class MultiheadAttention(torch.nn.Module):
init_method: Optional[Callable] = None,
output_layer_init_method: Optional[Callable] = None,
layer_number: Optional[int] = None,
attn_mask_type: str = "causal",
attn_mask_type: Optional[str] = None,
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
num_gqa_groups: Optional[int] = None,
......@@ -1334,6 +1361,15 @@ class MultiheadAttention(torch.nn.Module):
device: Union[torch.device, str] = "cuda",
) -> None:
super().__init__()
if attn_mask_type is not None:
warnings.warn(
"Argument :attr:`attn_mask_type` has been moved to the `forward` method and"
"is deprecated. It will be fully removed in future releases.",
category=DeprecationWarning,
)
self.attn_mask_type = attn_mask_type
self.layer_number = layer_number
self.input_layernorm = input_layernorm
self.attention_type = attention_type
......@@ -1341,7 +1377,6 @@ class MultiheadAttention(torch.nn.Module):
self.tp_group = tp_group
self.return_layernorm_output = return_layernorm_output
self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.attn_mask_type = attn_mask_type
self.num_attention_heads = num_attention_heads
self.return_bias = return_bias
......@@ -1467,7 +1502,6 @@ class MultiheadAttention(torch.nn.Module):
attention_dropout=attention_dropout,
tp_size=tp_size,
get_rng_state_tracker=get_rng_state_tracker,
attn_mask_type=attn_mask_type,
sequence_parallel=sequence_parallel,
tp_group=tp_group,
layer_number=self.layer_number,
......@@ -1508,6 +1542,7 @@ class MultiheadAttention(torch.nn.Module):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_output: Optional[torch.Tensor] = None,
attn_mask_type: str = "causal",
is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False,
inference_params: Optional[Any] = None,
......@@ -1521,7 +1556,7 @@ class MultiheadAttention(torch.nn.Module):
.. note::
Argument :attr:`attention_mask` will be ignored when :attr:`self_attn_mask_type`
Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type`
is set to `"causal"`.
Parameters
......@@ -1530,6 +1565,8 @@ class MultiheadAttention(torch.nn.Module):
Input tensor.
attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out self-attention softmax input.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
type of attention mask passed into softmax operation.
encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`.
......@@ -1563,7 +1600,16 @@ class MultiheadAttention(torch.nn.Module):
"""
# hidden_states: [sq, b, h]
if self.attn_mask_type == "padding" and attention_mask is not None:
if self.attn_mask_type is not None:
warnings.warn(
"Argument :attr:`attn_mask_type` has been moved to the `forward` method and"
"is deprecated. It will be fully removed in future releases.",
category=DeprecationWarning,
)
# Keep previous functionality for current users.
attn_mask_type = self.attn_mask_type
if attn_mask_type == "padding" and attention_mask is not None:
assert (
attention_mask.dtype == torch.bool
), "Attention mask must be a boolean tensor"
......@@ -1768,7 +1814,8 @@ class MultiheadAttention(torch.nn.Module):
query_layer,
key_layer,
value_layer,
attention_mask,
attention_mask=attention_mask,
attn_mask_type=attn_mask_type,
checkpoint_core_attention=checkpoint_core_attention,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
......
......@@ -215,19 +215,16 @@ class FusedScaleMaskSoftmax(nn.Module):
fused operation: scaling + mask + softmax
Arguments:
attn_mask_type: attention mask type (pad or causal)
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
"""
def __init__(
self,
attn_mask_type: str,
mask_func: Callable,
softmax_in_fp32: bool = True,
) -> None:
super().__init__()
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = bool(
int(os.getenv("NVTE_MASKED_SOFTMAX_FUSION", "1"))
)
......@@ -249,6 +246,7 @@ class FusedScaleMaskSoftmax(nn.Module):
self,
inp: torch.Tensor,
mask: torch.Tensor,
attn_mask_type: str,
scale: Optional[float] = None,
) -> torch.Tensor:
"""FusedScaleMaskSoftmax fprop"""
......@@ -257,6 +255,7 @@ class FusedScaleMaskSoftmax(nn.Module):
self.input_in_fp16 = inp.dtype == torch.float16
self.input_in_bf16 = inp.dtype == torch.bfloat16
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type
assert (
scale is None or self.softmax_in_fp32
......
......@@ -73,10 +73,10 @@ class TransformerLayer(torch.nn.Module):
Arguments :attr:`attention_softmax_in_fp32` and :attr:`apply_query_key_layer_scaling`
are deprecated and will be fully removed in future releases.
.. note::
.. warning::
Argument :attr:`attention_mask` will be ignored in the `forward` call when
:attr:`self_attn_mask_type` is set to `"causal"`.
Argument :attr:`self_attn_mask_type` has been moved to the `forward` method and
is deprecated. It will be fully removed in future releases.
Parameters
----------
......@@ -127,8 +127,6 @@ class TransformerLayer(torch.nn.Module):
kv_channels: int, default = `None`
number of key-value channels. defaults to
:attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
self_attn_mask_type: {'causal', 'padding'}, default = `causal`
type of attention mask passed into softmax operation.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
......@@ -214,7 +212,7 @@ class TransformerLayer(torch.nn.Module):
output_layer_init_method: Optional[Callable] = None,
layer_number: Optional[int] = None,
kv_channels: Optional[int] = None,
self_attn_mask_type: str = "causal",
self_attn_mask_type: Optional[str] = None,
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
params_dtype: Optional[torch.dtype] = None,
......@@ -241,6 +239,13 @@ class TransformerLayer(torch.nn.Module):
) -> None:
super().__init__()
if self_attn_mask_type is not None:
warnings.warn(
"Argument :attr:`self_attn_mask_type` has been moved to the `forward` method and"
"is deprecated. It will be fully removed in future releases.",
category=DeprecationWarning,
)
warnings.warn(
"Arguments `attention_softmax_in_fp32` and `apply_query_key_layer_scaling`"
"are deprecated and will be fully removed in future releases.",
......@@ -252,6 +257,7 @@ class TransformerLayer(torch.nn.Module):
tex.userbuf_comm_available()
), "Userbuffer communication backend not available."
self.self_attn_mask_type = self_attn_mask_type
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
ub_tp_comm_overlap = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_OVERLAP", "1")))
ub_bulk_wgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_WGRAD", "1")))
......@@ -265,10 +271,7 @@ class TransformerLayer(torch.nn.Module):
self.apply_residual_connection_post_layernorm = (
apply_residual_connection_post_layernorm
)
self.self_attn_mask_type = self_attn_mask_type
assert (
self_attn_mask_type in AttnMaskTypes
), f"self_attn_mask_type {self_attn_mask_type} not supported"
assert layer_type in LayerTypes, f"layer_type {layer_type} not supported"
if not fuse_qkv_params:
......@@ -326,7 +329,6 @@ class TransformerLayer(torch.nn.Module):
self.self_attention = MultiheadAttention(
*attention_args,
**common_attention_kwargs,
attn_mask_type=self_attn_mask_type,
input_layernorm=not output_layernorm,
attention_type="self",
bias=bias,
......@@ -429,6 +431,7 @@ class TransformerLayer(torch.nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
self_attn_mask_type: str = "causal",
encoder_output: Optional[torch.Tensor] = None,
enc_dec_attn_mask: Optional[torch.Tensor] = None,
is_first_microbatch: Optional[bool] = None,
......@@ -453,6 +456,8 @@ class TransformerLayer(torch.nn.Module):
Input tensor.
attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out self-attention softmax input.
self_attn_mask_type: {'causal', 'padding'}, default = `causal`
type of attention mask passed into softmax operation.
encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`.
......@@ -488,6 +493,19 @@ class TransformerLayer(torch.nn.Module):
Whether to set output tensors to 0 or not before use.
"""
if self.self_attn_mask_type is not None:
warnings.warn(
"Argument :attr:`self_attn_mask_type` has been moved to the `forward` method and"
"is deprecated. It will be fully removed in future releases.",
category=DeprecationWarning,
)
# Keep previous functionality for current users.
self_attn_mask_type = self.self_attn_mask_type
assert (
self_attn_mask_type in AttnMaskTypes
), f"self_attn_mask_type {self_attn_mask_type} not supported"
hidden_states = hidden_states.contiguous()
if self.sequence_parallel and self.seq_length is not None:
......@@ -495,7 +513,7 @@ class TransformerLayer(torch.nn.Module):
hidden_states.shape[0] == self.seq_length // self.tp_size
), "Sequence dimension must be split across TP group when using sequence parallel."
if self.self_attn_mask_type != "causal" and attention_mask is not None:
if self_attn_mask_type != "causal" and attention_mask is not None:
assert (
attention_mask.dtype == torch.bool
), "Attention mask must be a boolean tensor"
......@@ -509,7 +527,8 @@ class TransformerLayer(torch.nn.Module):
# Self attention.
self_attention_outputs = self.self_attention(
hidden_states,
attention_mask,
attention_mask=attention_mask,
attn_mask_type=self_attn_mask_type,
inference_params=inference_params,
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
......@@ -556,7 +575,8 @@ class TransformerLayer(torch.nn.Module):
if self.layer_type == "decoder":
inter_attention_outputs = self.inter_attention(
bda_output,
enc_dec_attn_mask,
attention_mask=enc_dec_attn_mask,
attn_mask_type=self_attn_mask_type,
encoder_output=encoder_output,
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment