Unverified Commit 1cb4b25a authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Add support for multi-query and grouped-query attention (#338)



* add support for multi-query/grouped-query attention
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert to flash-attn 1.0.6 and build 2.0.0.post1 manually in CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add keyword name for DPA input
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fused attn tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix skipif for pytest
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Update transformer_engine/pytorch/attention.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update tests/pytorch/test_fused_attn.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix TP and SP case
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* add skipifs for pytest
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove higher limit for flash-attn version
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 59c0f096
......@@ -11,3 +11,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_attn.py
......@@ -290,7 +290,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements
if "pytorch" in frameworks():
add_unique(install_reqs, ["torch", "flash-attn>=1.0.6, <=2.0.0.post1"])
add_unique(install_reqs, ["torch", "flash-attn>=1.0.6"])
add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"])
if "jax" in frameworks():
if not found_pybind11():
......
......@@ -8,11 +8,19 @@ import pytest
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
get_device_compute_capability,
)
from transformer_engine.pytorch.fp8 import is_fp8_available
from transformer_engine.pytorch import TransformerLayer
from transformer_engine.pytorch.attention import DotProductAttention
import os
from pkg_resources import packaging
from importlib.metadata import version
fp8_available, reason_for_no_fp8 = is_fp8_available()
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
class ModelConfig:
def __init__(
self, num_layers, hidden_size, num_attention_heads, head_dim, seq_len,
......@@ -45,6 +53,8 @@ if torch.cuda.is_bf16_supported():
batch_sizes = [1, 2, 32]
@pytest.mark.skipif(
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
......@@ -113,6 +123,8 @@ def _run_dot_product_attention(dtype, bs, config, backend):
return op, inp.grad
@pytest.mark.skipif(
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
......@@ -208,12 +220,114 @@ def _run_transformer_layer(dtype, bs, config, backend):
return op, inp.grad
@pytest.mark.skipif(not _flash_attn_2_available, reason="FA2.0 is not available")
@pytest.mark.skipif(
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_transformer_layer_gqa(dtype, bs, model):
"""Test TransformerLayer module when its DotProductAttention is enabled with
FlashAttention, FusedAttention, or UnfusedDotProductAttention backend"""
config = model_configs[model]
def find_factors(x):
f = []
for i in range(1, x + 1):
if x % i == 0:
f.append(i)
return f
num_querys_per_gqa_group = find_factors(config.num_attention_heads)
for num_q_per_gqa_group in num_querys_per_gqa_group:
flash_attn_fwd, flash_attn_bwd = _run_transformer_layer_gqa(
dtype, bs, config, "FlashAttention", num_q_per_gqa_group)
unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer_gqa(
dtype, bs, config, "UnfusedDotProductAttention", num_q_per_gqa_group)
atol, rtol = 5e-1, 5e-1
assert torch.allclose(flash_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(flash_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol)
def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_group):
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
os.environ["NVTE_FLASH_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
inp = 0.1 * torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda()
inp.requires_grad=True
seqlens = torch.empty(bs, dtype = torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
op_grad = 0.001 * torch.randint(0, 200, (
config.seq_len, bs, config.num_attention_heads * config.head_dim
), dtype = dtype).cuda()
sigma = 0.02
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
layer_number = 1
drop_path_rate = 0.0
drop_path_rates = [
rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
num_gqa_groups = config.num_attention_heads / num_querys_per_gqa_group,
layernorm_epsilon = 1e-5,
hidden_dropout = 0.0,
attention_dropout = config.dropout_p,
init_method = init_method,
output_layer_init_method = output_layer_init_method,
layer_number = layer_number,
kv_channels = config.head_dim,
self_attn_mask_type = config.attn_mask_type,
tp_group = None,
tp_size = 1,
params_dtype = dtype,
get_rng_state_tracker = None,
fuse_wgrad_accumulation = False,
seq_length = config.seq_len,
micro_batch_size = bs,
sequence_parallel = False,
apply_residual_connection_post_layernorm = False,
output_layernorm = False,
layer_type = "encoder",
drop_path_rate = drop_path_rates[layer_number - 1],
set_parallel_mode = True,
fuse_qkv_params = True,
zero_centered_gamma = False,
qkv_weight_interleaved = False,
ub_tp_comm_overlap = False,
bias = True,
)
.to(dtype = dtype)
.cuda()
)
op = block(inp)
op.backward(op_grad)
return op, inp.grad
model_configs_fp8 = {
"test1": ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask"),
}
batch_sizes_fp8 = [1, 4]
param_types_fp8 = [torch.float16]
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("bs", batch_sizes_fp8)
@pytest.mark.parametrize("model", model_configs_fp8.keys())
......
......@@ -805,7 +805,7 @@ def test_dpa_accuracy(dtype, bs, model):
DotProductAttention(
config.num_attention_heads,
config.embed,
0.1, # dropout
attention_dropout=0.1, # dropout
)
.to(dtype=dtype)
.cuda()
......
......@@ -180,6 +180,15 @@ class UnfusedDotProductAttention(torch.nn.Module):
key_layer.size(0),
)
assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!"
if key_layer.shape[2] != query_layer.shape[2]:
assert (query_layer.shape[2]%key_layer.shape[2]==0
),"The number of attention heads must be divisible by the number of GQA groups!"
key_layer = key_layer.repeat_interleave(
int(query_layer.shape[2]/key_layer.shape[2]), dim = 2)
value_layer = value_layer.repeat_interleave(
int(query_layer.shape[2]/value_layer.shape[2]), dim = 2)
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.reshape(
output_size[2], output_size[0] * output_size[1], -1
......@@ -722,6 +731,14 @@ class DotProductAttention(torch.nn.Module):
number of attention heads in the transformer layer.
kv_channels : int
number of key-value channels.
num_gqa_groups : Optional[int] = None
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the queries.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
attention_dropout: float, default = 0.0
dropout probability for the dropout op during multi-head attention.
attn_mask_type: {'causal', 'padding'}, default = `causal`
......@@ -744,6 +761,7 @@ class DotProductAttention(torch.nn.Module):
self,
num_attention_heads: int,
kv_channels: int,
num_gqa_groups: Optional[int] = None,
attention_dropout: float = 0.0,
attn_mask_type: str = "causal",
sequence_parallel: bool = False,
......@@ -758,12 +776,16 @@ class DotProductAttention(torch.nn.Module):
self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
self.tp_group = tp_group
self.get_rng_state_tracker = get_rng_state_tracker
self.num_attention_heads = num_attention_heads
projection_size = kv_channels * num_attention_heads
self.hidden_size_per_partition = divide(projection_size, self.tp_size)
self.hidden_size_per_attention_head = divide(
projection_size, num_attention_heads
self.hidden_size_per_attention_head = kv_channels
self.num_gqa_groups = (
num_attention_heads if num_gqa_groups is None else num_gqa_groups
)
self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
assert (num_attention_heads % self.num_gqa_groups == 0
), "The number of attention heads must be divisible by the number of GQA groups!"
if sequence_parallel or get_rng_state_tracker is None:
attention_dropout_ctx = nullcontext
......@@ -883,6 +905,10 @@ class DotProductAttention(torch.nn.Module):
Whether to use the fast path to set output tensors to 0 or not.
"""
assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition
and value_layer.shape[-2] == self.num_gqa_groups_per_partition
), f"Keys and values must have {self.num_gqa_groups} heads!"
use_flash_attention = self.use_flash_attention
use_fused_attention = self.use_fused_attention
......@@ -898,6 +924,9 @@ class DotProductAttention(torch.nn.Module):
elif not _flash_attn_2_available and self.device_compute_capability == 8.9:
use_flash_attention = False
if not _flash_attn_2_available and self.num_gqa_groups != self.num_attention_heads:
use_flash_attention = False
if self.attn_mask_type == "padding" and attention_mask is not None:
use_flash_attention = False
use_fused_attention = False
......@@ -919,7 +948,9 @@ class DotProductAttention(torch.nn.Module):
# DPA does not support FP8; for FP8, use cpp_extensions modules directly
is_backend_avail = (fused_attention_backend in
[FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]])
use_fused_attention = use_fused_attention and is_backend_avail
use_fused_attention = (use_fused_attention
and is_backend_avail
and self.num_gqa_groups == self.num_attention_heads)
if use_flash_attention:
if checkpoint_core_attention:
......@@ -974,6 +1005,7 @@ class MultiHeadAttention(torch.nn.Module):
attn_mask_type: str = "causal",
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
num_gqa_groups: Optional[int] = None,
fuse_wgrad_accumulation: bool = False,
get_rng_state_tracker: Optional[Callable] = None,
sequence_parallel: bool = False,
......@@ -1002,6 +1034,7 @@ class MultiHeadAttention(torch.nn.Module):
self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.init_method = init_method
self.attn_mask_type = attn_mask_type
self.num_attention_heads = num_attention_heads
if not fuse_qkv_params:
qkv_weight_interleaved = False
......@@ -1017,6 +1050,15 @@ class MultiHeadAttention(torch.nn.Module):
self.hidden_size_per_attention_head = kv_channels
self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size)
self.num_gqa_groups = (
num_attention_heads if num_gqa_groups is None else num_gqa_groups
)
assert (num_attention_heads % self.num_gqa_groups == 0
), "The number of GQA groups must be divisible by the number of attention heads!"
assert (num_attention_heads % tp_size == 0
), "The number of GQA groups must be divisible by tensor parallel size!"
self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
self.hidden_size_kv = int(hidden_size * self.num_gqa_groups // num_attention_heads)
common_gemm_kwargs = {
"fuse_wgrad_accumulation": fuse_wgrad_accumulation,
......@@ -1029,7 +1071,7 @@ class MultiHeadAttention(torch.nn.Module):
qkv_parallel_mode = "column" if set_parallel_mode else None
if self.attention_type == "self":
if self.attention_type == "self" and self.num_gqa_groups == self.num_attention_heads:
if self.input_layernorm:
self.layernorm_qkv = LayerNormLinear(
hidden_size,
......@@ -1059,7 +1101,9 @@ class MultiHeadAttention(torch.nn.Module):
parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None,
**common_gemm_kwargs,
)
else:
elif ((self.attention_type == "cross")
or (self.attention_type == "self"
and self.num_gqa_groups != self.num_attention_heads)):
if self.input_layernorm:
self.layernorm_query = LayerNormLinear(
hidden_size,
......@@ -1089,7 +1133,7 @@ class MultiHeadAttention(torch.nn.Module):
)
self.key_value = Linear(
hidden_size,
2 * hidden_size,
2 * self.hidden_size_kv,
init_method=init_method,
bias=bias,
return_bias=False,
......@@ -1102,7 +1146,8 @@ class MultiHeadAttention(torch.nn.Module):
self.core_attention = DotProductAttention(
num_attention_heads,
kv_channels,
attention_dropout,
num_gqa_groups=self.num_gqa_groups,
attention_dropout=attention_dropout,
tp_size=tp_size,
get_rng_state_tracker=get_rng_state_tracker,
attn_mask_type=attn_mask_type,
......@@ -1131,7 +1176,7 @@ class MultiHeadAttention(torch.nn.Module):
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
self.num_gqa_groups_per_partition,
self.hidden_size_per_attention_head,
dtype=dtype,
device=torch.cuda.current_device(),
......@@ -1192,7 +1237,7 @@ class MultiHeadAttention(torch.nn.Module):
# Query, Key, and Value
# =====================
if self.attention_type == "self":
if self.attention_type == "self" and self.num_gqa_groups == self.num_attention_heads:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
if self.input_layernorm:
layernorm_qkv_outputs = self.layernorm_qkv(
......@@ -1235,17 +1280,25 @@ class MultiHeadAttention(torch.nn.Module):
query_layer, key_layer, value_layer = split_tensor_along_dim(
mixed_x_layer, split_dim, 3
)
else:
elif ((self.attention_type == "cross")
or (self.attention_type == "self"
and self.num_gqa_groups != self.num_attention_heads)):
if self.attention_type == "cross":
input_tensor = encoder_output
else:
input_tensor = hidden_states
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer = self.key_value(
encoder_output,
input_tensor,
is_first_microbatch=is_first_microbatch,
)
if self.qkv_weight_interleaved:
# [sq, b, (np * 2 * hn)] --> [sq, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
self.num_gqa_groups_per_partition,
2 * self.hidden_size_per_attention_head,
)
# split along last dimension
......@@ -1253,7 +1306,7 @@ class MultiHeadAttention(torch.nn.Module):
else:
# [sq, b, (np * 2 * hn)] --> [sq, b, 2 * np, hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + (
2 * self.num_attention_heads_per_partition,
2 * self.num_gqa_groups_per_partition,
self.hidden_size_per_attention_head,
)
# split along second last dimension
......
......@@ -86,6 +86,14 @@ class TransformerLayer(torch.nn.Module):
intermediate size to which input samples are projected.
num_attention_heads : int
number of attention heads in the transformer layer.
num_gqa_groups : int, default = `None`
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the querys.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
layernorm_epsilon : float, default = 1e-5
a value added to the denominator of layer normalization
for numerical stability.
......@@ -194,6 +202,7 @@ class TransformerLayer(torch.nn.Module):
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
num_gqa_groups: Optional[int] = None,
layernorm_epsilon: float = 1e-5,
hidden_dropout: float = 0.1,
attention_dropout: float = 0.1,
......@@ -293,6 +302,7 @@ class TransformerLayer(torch.nn.Module):
"layer_number": layer_number,
"tp_group": tp_group,
"tp_size": self.tp_size,
"num_gqa_groups": num_gqa_groups,
"fuse_wgrad_accumulation": fuse_wgrad_accumulation,
"get_rng_state_tracker": get_rng_state_tracker,
"sequence_parallel": self.sequence_parallel,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment