Unverified Commit 7f2f7dd2 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Fix flash-attn checks and RoPE DPA (#506)



* fix condition checks related to FA head_dim
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* force q,k,v contiguous when RoPE is in use
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Expand FA version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 71e51eae
......@@ -284,7 +284,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements
if "pytorch" in frameworks():
add_unique(install_reqs, ["torch", "flash-attn>=1.0.6, <=2.0.4"])
add_unique(install_reqs, ["torch", "flash-attn>=1.0.6,<=2.3.3,!=2.0.9,!=2.1.0"])
add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"])
if "jax" in frameworks():
if not found_pybind11():
......
......@@ -1142,60 +1142,72 @@ def _get_qkv_layout(
check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v])
assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!"
data_ptr = q.untyped_storage().data_ptr()
check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
data_ptr = k.untyped_storage().data_ptr()
check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v])
stride = q.stride()
check_strides_qkv = all(stride == x.stride() for x in [q, k, v])
stride = k.stride()
check_strides_kv = all(stride == x.stride() for x in [k, v])
shape = q.shape
check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
shape = k.shape
check_shapes_kv = all(shape == x.shape for x in [k, v])
last_dim_size = q.shape[-1]
check_last_dim_offsets_qkv = all(i * last_dim_size == x.storage_offset()
for i, x in enumerate([q, k, v]))
last_dim_size = k.shape[-1]
check_last_dim_offsets_kv = all(i * last_dim_size == x.storage_offset()
for i, x in enumerate([k, v]))
last_two_dims_size = q.shape[-1] * q.shape[-2]
check_last_two_dims_offsets_qkv = all(i * last_two_dims_size == x.storage_offset()
for i, x in enumerate([q, k, v]))
last_two_dims_size = k.shape[-1] * k.shape[-2]
check_last_two_dims_offsets_kv = all(i * last_two_dims_size == x.storage_offset()
for i, x in enumerate([k, v]))
qkv_layout = None
if (check_ptrs_qkv and check_strides_qkv and check_shapes_qkv
and check_last_two_dims_offsets_qkv
and not check_last_dim_offsets_qkv):
# sb3hd, bs3hd, t3hd
qkv_layout = qkv_format[:-2] + '3' + qkv_format[-2:]
elif check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_last_dim_offsets_qkv:
# sbh3d, bsh3d, th3d
qkv_layout = qkv_format[:-1] + '3' + qkv_format[-1:]
elif (check_ptrs_kv and check_strides_kv and check_shapes_kv
and check_last_two_dims_offsets_kv
and not check_last_dim_offsets_kv):
# sbhd_sb2hd, bshd_bs2hd, thd_t2hd
qkv_layout = qkv_format + '_' + qkv_format[:-2] + '2' + qkv_format[-2:]
elif (check_ptrs_kv and check_strides_kv and check_shapes_kv
and check_last_dim_offsets_kv):
# sbhd_sbh2d, bshd_bsh2d, thd_th2d
qkv_layout = qkv_format + '_' + qkv_format[:-1] + '2' + qkv_format[-1:]
elif check_strides_kv and check_shapes_kv:
# sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd
qkv_layout = '_'.join(list([qkv_format])*3)
else:
def run_iteratively(q, k, v):
data_ptr = q.untyped_storage().data_ptr()
check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
data_ptr = k.untyped_storage().data_ptr()
check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v])
stride = q.stride()
check_strides_qkv = all(stride == x.stride() for x in [q, k, v])
stride = k.stride()
check_strides_kv = all(stride == x.stride() for x in [k, v])
shape = q.shape
check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
shape = k.shape
check_shapes_kv = all(shape == x.shape for x in [k, v])
last_dim_size = q.shape[-1]
check_last_dim_offsets_qkv = all(i * last_dim_size == x.storage_offset()
for i, x in enumerate([q, k, v]))
last_dim_size = k.shape[-1]
check_last_dim_offsets_kv = all(i * last_dim_size == x.storage_offset()
for i, x in enumerate([k, v]))
last_two_dims_size = q.shape[-1] * q.shape[-2]
check_last_two_dims_offsets_qkv = all(i * last_two_dims_size == x.storage_offset()
for i, x in enumerate([q, k, v]))
last_two_dims_size = k.shape[-1] * k.shape[-2]
check_last_two_dims_offsets_kv = all(i * last_two_dims_size == x.storage_offset()
for i, x in enumerate([k, v]))
qkv_layout = None
if (check_ptrs_qkv and check_strides_qkv and check_shapes_qkv
and check_last_two_dims_offsets_qkv
and not check_last_dim_offsets_qkv):
# sb3hd, bs3hd, t3hd
qkv_layout = qkv_format[:-2] + '3' + qkv_format[-2:]
elif (check_ptrs_qkv and check_strides_qkv and check_shapes_qkv
and check_last_dim_offsets_qkv):
# sbh3d, bsh3d, th3d
qkv_layout = qkv_format[:-1] + '3' + qkv_format[-1:]
elif (check_ptrs_kv and check_strides_kv and check_shapes_kv
and check_last_two_dims_offsets_kv
and not check_last_dim_offsets_kv):
# sbhd_sb2hd, bshd_bs2hd, thd_t2hd
qkv_layout = qkv_format + '_' + qkv_format[:-2] + '2' + qkv_format[-2:]
elif (check_ptrs_kv and check_strides_kv and check_shapes_kv
and check_last_dim_offsets_kv):
# sbhd_sbh2d, bshd_bsh2d, thd_th2d
qkv_layout = qkv_format + '_' + qkv_format[:-1] + '2' + qkv_format[-1:]
elif check_strides_kv and check_shapes_kv:
# sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd
qkv_layout = '_'.join(list([qkv_format])*3)
else:
qkv_layout = 'not_supported'
return qkv_layout
qkv_layout = run_iteratively(q, k, v)
if qkv_layout == 'not_supported':
# force q,k,v to be contiguous and run get_layout again
q, k, v = [x.contiguous() for x in [q, k, v]]
qkv_layout = run_iteratively(q, k, v)
if qkv_layout == 'not_supported':
raise Exception("The provided qkv memory layout is not supported!")
return qkv_layout
return qkv_layout, q, k, v
class FlashAttention(torch.nn.Module):
......@@ -2083,8 +2095,8 @@ class DotProductAttention(torch.nn.Module):
), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
the sequence dimention in 'key_layer' and 'value_layer'!"""
qkv_layout = _get_qkv_layout(query_layer, key_layer, value_layer,
qkv_format = qkv_format)
qkv_layout, query_layer, key_layer, value_layer = _get_qkv_layout(
query_layer, key_layer, value_layer, qkv_format = qkv_format)
# The priority for attention backends (subject to availability and clearing the filters)
# is: FlashAttention > FusedAttention (cuDNN) > UnfusedDotProductAttention.
......@@ -2103,13 +2115,20 @@ class DotProductAttention(torch.nn.Module):
use_fused_attention = False
# Filter: Device and dimensions.
if key_layer.shape[-1] > 64:
if self.device_compute_capability in ((8, 6), (8, 7)):
# FAv1 supports head_dim <= 128, and for >64 requires sm80/sm90
# FAv2 supports head_dim <= 256, and for >192 requires sm80/sm90
# Both FAv1 and FAv2 require head_dim % 8 == 0
if not _flash_attn_2_available:
if (key_layer.shape[-1] > 128
or key_layer.shape[-1] % 8 != 0
or (key_layer.shape[-1] > 64
and self.device_compute_capability not in ((8, 0), (9, 0)))):
use_flash_attention = False
elif (
not _flash_attn_2_available
and self.device_compute_capability == (8, 9)
):
if _flash_attn_2_available:
if (key_layer.shape[-1] > 256
or key_layer.shape[-1] % 8 != 0
or (key_layer.shape[-1] > 192
and self.device_compute_capability not in ((8, 0), (9, 0)))):
use_flash_attention = False
if not _flash_attn_2_available and self.num_gqa_groups != self.num_attention_heads:
......@@ -2878,7 +2897,6 @@ class MultiheadAttention(torch.nn.Module):
q_pos_emb, k_pos_emb = rotary_pos_emb
query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb)
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb)
value_layer = value_layer.contiguous()
context_layer = self.core_attention(
query_layer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment