Unverified Commit 67bc399d authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch/Jax] Fix attention mask definition, and sliding window for decoder (#818)



* fix inconsistency for attn mask; now True means participating in attn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix sliding window window_size for decoder+padding combination
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert paddle changes regarding mask
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert softmax to 1-mask;0-keep
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* enforce 1-mask out; 0-keep rule for jax masks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix jax lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert pytorch mask changes; some kept in tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert to jax fused attn on main
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* inverse mask logic for get_cu_seqlens/_and_indices in PyTorch implementation and mask generation in unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* temporarily disable update_weight_scale_inv
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* enforce window_size for decoder
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add docstring for mask definition 1-mask out;0-keep
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add aux_ctx_tensors to save_for_backward
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak make_decoder_mask and make_mask in jax tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* skip dBias for shapes other than 1HSS; otherwise dq/dk/dv NaNs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* expand attn_biases from list to variables in save_for_backward
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix use of variable before assignment in jax dact_lu
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove window size definition for decoder
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add change notes in README for padding mask in PyTorch
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak padding mask notes in README
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* expand list to tensors for save_for_backwards
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
parent 430d5d5a
......@@ -184,10 +184,35 @@ Compiling with FlashAttention-2
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Transformer Engine release v0.11.0 adds support for FlashAttention-2 in PyTorch for improved performance.
It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see `bug <https://github.com/Dao-AILab/flash-attention/issues/358>`_), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue. If the errors persist, install a supported version of FlashAttention-1 (v1.0.6 to v1.0.9).
It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see `bug <https://github.com/Dao-AILab/flash-attention/issues/358>`_), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue.
Note that NGC PyTorch 23.08+ containers include FlashAttention-2.
Breaking Changes
================
v1.7: Padding mask definition for PyTorch
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
In an effort to unify the definition and usage of the attention mask across all three frameworks in Transformer Engine, the padding mask has changed from `True` meaning inclusion of the corresponding position in attention to exclusion of that position in our PyTorch implementation. Since v1.7, all attention mask types follow the same definition where `True` means masking out the corresponding position and `False` means including that position in attention calculation.
An example of this change is,
.. code-block:: bash
# for a batch of 3 sequences where `a`s, `b`s and `c`s are the useful tokens
# and `0`s are the padding tokens,
[a, a, a, 0, 0,
b, b, 0, 0, 0,
c, c, c, c, 0]
# the padding mask for this batch before v1.7 is,
[ True, True, True, False, False,
True, True, False, False, False,
True, True, True, True, False]
# and for v1.7 onwards it should be,
[False, False, False, True, True,
False, False, True, True, True,
False, False, False, False, True]
FP8 Convergence
===============
......
......@@ -66,7 +66,7 @@ def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: Array
if mask is not None:
if mask.ndim != logits.ndim:
mask = jnp.expand_dims(mask, axis=-3)
logits = jnp.where(mask, logits, jnp.finfo(dtype).min)
logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
softmax_out = jax.nn.softmax(logits).astype(dtype)
......@@ -90,24 +90,34 @@ def is_causal_mask(mask: AttnMaskType):
def make_decoder_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array:
"""
Create padded causal mask
Create inverse padded causal mask where `True` means allowing the corresponding
position to participate in attention and `False` means masking out that position.
"""
q_idxs = jnp.broadcast_to(jnp.arange(q_tokens.shape[-1], dtype=jnp.int32), q_tokens.shape)
kv_idxs = jnp.broadcast_to(jnp.arange(kv_tokens.shape[-1], dtype=jnp.int32), kv_tokens.shape)
causal_mask = make_attention_mask(q_idxs, kv_idxs, jnp.greater_equal)
padding_mask = make_attention_mask(q_tokens > 0, kv_tokens > 0)
return combine_masks(causal_mask, padding_mask)
inv_causal_mask = make_attention_mask(q_idxs, kv_idxs, jnp.greater_equal)
inv_padding_mask = make_attention_mask(q_tokens > 0, kv_tokens > 0)
return combine_masks(inv_causal_mask, inv_padding_mask)
def make_mask(q_token: ArrayLike, kv_token: ArrayLike, attn_mask_type: AttnMaskType) -> Array:
"""
Create attention mask based on mask type. A `True` value in the mask means
masking out the corresponding position and a `False` value means allowing
that position to participate in attention.
"""
if is_causal_mask(attn_mask_type):
inv_mask = make_decoder_mask(q_token, kv_token)
else:
inv_mask = make_attention_mask(q_token > 0, kv_token > 0)
mask = jnp.logical_not(inv_mask)
return mask
def jax_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
"""
JAX native dot product attention implementation
"""
attn_mask_type = kwargs['attn_mask_type']
if is_causal_mask(attn_mask_type):
mask = make_decoder_mask(q_token, kv_token)
else:
mask = make_attention_mask(q_token > 0, kv_token > 0)
mask = make_mask(q_token, kv_token, attn_mask_type)
output = general_dot_product_attention(query,
key,
......@@ -127,13 +137,7 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng
TE customcall dot product attention implementation
"""
attn_mask_type = kwargs['attn_mask_type']
if is_causal_mask(attn_mask_type):
mask = make_decoder_mask(q_token, kv_token)
else:
mask = make_attention_mask(q_token > 0, kv_token > 0)
# mask invert
mask = jnp.logical_not(mask)
mask = make_mask(q_token, kv_token, attn_mask_type)
qkv_layout = kwargs.pop('qkv_layout')
match qkv_layout:
......@@ -298,6 +302,8 @@ class FusedAttnRunner:
"""
self._setup_inputs()
if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape != BiasShape.BIAS_1HSS:
pytest.skip("Bias gradient calculation is only supported for 1HSS bias shape.")
def grad_func(func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the gradient
......
......@@ -635,7 +635,7 @@ class MultiHeadAttention(nn.Module):
# position should only attend to those key positions that have already
# been generated and cached, not the remaining zero elements.
mask = combine_masks(
mask,
jnp.logical_not(mask),
jnp.broadcast_to(
jnp.arange(length) <= cur_index,
# (1, 1, length) represent (head dim, query length, key length)
......
......@@ -544,7 +544,7 @@ def _run_dot_product_attention(
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
for i in range(config.batch_size):
attention_mask_q = torch.cat([attention_mask_q,
torch.Tensor([True]*seqlens_q[i] + [False]*(config.max_seqlen_q-seqlens_q[i]))
torch.Tensor([False]*seqlens_q[i] + [True]*(config.max_seqlen_q-seqlens_q[i]))
.to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
attention_mask = attention_mask_q.to(device="cuda")
if config.attn_type == 'cross':
......@@ -552,19 +552,18 @@ def _run_dot_product_attention(
attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool)
for i in range(config.batch_size):
attention_mask_q = torch.cat([attention_mask_q,
torch.Tensor([True]*seqlens_q[i] + [False]*(config.max_seqlen_q-seqlens_q[i]))
torch.Tensor([False]*seqlens_q[i] + [True]*(config.max_seqlen_q-seqlens_q[i]))
.to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
attention_mask_kv = torch.cat([attention_mask_kv, torch.Tensor(
[True]*seqlens_kv[i] + [False]*(config.max_seqlen_kv-seqlens_kv[i]))
[False]*seqlens_kv[i] + [True]*(config.max_seqlen_kv-seqlens_kv[i]))
.to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
attention_mask = (
attention_mask_q.to(device="cuda"), attention_mask_kv.to(device="cuda"))
window_size = None
if swa:
window_size, attention_mask = get_swa(config.max_seqlen_q, config.max_seqlen_kv)
elif "causal" in config.attn_mask_type:
window_size, attention_mask = (-1, 0), None
else:
window_size, attention_mask = None, None
alibi_slopes = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
......@@ -858,7 +857,7 @@ def _run_transformer_layer(
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
for i in range(config.batch_size):
attention_mask_q = torch.cat([attention_mask_q,
torch.Tensor([True]*seqlens_q[i] + [False]*(config.max_seqlen_q-seqlens_q[i]))
torch.Tensor([False]*seqlens_q[i] + [True]*(config.max_seqlen_q-seqlens_q[i]))
.to(torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
attention_mask = attention_mask_q.to(device="cuda")
......@@ -944,7 +943,7 @@ def _run_transformer_layer(
model_configs_fp8_vs_f16 = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"fp8_9 ": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_9" : ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_10": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_11": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_12": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"),
......@@ -1143,24 +1142,49 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd):
dtype, config, False, qkv_layout)
tols = dict(atol=5e-1, rtol=5e-2)
rmse_tol = 0.1
bwd_names = ['dq', 'dk', 'dv']
fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16)
fwd_range = max(fused_attn_fwd_fp8.max().item(),
fused_attn_fwd_f16.max().item()) - min(fused_attn_fwd_fp8.min().item(),
fused_attn_fwd_f16.min().item())
if _NVTE_DEBUG:
print('[test_dpa_fp8_vs_f16]: ', tols)
print()
print('========== {:^25s} =========='.format('forward output'))
print('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format(
fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item()))
print('fused_attn_fwd_f16 min {:.6f} max {:.6f}'.format(
fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item()))
print('fused_attn_fwd RMSE: {:.6f}'.format(
_rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16)))
print('fused_attn_fwd RMSE: {:.6f}'.format(fwd_rmse))
try:
torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols)
except Exception as e:
print(e)
print()
assert(fwd_rmse < rmse_tol * fwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range)
for i,_ in enumerate(fused_attn_bwd_f16):
bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i])
bwd_range = max(fused_attn_bwd_fp8[i].max().item(),
fused_attn_bwd_f16[i].max().item()) - min(fused_attn_bwd_fp8[i].min().item(),
fused_attn_bwd_f16[i].min().item())
if _NVTE_DEBUG:
print('fused_attn_bwd_fp8 min {:.6f} max {:.6f}'.format(
print()
print('========== {:^25s} =========='.format(bwd_names[i]))
print('fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}'.format(i,
fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()))
print('fused_attn_bwd_f16 min {:.6f} max {:.6f}'.format(
print('fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}'.format(i,
fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()))
print('fused_attn_bwd RMSE: {:.6f}'.format(
_rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i])))
print('fused_attn_bwd RMSE[{}]: {:.6f}'.format(i, bwd_rmse))
try:
torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols)
except Exception as e:
print(e)
print()
assert(bwd_rmse < rmse_tol * bwd_range
), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range)
def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
......@@ -1231,7 +1255,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
layout = layout.replace('h', 'hg')
layout = layout.replace('t', 'tg')
tensor_shape = [dim_to_num[j] for j in layout.split('_')]
tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda")
tensor_count = 1
split_dim = 0
for dim, l in enumerate(layout.split('_')):
......@@ -1252,7 +1276,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
qkv_format_kv = qkv_format_kv.replace('s', 'sq')
out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split('_')]
out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
out_grad = 0.1 * torch.randn(out_grad_shape_new, dtype=dtype, device="cuda")
out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda")
with fp8_autocast(enabled=fp8_dpa, fp8_recipe=fp8_recipe):
out = dpa(inp[0], inp[1], inp[2],
......@@ -1359,7 +1383,7 @@ def _run_custom_mha_fp8(dtype, config, backend):
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
inp = 0.0001 * torch.randint(0, 100,
inp = 0.0001 * torch.randint(-100, 100,
(config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim),
dtype=dtype, device="cuda", requires_grad=True)
seqlens = torch.full([config.batch_size], config.max_seqlen_q,
......
......@@ -224,7 +224,7 @@ def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor:
the samples in a batch.
"""
mask = mask.squeeze(1).squeeze(1)
reduced_mask = mask.sum(dim=1)
reduced_mask = mask.logical_not().sum(dim=1)
cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
zero = torch.zeros(1, dtype=torch.int32, device="cuda")
cu_seqlens = torch.cat((zero, cu_seqlens))
......@@ -242,13 +242,13 @@ def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.
mask = mask.squeeze(1).squeeze(1)
bs, seqlen = mask.shape
reduced_mask = mask.sum(dim=1)
reduced_mask = mask.logical_not().sum(dim=1)
cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
zero = torch.zeros(1, dtype=torch.int32, device="cuda")
cu_seqlens = torch.cat((zero, cu_seqlens))
mask = mask.reshape(-1)
indices = mask.nonzero()
indices = mask.logical_not().nonzero()
indices = indices.unsqueeze(-1)
num_nonzeros = indices.shape[0]
......@@ -408,7 +408,7 @@ class PackTensors(torch.autograd.Function):
*tensors: Tuple[torch.Tensor, ...]
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported."
ctx.indices = indices
ctx.save_for_backward(indices)
ctx.dim0 = tensors[0].shape[0]
if len(tensors) == 1:
return pack_tensor(indices, *tensors)
......@@ -418,11 +418,12 @@ class PackTensors(torch.autograd.Function):
@staticmethod
def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]):
(indices,) = ctx.saved_tensors
if len(grad_outputs) == 1:
return None, unpack_tensor(ctx.indices, ctx.dim0, *grad_outputs)
return None, unpack_tensor(indices, ctx.dim0, *grad_outputs)
if len(grad_outputs) == 2:
return None, *unpack_2_tensors(ctx.indices, ctx.dim0, *grad_outputs)
return None, *unpack_3_tensors(ctx.indices, ctx.dim0, *grad_outputs)
return None, *unpack_2_tensors(indices, ctx.dim0, *grad_outputs)
return None, *unpack_3_tensors(indices, ctx.dim0, *grad_outputs)
class UnpackTensor(torch.autograd.Function):
......@@ -436,12 +437,13 @@ class UnpackTensor(torch.autograd.Function):
dim0: int,
tensor: torch.Tensor,
) -> torch.Tensor:
ctx.indices = indices
ctx.save_for_backward(indices)
return unpack_tensor(indices, dim0, tensor)
@staticmethod
def backward(ctx, grad_output):
return None, None, pack_tensor(ctx.indices, grad_output)
(indices,) = ctx.saved_tensors
return None, None, pack_tensor(indices, grad_output)
def flash_attn_p2p_communicate(rank, send_tensor, send_dst,
......@@ -868,8 +870,8 @@ class AttnFuncWithCP(torch.autograd.Function):
else:
out = out.view(-1, *out.shape[-2:])
ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k)
ctx.rng_states = rng_states
ctx.save_for_backward(q, kv, out, softmax_lse,
cu_seqlens_q, cu_seqlens_k, *rng_states, *attn_biases)
ctx.cp_group = cp_group
ctx.cp_global_ranks = cp_global_ranks
ctx.dropout_p = dropout_p
......@@ -880,16 +882,17 @@ class AttnFuncWithCP(torch.autograd.Function):
ctx.qkv_format = qkv_format
ctx.attn_bias_type = attn_bias_type
ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape
ctx.attn_biases = attn_biases
ctx.deterministic = deterministic
ctx.use_fused_attention = use_fused_attention
return out
@staticmethod
def backward(ctx, dout):
q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) = ctx.saved_tensors[:6]
cp_size = get_distributed_world_size(ctx.cp_group)
rng_states = ctx.saved_tensors[6:6+cp_size]
attn_biases = ctx.saved_tensors[6+cp_size:6+cp_size*2]
rank = get_distributed_rank(ctx.cp_group)
send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size]
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size]
......@@ -897,12 +900,12 @@ class AttnFuncWithCP(torch.autograd.Function):
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
if ctx.attn_biases[0] is not None:
if attn_biases[0] is not None:
# [b, np, sq, 2*cp, sk//(2*cp)]
attn_dbias = torch.zeros(
*ctx.attn_bias_shape,
dtype=ctx.attn_biases[0].dtype,
device=ctx.attn_biases[0].device
dtype=attn_biases[0].dtype,
device=attn_biases[0].device
)
# [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
attn_dbias_ = attn_dbias.view(
......@@ -985,9 +988,9 @@ class AttnFuncWithCP(torch.autograd.Function):
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
out_ = out.view(-1, *out.shape[-3:])
dout_ = dout.view(-1, *dout.shape[-3:])
aux_ctx_tensors = [softmax_lse, ctx.rng_states[cp_size-i-1]]
aux_ctx_tensors = [softmax_lse, rng_states[cp_size-i-1]]
if attn_dbias is not None:
aux_ctx_tensors += [ctx.attn_biases[cp_size-i-1]]
aux_ctx_tensors += [attn_biases[cp_size-i-1]]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_k,
cu_seqlens_q, cu_seqlens_k,
......@@ -1017,7 +1020,7 @@ class AttnFuncWithCP(torch.autograd.Function):
dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k,
ctx.max_seqlen_q, ctx.max_seqlen_k,
ctx.dropout_p, ctx.softmax_scale, True,
rng_state=ctx.rng_states[cp_size-i-1],
rng_state=rng_states[cp_size-i-1],
**fa_optional_backward_kwargs
)
elif i >= (cp_size-rank-1):
......@@ -1038,9 +1041,9 @@ class AttnFuncWithCP(torch.autograd.Function):
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
out_ = out.view(-1, *out.shape[-3:])
dout_ = dout.view(-1, *dout.shape[-3:])
aux_ctx_tensors = [softmax_lse, ctx.rng_states[cp_size-i-1]]
aux_ctx_tensors = [softmax_lse, rng_states[cp_size-i-1]]
if attn_dbias is not None:
aux_ctx_tensors += [ctx.attn_biases[cp_size-i-1]]
aux_ctx_tensors += [attn_biases[cp_size-i-1]]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_k//2,
cu_seqlens_q, cu_seqlens_k//2,
......@@ -1074,7 +1077,7 @@ class AttnFuncWithCP(torch.autograd.Function):
dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k//2,
ctx.max_seqlen_q, ctx.max_seqlen_k//2,
ctx.dropout_p, ctx.softmax_scale, False,
rng_state=ctx.rng_states[cp_size-i-1],
rng_state=rng_states[cp_size-i-1],
**fa_optional_backward_kwargs
)
else:
......@@ -1095,9 +1098,9 @@ class AttnFuncWithCP(torch.autograd.Function):
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
out_ = out[1].contiguous()
dout_ = dout[1].contiguous()
aux_ctx_tensors = [softmax_lse_, ctx.rng_states[cp_size-i-1]]
aux_ctx_tensors = [softmax_lse_, rng_states[cp_size-i-1]]
if attn_dbias is not None:
aux_ctx_tensors += [ctx.attn_biases[cp_size-i-1]]
aux_ctx_tensors += [attn_biases[cp_size-i-1]]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
ctx.max_seqlen_q//2, ctx.max_seqlen_k,
cu_seqlens_q//2, cu_seqlens_k,
......@@ -1135,14 +1138,14 @@ class AttnFuncWithCP(torch.autograd.Function):
dq_, dkv_[0], dkv_[1], cu_seqlens_q//2, cu_seqlens_k,
ctx.max_seqlen_q//2, ctx.max_seqlen_k,
ctx.dropout_p, ctx.softmax_scale, False,
rng_state=ctx.rng_states[cp_size-i-1],
rng_state=rng_states[cp_size-i-1],
**fa_optional_backward_kwargs
)
else:
if ctx.use_fused_attention:
aux_ctx_tensors = [softmax_lse, ctx.rng_states[cp_size-i-1]]
aux_ctx_tensors = [softmax_lse, rng_states[cp_size-i-1]]
if attn_dbias is not None:
aux_ctx_tensors += [ctx.attn_biases[cp_size-i-1]]
aux_ctx_tensors += [attn_biases[cp_size-i-1]]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_k,
cu_seqlens_q, cu_seqlens_k,
......@@ -2300,9 +2303,8 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None)
ctx.save_for_backward(*qkvo_tensors, cu_seqlens, *fp8_tensors)
ctx.save_for_backward(*qkvo_tensors, cu_seqlens, *fp8_tensors, *aux_ctx_tensors)
ctx.fp8_meta = fp8_meta
ctx.aux_ctx_tensors = aux_ctx_tensors
ctx.max_seqlen = max_seqlen
ctx.qkv_dtype = qkv_dtype
ctx.attn_scale = attn_scale
......@@ -2326,12 +2328,12 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
d_out = d_out._data
d_out = d_out.contiguous()
(qkv, out, cu_seqlens,
qkv_fp8, out_fp8, fwd_scales, fwd_scale_invs) = ctx.saved_tensors
if not ctx.aux_ctx_tensors[0].is_contiguous():
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
(qkv, out, cu_seqlens, qkv_fp8, out_fp8,
fwd_scales, fwd_scale_invs, *aux_ctx_tensors) = ctx.saved_tensors
if not aux_ctx_tensors[0].is_contiguous():
aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
if ctx.use_FAv2_bwd:
softmax_lse, rng_state = ctx.aux_ctx_tensors
softmax_lse, rng_state = aux_ctx_tensors
dqkv = torch.empty_like(qkv)
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
d_out, q, k, v, out = [maybe_contiguous(x)
......@@ -2363,7 +2365,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
dqkv_fp8, *rest = fused_attn_bwd_qkvpacked(
ctx.max_seqlen, cu_seqlens,
qkv_fp8, out_fp8, d_out_fp8,
fp8_dtype_forward, fp8_dtype_backward, ctx.aux_ctx_tensors,
fp8_dtype_forward, fp8_dtype_backward, aux_ctx_tensors,
ctx.fused_attention_backend,
fwd_scale_invs[META_QKV], # d_scale_qkv,
fwd_scale_invs[META_S], # d_scale_s,
......@@ -2398,7 +2400,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
d_out = d_out_f8tensor.from_float8(qkv.dtype)
dqkv, *rest = fused_attn_bwd_qkvpacked(
ctx.max_seqlen, cu_seqlens, qkv, out, d_out,
ctx.qkv_dtype, ctx.qkv_dtype, ctx.aux_ctx_tensors,
ctx.qkv_dtype, ctx.qkv_dtype, aux_ctx_tensors,
ctx.fused_attention_backend,
None, None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
......@@ -2501,9 +2503,9 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None)
ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, *fp8_tensors)
ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv,
*fp8_tensors, *aux_ctx_tensors)
ctx.fp8_meta = fp8_meta
ctx.aux_ctx_tensors = aux_ctx_tensors
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_kv = max_seqlen_kv
ctx.qkv_dtype = qkv_dtype
......@@ -2528,12 +2530,12 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
d_out = d_out._data
d_out = d_out.contiguous()
(q, kv, out, cu_seqlens_q, cu_seqlens_kv,
q_fp8, kv_fp8, out_fp8, fwd_scales, fwd_scale_invs) = ctx.saved_tensors
if not ctx.aux_ctx_tensors[0].is_contiguous():
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
(q, kv, out, cu_seqlens_q, cu_seqlens_kv, q_fp8, kv_fp8, out_fp8,
fwd_scales, fwd_scale_invs, *aux_ctx_tensors) = ctx.saved_tensors
if not aux_ctx_tensors[0].is_contiguous():
aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
if ctx.use_FAv2_bwd:
softmax_lse, rng_state = ctx.aux_ctx_tensors
softmax_lse, rng_state = aux_ctx_tensors
dq = torch.empty_like(q)
dkv = torch.empty_like(kv)
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
......@@ -2567,7 +2569,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
dq_fp8, dkv_fp8, *rest = fused_attn_bwd_kvpacked(
ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q_fp8, kv_fp8, out_fp8, d_out_fp8,
fp8_dtype_forward, fp8_dtype_backward, ctx.aux_ctx_tensors,
fp8_dtype_forward, fp8_dtype_backward, aux_ctx_tensors,
ctx.fused_attention_backend,
fwd_scale_invs[META_QKV], # d_scale_qkv,
fwd_scale_invs[META_S], # d_scale_s,
......@@ -2614,7 +2616,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
dq, dkv, *rest = fused_attn_bwd_kvpacked(
ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, kv, out, d_out,
ctx.qkv_dtype, ctx.qkv_dtype, ctx.aux_ctx_tensors,
ctx.qkv_dtype, ctx.qkv_dtype, aux_ctx_tensors,
ctx.fused_attention_backend,
None, None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
......@@ -2773,9 +2775,9 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, *fp8_tensors)
ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv,
*fp8_tensors, *aux_ctx_tensors)
ctx.fp8_meta = fp8_meta
ctx.aux_ctx_tensors = aux_ctx_tensors
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_kv = max_seqlen_kv
ctx.qkv_dtype = qkv_dtype
......@@ -2800,12 +2802,12 @@ class FusedAttnFunc(torch.autograd.Function):
d_out = d_out._data
d_out = d_out.contiguous()
(q, k, v, out, cu_seqlens_q, cu_seqlens_kv,
q_fp8, k_fp8, v_fp8, out_fp8, fwd_scales, fwd_scale_invs) = ctx.saved_tensors
if not ctx.aux_ctx_tensors[0].is_contiguous():
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
(q, k, v, out, cu_seqlens_q, cu_seqlens_kv, q_fp8, k_fp8, v_fp8, out_fp8,
fwd_scales, fwd_scale_invs, *aux_ctx_tensors) = ctx.saved_tensors
if not aux_ctx_tensors[0].is_contiguous():
aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
if ctx.use_FAv2_bwd:
softmax_lse, rng_state = ctx.aux_ctx_tensors
softmax_lse, rng_state = aux_ctx_tensors
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
......@@ -2840,7 +2842,7 @@ class FusedAttnFunc(torch.autograd.Function):
dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q_fp8, k_fp8, v_fp8, out_fp8, d_out_fp8,
fp8_dtype_forward, fp8_dtype_backward, ctx.aux_ctx_tensors,
fp8_dtype_forward, fp8_dtype_backward, aux_ctx_tensors,
ctx.fused_attention_backend,
fwd_scale_invs[META_QKV], # d_scale_qkv,
fwd_scale_invs[META_S], # d_scale_s,
......@@ -2923,7 +2925,7 @@ class FusedAttnFunc(torch.autograd.Function):
dq, dk, dv, *rest = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, k, v, out, d_out,
ctx.qkv_dtype, ctx.qkv_dtype, ctx.aux_ctx_tensors,
ctx.qkv_dtype, ctx.qkv_dtype, aux_ctx_tensors,
ctx.fused_attention_backend,
None, None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
......@@ -3493,7 +3495,9 @@ class DotProductAttention(torch.nn.Module):
a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
for cross-attention. For the 'arbitrary' mask type, it should be in a shape that is
broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value
means the corresponding position is masked out and a `False` means that position is
allowed to participate in attention.
qkv_format: str, default = `None`
If provided, overrides :attr:`qkv_format` from initialization.
cu_seqlens_q: Optional[torch.Tensor], default = `None`
......@@ -4383,7 +4387,9 @@ class MultiheadAttention(torch.nn.Module):
a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
for cross-attention. For the 'arbitrary' mask type, it should be in a shape that is
broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value
means the corresponding position is masked out and a `False` means that position is
allowed to participate in attention.
attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'arbitrary'},
default = `None`
type of attention mask passed into softmax operation.
......
......@@ -542,6 +542,8 @@ class TransformerLayer(torch.nn.Module):
It should be in [batch_size, 1, 1, seqlen_q] for 'padding' mask,
and broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]
for 'arbitrary'. It should be 'None' for 'causal' and 'no_mask'.
A `True` value means the corresponding position is masked out and
a `False` means that position is allowed to participate in attention.
self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default = `causal`
Type of attention mask passed into softmax operation.
......@@ -555,7 +557,9 @@ class TransformerLayer(torch.nn.Module):
using `layer_type="decoder"`. It should be a tuple of two masks in
[batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for 'padding' mask.
It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]
for 'arbitrary' mask. It should be 'None' for 'causal' and 'no_mask'.
for 'arbitrary' mask. It should be 'None' for 'causal' and 'no_mask'. A `True` value
means the corresponding position is masked out and a `False` means that position is
allowed to participate in attention.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
......@@ -655,7 +659,6 @@ class TransformerLayer(torch.nn.Module):
inter_attention_outputs = self.inter_attention(
hidden_states,
attention_mask=enc_dec_attn_mask,
window_size=window_size,
encoder_output=encoder_output,
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment