Unverified Commit 92d1ba0d authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

[C/PyTorch] RoPE fixes and minor improvements for fused attention (#453)



* add support for h2d/2hd in 8.9.6
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* cull unit tests in fused_attn.py and add skipif for layout tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add workopt=1 flag for dpa tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update support table for arbi_seqlen backend
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix rotary position embedding and add unit tests accordingly
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* further cut down unit tests for CI efficiency
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove einops dependency
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent 79f5fac7
...@@ -12,7 +12,7 @@ from transformer_engine.pytorch.utils import ( ...@@ -12,7 +12,7 @@ from transformer_engine.pytorch.utils import (
) )
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch import TransformerLayer from transformer_engine.pytorch import TransformerLayer
from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import DotProductAttention, RotaryPositionEmbedding
import os import os
from pkg_resources import packaging from pkg_resources import packaging
...@@ -21,6 +21,8 @@ from test_numerics import get_dummy_cuda_rng_tracker, reset_rng_states ...@@ -21,6 +21,8 @@ from test_numerics import get_dummy_cuda_rng_tracker, reset_rng_states
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
_flash_attn_version = packaging.version.Version(version("flash-attn")) _flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2") _flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
_cudnn_version = [int(i) for i in os.environ['CUDNN_VERSION'].split('.')]
class ModelConfig: class ModelConfig:
def __init__( def __init__(
...@@ -45,22 +47,26 @@ model_configs = { ...@@ -45,22 +47,26 @@ model_configs = {
"test5": ModelConfig(1, 1024, 16, 64, 128, 0.0, "no_mask"), "test5": ModelConfig(1, 1024, 16, 64, 128, 0.0, "no_mask"),
} }
if os.getenv('NVTE_ADDITIONAL_TESTS', '0') == '1':
model_configs["test6"] = ModelConfig(1, 1024, 16, 64, 512, 0.0, "causal")
model_configs["test7"] = ModelConfig(1, 2048, 16, 128, 512, 0.0, "causal")
model_configs["test8"] = ModelConfig(1, 2048, 16, 128, 2048, 0.0, "causal")
model_configs["test9"] = ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask")
param_types = [torch.float16] param_types = [torch.float16]
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported():
param_types.append(torch.bfloat16) param_types.append(torch.bfloat16)
batch_sizes = [1, 2] # add more if needed, e.g. 32 batch_sizes = [1, 32]
model_configs_lean = {
"test6": ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask"),
"test7": ModelConfig(1, 2048, 16, 128, 2048, 0.0, "causal"),
}
param_types_lean = [torch.bfloat16]
batch_sizes_lean = [2]
@pytest.mark.skipif( @pytest.mark.skipif(
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.") get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes_lean)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("ckpt_attn", [True, False]) @pytest.mark.parametrize("ckpt_attn", [True, False])
@pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"]) @pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"])
...@@ -69,7 +75,6 @@ def test_dot_product_attention(dtype, bs, model, ckpt_attn, bias_type): ...@@ -69,7 +75,6 @@ def test_dot_product_attention(dtype, bs, model, ckpt_attn, bias_type):
FlashAttention, FusedAttention and UnfusedDotProductAttention""" FlashAttention, FusedAttention and UnfusedDotProductAttention"""
config = model_configs[model] config = model_configs[model]
if bias_type == "no_bias": if bias_type == "no_bias":
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, bs, config, "FlashAttention", ckpt_attn, bias_type) dtype, bs, config, "FlashAttention", ckpt_attn, bias_type)
...@@ -94,6 +99,7 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type) ...@@ -94,6 +99,7 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type)
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention": if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1"
inp = torch.randn( inp = torch.randn(
config.seq_len, bs, 3, config.num_attention_heads, config.head_dim, config.seq_len, bs, 3, config.num_attention_heads, config.head_dim,
...@@ -150,15 +156,17 @@ qkv_layouts = [ ...@@ -150,15 +156,17 @@ qkv_layouts = [
@pytest.mark.skipif( @pytest.mark.skipif(
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.") get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.skipif(
@pytest.mark.parametrize("bs", batch_sizes) _cudnn_version >= [8,9,5], reason="cuDNN 8.9.5+ is required.")
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("bs", batch_sizes_lean)
@pytest.mark.parametrize("model", model_configs_lean.keys())
@pytest.mark.parametrize("workspace_opt", [True, False]) @pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", qkv_layouts) @pytest.mark.parametrize("qkv_layout", qkv_layouts)
def test_dpa_qkv_layout(dtype, bs, model, workspace_opt, qkv_layout): def test_dpa_qkv_layout(dtype, bs, model, workspace_opt, qkv_layout):
"""Test DotProductAttention module with different QKV layouts""" """Test DotProductAttention module with different QKV layouts"""
config = model_configs[model] config = model_configs_lean[model]
flash_attn_fwd, flash_attn_bwd = _run_dpa_qkv_layout( flash_attn_fwd, flash_attn_bwd = _run_dpa_qkv_layout(
dtype, bs, config, "FlashAttention", qkv_layout, workspace_opt) dtype, bs, config, "FlashAttention", qkv_layout, workspace_opt)
...@@ -188,7 +196,6 @@ def _run_dpa_qkv_layout(dtype, bs, config, backend, qkv_layout, workspace_opt): ...@@ -188,7 +196,6 @@ def _run_dpa_qkv_layout(dtype, bs, config, backend, qkv_layout, workspace_opt):
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0" os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
dim_to_num = {'b': bs, dim_to_num = {'b': bs,
's': config.seq_len, 's': config.seq_len,
'h': config.num_attention_heads, 'h': config.num_attention_heads,
...@@ -269,23 +276,23 @@ def _run_dpa_qkv_layout(dtype, bs, config, backend, qkv_layout, workspace_opt): ...@@ -269,23 +276,23 @@ def _run_dpa_qkv_layout(dtype, bs, config, backend, qkv_layout, workspace_opt):
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.") get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs_lean.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"]) @pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"])
@pytest.mark.parametrize("fused_qkv_params", [True, False]) @pytest.mark.parametrize("fused_qkv_params", [True, False])
def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type, fused_qkv_params): @pytest.mark.parametrize("RoPE", [True, False])
def test_transformer_layer(dtype, bs, model, bias_type, fused_qkv_params, RoPE):
"""Test TransformerLayer module when its DotProductAttention is enabled with """Test TransformerLayer module when its DotProductAttention is enabled with
FlashAttention, FusedAttention, or UnfusedDotProductAttention backend""" FlashAttention, FusedAttention, or UnfusedDotProductAttention backend"""
config = model_configs[model] config = model_configs_lean[model]
if bias_type == "no_bias": if bias_type == "no_bias":
flash_attn_fwd, flash_attn_bwd = _run_transformer_layer( flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
dtype, bs, config, "FlashAttention", ckpt_attn, bias_type, fused_qkv_params) dtype, bs, config, "FlashAttention", bias_type, fused_qkv_params, RoPE)
fused_attn_fwd, fused_attn_bwd = _run_transformer_layer( fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
dtype, bs, config, "FusedAttention", ckpt_attn, bias_type, fused_qkv_params) dtype, bs, config, "FusedAttention", bias_type, fused_qkv_params, RoPE)
unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer( unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
dtype, bs, config, "UnfusedDotProductAttention", ckpt_attn, bias_type, fused_qkv_params) dtype, bs, config, "UnfusedDotProductAttention", bias_type, fused_qkv_params, RoPE)
atol, rtol = (5e-1, 5e-2) atol, rtol = (5e-1, 5e-2)
if bias_type == "no_bias": if bias_type == "no_bias":
...@@ -294,7 +301,7 @@ def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type, fused_qkv_par ...@@ -294,7 +301,7 @@ def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type, fused_qkv_par
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol) torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol) torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fused_qkv_params): def _run_transformer_layer(dtype, bs, config, backend, bias_type, fused_qkv_params, RoPE):
reset_rng_states() reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
...@@ -327,6 +334,11 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fus ...@@ -327,6 +334,11 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fus
else: else:
bias = None bias = None
rotary_pos_emb = None
if RoPE:
PE = RotaryPositionEmbedding(dim=config.head_dim)
rotary_pos_emb = PE(config.seq_len).cuda().to(dtype=dtype)
block = ( block = (
TransformerLayer( TransformerLayer(
config.hidden_size, config.hidden_size,
...@@ -365,7 +377,7 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fus ...@@ -365,7 +377,7 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fus
num_iters = 5 num_iters = 5
for i in range(num_iters): for i in range(num_iters):
op = block(inp, self_attn_mask_type=config.attn_mask_type, op = block(inp, self_attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn, rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=bias_type, core_attention_bias_type=bias_type,
core_attention_bias=bias) core_attention_bias=bias)
loss = op.sum() loss = op.sum()
...@@ -376,14 +388,14 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fus ...@@ -376,14 +388,14 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fus
@pytest.mark.skipif(not _flash_attn_2_available, reason="FA2.0 is not available") @pytest.mark.skipif(not _flash_attn_2_available, reason="FA2.0 is not available")
@pytest.mark.skipif( @pytest.mark.skipif(
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.") get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes_lean)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs_lean.keys())
def test_transformer_layer_gqa(dtype, bs, model): def test_transformer_layer_gqa(dtype, bs, model):
"""Test TransformerLayer module when its DotProductAttention is enabled with """Test TransformerLayer module when its DotProductAttention is enabled with
FlashAttention, FusedAttention, or UnfusedDotProductAttention backend""" FlashAttention, FusedAttention, or UnfusedDotProductAttention backend"""
config = model_configs[model] config = model_configs_lean[model]
def find_factors(x): def find_factors(x):
f = [] f = []
for i in range(1, x + 1): for i in range(1, x + 1):
......
...@@ -1338,12 +1338,13 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, ...@@ -1338,12 +1338,13 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen,
} }
use_workspace_opt = transformer_engine::getenv<bool>( use_workspace_opt = transformer_engine::getenv<bool>(
"NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT", use_workspace_opt); "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT", use_workspace_opt);
// will not be needed in cuDNN 8.9.6 #if (CUDNN_VERSION < 8906)
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
if ((layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) if ((layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD)
|| (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D)) { || (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D)) {
use_workspace_opt = false; use_workspace_opt = false;
} }
#endif
} }
#endif #endif
...@@ -1485,12 +1486,13 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t max_seqlen_q, size_t m ...@@ -1485,12 +1486,13 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t max_seqlen_q, size_t m
} }
use_workspace_opt = transformer_engine::getenv<bool>( use_workspace_opt = transformer_engine::getenv<bool>(
"NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT", use_workspace_opt); "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT", use_workspace_opt);
// will not be needed in cuDNN 8.9.6 #if (CUDNN_VERSION < 8906)
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
if ((layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) if ((layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD)
|| (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D)) { || (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D)) {
use_workspace_opt = false; use_workspace_opt = false;
} }
#endif
} }
#endif #endif
......
...@@ -691,6 +691,64 @@ def flash_attn_forward_func_with_cp(q, k, v, cu_seqlens_q, cu_seqlens_k, ...@@ -691,6 +691,64 @@ def flash_attn_forward_func_with_cp(q, k, v, cu_seqlens_q, cu_seqlens_k,
return out return out
class RotaryPositionEmbedding(torch.nn.Module):
"""
Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
"""
def __init__(
self,
dim: int,
seq_len_interpolation_factor: Optional[int] = None,
pretrained_max_position_embeddings: Optional[int] = None,
):
"""
Parameters
----------
dim: int
rotary embedding dimension
seq_len_interpolation_factor: int
if not None, discrete positions will be interpolated by this factor via the trick in
https://arxiv.org/abs/2306.15595
pretrained_max_position_embeddings: int
pre-trained max_position_embeddings before position interpolation
"""
super().__init__()
self.seq_len_interpolation_factor = seq_len_interpolation_factor
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
self.pretrained_max_position_embeddings = pretrained_max_position_embeddings
def forward(self, max_seq_len: int, offset: int = 0):
"""
Create rotary position embedding frequencies
Parameters
----------
max_seq_len: int
sequence length of a sample
offset: int, default = 0
fixed offset for freqencies
"""
seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset
seq = seq.type_as(self.inv_freq)
if (self.pretrained_max_position_embeddings is not None
and self.seq_len_interpolation_factor is not None):
if (max_seq_len >
self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor):
# dynamic linear scaling (length > position we have learned)
seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
else:
# fixed linear scaling
seq *= 1 / self.seq_len_interpolation_factor
freqs = torch.einsum('i , j -> i j', seq, self.inv_freq)
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
emb = torch.cat((freqs, freqs), dim=-1)
# emb [seq_length, .., dim]
return emb.reshape(emb.size(0), 1, 1, emb.size(1))
def _rotate_half(x: torch.Tensor) -> torch.Tensor: def _rotate_half(x: torch.Tensor) -> torch.Tensor:
""" """
change sign so the last dimension becomes [-odd, +even] change sign so the last dimension becomes [-odd, +even]
...@@ -1488,9 +1546,10 @@ class FusedAttention(torch.nn.Module): ...@@ -1488,9 +1546,10 @@ class FusedAttention(torch.nn.Module):
| qkv_layout | | | | qkv_layout | | |
| - qkv | qkv_interleaved | qkv_interleaved | | - qkv | qkv_interleaved | qkv_interleaved |
| - (q,kv) | kv_interleaved | | | - (q,kv) | kv_interleaved | |
| - (q,k,v) | sb3hd, bs3hd | sb3hd, bs3hd | | - (q,k,v) | sb3hd, bs3hd | sb3hd, bs3hd, sbh3d, bsh3d |
| | sbhd_sb2hd, bshd_bs2hd | sbhd_sb2hd, bshd_bs2hd | | | sbhd_sb2hd, bshd_bs2hd | sbhd_sb2hd, bshd_bs2hd |
| | bshd_bshd_bshd | sbhd_sbhd_sbhd, bshd_bshd_bshd | | | bshd_bshd_bshd | sbhd_sbh2d, bshd_bsh2d |
| | | sbhd_sbhd_sbhd, bshd_bshd_bshd |
| mask_type | causal/no_mask | causal | | mask_type | causal/no_mask | causal |
| bias_type | no_bias/post_scale_bias | no_bias | | bias_type | no_bias/post_scale_bias | no_bias |
| dropout | yes | yes | | dropout | yes | yes |
...@@ -2736,6 +2795,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2736,6 +2795,7 @@ class MultiheadAttention(torch.nn.Module):
q_pos_emb, k_pos_emb = rotary_pos_emb q_pos_emb, k_pos_emb = rotary_pos_emb
query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb) query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb)
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb) key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb)
value_layer = value_layer.contiguous()
context_layer = self.core_attention( context_layer = self.core_attention(
query_layer, query_layer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment