Unverified Commit ecd4f808 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Add mixed use of cuDNN fprop and flash-attn v2 bprop (#349)



* Add support for cuDNN fprop and FAv2 bprop
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* skip activation recompute tests if FAv2
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* restrict the use of FAv2 bprop to H100 only
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* move use_FAv2_bwd check to init
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove skipifs for FAv2 in test numerics
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix typos and wording for deterministic checks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Remove variables related to FAv2 skipifs
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 3cc2c1d2
......@@ -50,8 +50,9 @@ _flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
if _flash_attn_2_available:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd # pylint: disable=no-name-in-module
else:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as flash_attn_forward_func # pylint: disable=no-name-in-module
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as flash_attn_forward_func # pylint: disable=no-name-in-module,ungrouped-imports
__all__ = ["DotProductAttention"]
......@@ -124,6 +125,64 @@ class _SplitLastDim(torch.autograd.Function):
return torch.cat(grad_outputs, dim = -1), None
class _CombineQKV(torch.autograd.Function):
""""""
@staticmethod
def forward(ctx,
query_layer: torch.Tensor,
key_layer: torch.Tensor, # pylint: disable=unused-argument
value_layer: torch.Tensor, # pylint: disable=unused-argument
dim: int,
) -> torch.Tensor:
mixed_layer = torch.Tensor().to(device=query_layer.device,
dtype=query_layer.dtype)
new_shape = list(query_layer.shape)
new_shape[dim] = new_shape[dim] * 3
mixed_layer.set_(query_layer.untyped_storage(),
query_layer.storage_offset(),
new_shape,
query_layer.stride())
ctx.dim = dim
return mixed_layer
@staticmethod
def backward(ctx,
*grad_outputs,
) -> Tuple[torch.Tensor, ...]:
assert len(grad_outputs) > 0, "No gradients received for backprop!"
tensors = split_tensor_along_dim(grad_outputs[0], ctx.dim, 3)
return tensors[0], tensors[1], tensors[2], None
class _CombineKV(torch.autograd.Function):
""""""
@staticmethod
def forward(ctx,
key_layer: torch.Tensor,
value_layer: torch.Tensor, # pylint: disable=unused-argument
dim: int,
) -> torch.Tensor:
mixed_layer = torch.Tensor().to(device=key_layer.device,
dtype=key_layer.dtype)
new_shape = list(key_layer.shape)
new_shape[dim] = new_shape[dim] * 2
mixed_layer.set_(key_layer.untyped_storage(),
key_layer.storage_offset(),
new_shape,
key_layer.stride())
ctx.dim = dim
return mixed_layer
@staticmethod
def backward(ctx,
*grad_outputs,
) -> Tuple[torch.Tensor, ...]:
assert len(grad_outputs) > 0, "No gradients received for backprop!"
tensors = split_tensor_along_dim(grad_outputs[0], ctx.dim, 2)
return tensors[0], tensors[1], None
class UnfusedDotProductAttention(torch.nn.Module):
"""Parallel attention w/o QKV and Proj Gemms
......@@ -301,7 +360,7 @@ class _PrepareQKVForFA(torch.autograd.Function):
return dq, dk, dv
def _check_if_interleaved_qkv(q, k, v):
def _check_qkv_layout(q, k, v):
data_ptr = q.untyped_storage().data_ptr()
check_ptrs = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
if not check_ptrs:
......@@ -320,9 +379,18 @@ def _check_if_interleaved_qkv(q, k, v):
last_dim_size = shape[-1]
check_offsets = all(i * last_dim_size == x.storage_offset()
for i, x in enumerate([q, k, v]))
return check_offsets
if check_offsets:
return "sbh3d"
def _check_if_interleaved_kv(k, v):
last_dims_size = shape[-1] * shape[-2]
check_offsets = all(i * last_dims_size == x.storage_offset()
for i, x in enumerate([q, k, v]))
if check_offsets:
return "sb3hd"
return "other"
def _check_kv_layout(k, v):
data_ptr = k.untyped_storage().data_ptr()
check_ptrs = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v])
if not check_ptrs:
......@@ -341,8 +409,16 @@ def _check_if_interleaved_kv(k, v):
last_dim_size = shape[-1]
check_offsets = all(i * last_dim_size == x.storage_offset()
for i, x in enumerate([k, v]))
return check_offsets
if check_offsets:
return "sbh2d"
last_dims_size = shape[-1] * shape[-2]
check_offsets = all(i * last_dims_size == x.storage_offset()
for i, x in enumerate([k, v]))
if check_offsets:
return "sb2hd"
return "other"
class FlashAttention(torch.nn.Module):
......@@ -391,7 +467,7 @@ class FlashAttention(torch.nn.Module):
if (query_layer.shape[-1] == 128 and
query_layer.shape[0] * query_layer.shape[1] >= 512 and
_check_if_interleaved_qkv(query_layer, key_layer, value_layer)):
_check_qkv_layout(query_layer, key_layer, value_layer) == "sbh3d"):
query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(query_layer,
key_layer,
value_layer)
......@@ -436,7 +512,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
@staticmethod
def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias, attn_scale,
dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen, fused_attention_backend):
rng_gen, fused_attention_backend, use_FAv2_bwd):
out, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype,
fused_attention_backend, attn_bias,
......@@ -455,19 +531,34 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type
ctx.fused_attention_backend = fused_attention_backend
ctx.use_FAv2_bwd = use_FAv2_bwd
return out
@staticmethod
def backward(ctx, d_out):
qkv, out, cu_seqlens = ctx.saved_tensors
dqkv, *rest = fused_attn_bwd_qkvpacked(
ctx.max_seqlen, cu_seqlens, qkv, out, d_out,
ctx.qkv_dtype, ctx.aux_ctx_tensors,
ctx.fused_attention_backend,
None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
if ctx.use_FAv2_bwd:
softmax_lse, rng_state = ctx.aux_ctx_tensors
dqkv = torch.empty_like(qkv)
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
d_out, q, k, v, out = [maybe_contiguous(x)
for x in (d_out, qkv[:,0], qkv[:,1], qkv[:,2], out)]
flash_attn_cuda_bwd(
d_out, q, k, v, out, softmax_lse, dqkv[:,0], dqkv[:,1], dqkv[:,2],
cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen,
ctx.dropout_p, ctx.attn_scale, False,
ctx.attn_mask_type == "causal", None, rng_state
)
dqkv = dqkv[..., :d_out.shape[-1]]
else:
dqkv, *rest = fused_attn_bwd_qkvpacked(
ctx.max_seqlen, cu_seqlens, qkv, out, d_out,
ctx.qkv_dtype, ctx.aux_ctx_tensors,
ctx.fused_attention_backend,
None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
# if no_bias, return dqkv
if ctx.attn_bias_type == "no_bias":
......@@ -486,7 +577,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, kv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
qkv_layout, attn_bias_type, attn_mask_type,
rng_gen, fused_attention_backend):
rng_gen, fused_attention_backend, use_FAv2_bwd):
out, aux_ctx_tensors = fused_attn_fwd_kvpacked(
is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, kv, qkv_dtype, fused_attention_backend, attn_bias,
......@@ -506,20 +597,37 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type
ctx.fused_attention_backend = fused_attention_backend
ctx.use_FAv2_bwd = use_FAv2_bwd
return out
@staticmethod
def backward(ctx, d_out):
q, kv, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors
dq, dkv, *rest = fused_attn_bwd_kvpacked(
ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, kv, out, d_out,
ctx.qkv_dtype, ctx.aux_ctx_tensors,
ctx.fused_attention_backend,
None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
if ctx.use_FAv2_bwd:
softmax_lse, rng_state = ctx.aux_ctx_tensors
dq = torch.empty_like(q)
dkv = torch.empty_like(kv)
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
d_out, q, k, v, out = [maybe_contiguous(x)
for x in (d_out, q, kv[:,0], kv[:,1], out)]
flash_attn_cuda_bwd(
d_out, q, k, v, out, softmax_lse, dq, dkv[:,0], dkv[:,1],
cu_seqlens_q, cu_seqlens_kv, ctx.max_seqlen_q, ctx.max_seqlen_kv,
ctx.dropout_p, ctx.attn_scale, False,
ctx.attn_mask_type == "causal", None, rng_state
)
dq = dq[..., :d_out.shape[-1]]
dkv = dkv[..., :d_out.shape[-1]]
else:
dq, dkv, *rest = fused_attn_bwd_kvpacked(
ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, kv, out, d_out,
ctx.qkv_dtype, ctx.aux_ctx_tensors,
ctx.fused_attention_backend,
None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
# if no_bias, return dqkv
if ctx.attn_bias_type == "no_bias":
......@@ -572,6 +680,9 @@ class FusedAttention(torch.nn.Module):
self.attention_dropout_ctx = attention_dropout_ctx
self.attn_mask_type = attn_mask_type
self.attention_type = attention_type
self.use_FAv2_bwd = (os.getenv("NVTE_FUSED_ATTN_USE_FAv2_BWD", "1") == "1"
and _flash_attn_2_available
and get_device_compute_capability() == 9.0)
def forward(
self,
......@@ -601,26 +712,27 @@ class FusedAttention(torch.nn.Module):
max_seqlen_kv = seqlen_kv
if self.attention_type == "self":
if _check_if_interleaved_qkv(query_layer, key_layer, value_layer):
query_layer = query_layer.unsqueeze(3)
key_layer = key_layer.unsqueeze(3)
value_layer = value_layer.unsqueeze(3)
qkv_layout = _check_qkv_layout(query_layer, key_layer, value_layer)
if qkv_layout == "sbh3d":
mixed_layer = _CombineQKV.apply(query_layer, key_layer, value_layer, 3)
# [s, b, h, 3, d]
mixed_layer = torch.cat([query_layer, key_layer, value_layer], dim = 3)
mixed_layer = mixed_layer.view(
*mixed_layer.shape[0:3], 3, query_layer.shape[-1])
# [b, s, 3, h, d]
mixed_layer = mixed_layer.transpose(2, 3).transpose(0, 1).contiguous()
else:
query_layer = query_layer.unsqueeze(2)
key_layer = key_layer.unsqueeze(2)
value_layer = value_layer.unsqueeze(2)
elif qkv_layout == "sb3hd":
mixed_layer = _CombineQKV.apply(query_layer, key_layer, value_layer, 2)
# [s, b, 3, h, d]
mixed_layer = torch.cat([query_layer, key_layer, value_layer], dim = 2)
mixed_layer = mixed_layer.view(
*mixed_layer.shape[0:2], 3, *query_layer.shape[2:])
# [b, s, 3, h, d]
mixed_layer = mixed_layer.transpose(0, 1).contiguous()
else:
raise Exception("FusedAttention only supports qkv layout sbh3d or sb3hd!")
# [total_seqs, 3, h, d]
mixed_layer = mixed_layer.view(
mixed_layer.shape[0] * mixed_layer.shape[1], *mixed_layer.shape[2:]).contiguous()
mixed_layer.shape[0] * mixed_layer.shape[1], *mixed_layer.shape[2:])
qkv_layout = "qkv_interleaved"
max_seqlen = seqlen_q
......@@ -630,6 +742,10 @@ class FusedAttention(torch.nn.Module):
step=seqlen_q,
dtype=torch.int32,
device=query_layer.device)
use_FAv2_bwd = (self.use_FAv2_bwd
and (fused_attention_backend
== tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen)
and core_attention_bias_type == "no_bias")
with self.attention_dropout_ctx():
output = FusedAttnFunc_qkvpacked.apply(
......@@ -647,31 +763,36 @@ class FusedAttention(torch.nn.Module):
self.attn_mask_type,
None, # rng_gen
fused_attention_backend,
use_FAv2_bwd
)
output = output.view(batch_size, seqlen_q, -1).transpose(0, 1).contiguous()
if self.attention_type == "cross":
if _check_if_interleaved_kv(key_layer, value_layer):
kv_layout = _check_kv_layout(key_layer, value_layer)
if kv_layout == "sbh2d":
key_value = _CombineKV.apply(key_layer, value_layer, 3)
# [s, b, h, 2, d]
key_layer = key_layer.unsqueeze(3)
value_layer = value_layer.unsqueeze(3)
key_value = torch.cat([key_layer, value_layer], dim = 3)
key_value = key_value.view(
*key_value.shape[0:3], 2, key_layer.shape[-1])
# [b, s, 2, h, d]
key_value = key_value.transpose(2, 3).transpose(0, 1).contiguous()
else:
elif qkv_layout == "sb2hd":
key_value = _CombineKV.apply(key_layer, value_layer, 2)
# [s, b, 2, h, d]
key_layer = key_layer.unsqueeze(2)
value_layer = value_layer.unsqueeze(2)
key_value = torch.cat([key_layer, value_layer], dim = 2)
key_value = key_value.view(
*key_value.shape[0:2], 2, *key_layer.shape[2:])
# [b, s, 2, h, d]
key_value = key_value.transpose(0, 1).contiguous()
else:
raise Exception("FusedAttention only supports kv layout sbh2d or sb2hd!")
# [total_seqs, 2, h, d]
# [total_seqs, h, d]
query_layer = query_layer.transpose(0, 1).contiguous()
query_layer = query_layer.view(
query_layer.shape[0] * query_layer.shape[1], *query_layer.shape[2:])
# [total_seqs, 2, h, d]
key_value = key_value.view([key_value.shape[0] * key_value.shape[1]]
+ key_value.shape[2:]).contiguous()
+ key_value.shape[2:])
qkv_layout = "kv_interleaved"
cu_seqlens_q = torch.arange(
......@@ -703,6 +824,7 @@ class FusedAttention(torch.nn.Module):
self.attn_mask_type,
None, # rng_gen
fused_attention_backend,
use_FAv2_bwd
)
output = (outputs[0].view(batch_size, seqlen_q, -1).transpose(0, 1).contiguous(),
......@@ -807,8 +929,8 @@ class DotProductAttention(torch.nn.Module):
self.use_flash_attention = False
warnings.warn(
"Disabling usage of FlashAttention since version 2 does not support deterministic"
"exection. In order to use FA with deterministic behavior, install FlashAttention"
"version 1."
"execution. In order to use FA with deterministic behavior, please install"
"FlashAttention version 1."
)
self.use_fused_attention = (
......@@ -969,8 +1091,8 @@ class DotProductAttention(torch.nn.Module):
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]):
use_fused_attention = False
warnings.warn(
"Disabling usage of FusedAttention since the FusedAttention"
"backend does not support deterministic exection."
"Disabling usage of FusedAttention since this FusedAttention"
"backend does not support deterministic execution."
)
if use_flash_attention:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment