"tests/vscode:/vscode.git/clone" did not exist on "e1ac502890356f5f649db0d17305af480f8c361d"
Unverified Commit 1cb4b25a authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Add support for multi-query and grouped-query attention (#338)



* add support for multi-query/grouped-query attention
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert to flash-attn 1.0.6 and build 2.0.0.post1 manually in CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add keyword name for DPA input
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fused attn tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix skipif for pytest
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Update transformer_engine/pytorch/attention.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update tests/pytorch/test_fused_attn.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix TP and SP case
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* add skipifs for pytest
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove higher limit for flash-attn version
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 59c0f096
...@@ -11,3 +11,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py ...@@ -11,3 +11,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_attn.py
...@@ -290,7 +290,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -290,7 +290,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements # Framework-specific requirements
if "pytorch" in frameworks(): if "pytorch" in frameworks():
add_unique(install_reqs, ["torch", "flash-attn>=1.0.6, <=2.0.0.post1"]) add_unique(install_reqs, ["torch", "flash-attn>=1.0.6"])
add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"]) add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"])
if "jax" in frameworks(): if "jax" in frameworks():
if not found_pybind11(): if not found_pybind11():
......
...@@ -8,11 +8,19 @@ import pytest ...@@ -8,11 +8,19 @@ import pytest
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
init_method_normal, init_method_normal,
scaled_init_method_normal, scaled_init_method_normal,
get_device_compute_capability,
) )
from transformer_engine.pytorch.fp8 import is_fp8_available
from transformer_engine.pytorch import TransformerLayer from transformer_engine.pytorch import TransformerLayer
from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import DotProductAttention
import os import os
from pkg_resources import packaging
from importlib.metadata import version
fp8_available, reason_for_no_fp8 = is_fp8_available()
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
class ModelConfig: class ModelConfig:
def __init__( def __init__(
self, num_layers, hidden_size, num_attention_heads, head_dim, seq_len, self, num_layers, hidden_size, num_attention_heads, head_dim, seq_len,
...@@ -45,6 +53,8 @@ if torch.cuda.is_bf16_supported(): ...@@ -45,6 +53,8 @@ if torch.cuda.is_bf16_supported():
batch_sizes = [1, 2, 32] batch_sizes = [1, 2, 32]
@pytest.mark.skipif(
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
...@@ -113,6 +123,8 @@ def _run_dot_product_attention(dtype, bs, config, backend): ...@@ -113,6 +123,8 @@ def _run_dot_product_attention(dtype, bs, config, backend):
return op, inp.grad return op, inp.grad
@pytest.mark.skipif(
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
...@@ -208,12 +220,114 @@ def _run_transformer_layer(dtype, bs, config, backend): ...@@ -208,12 +220,114 @@ def _run_transformer_layer(dtype, bs, config, backend):
return op, inp.grad return op, inp.grad
@pytest.mark.skipif(not _flash_attn_2_available, reason="FA2.0 is not available")
@pytest.mark.skipif(
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_transformer_layer_gqa(dtype, bs, model):
"""Test TransformerLayer module when its DotProductAttention is enabled with
FlashAttention, FusedAttention, or UnfusedDotProductAttention backend"""
config = model_configs[model]
def find_factors(x):
f = []
for i in range(1, x + 1):
if x % i == 0:
f.append(i)
return f
num_querys_per_gqa_group = find_factors(config.num_attention_heads)
for num_q_per_gqa_group in num_querys_per_gqa_group:
flash_attn_fwd, flash_attn_bwd = _run_transformer_layer_gqa(
dtype, bs, config, "FlashAttention", num_q_per_gqa_group)
unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer_gqa(
dtype, bs, config, "UnfusedDotProductAttention", num_q_per_gqa_group)
atol, rtol = 5e-1, 5e-1
assert torch.allclose(flash_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(flash_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol)
def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_group):
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
os.environ["NVTE_FLASH_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
inp = 0.1 * torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda()
inp.requires_grad=True
seqlens = torch.empty(bs, dtype = torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
op_grad = 0.001 * torch.randint(0, 200, (
config.seq_len, bs, config.num_attention_heads * config.head_dim
), dtype = dtype).cuda()
sigma = 0.02
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
layer_number = 1
drop_path_rate = 0.0
drop_path_rates = [
rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
num_gqa_groups = config.num_attention_heads / num_querys_per_gqa_group,
layernorm_epsilon = 1e-5,
hidden_dropout = 0.0,
attention_dropout = config.dropout_p,
init_method = init_method,
output_layer_init_method = output_layer_init_method,
layer_number = layer_number,
kv_channels = config.head_dim,
self_attn_mask_type = config.attn_mask_type,
tp_group = None,
tp_size = 1,
params_dtype = dtype,
get_rng_state_tracker = None,
fuse_wgrad_accumulation = False,
seq_length = config.seq_len,
micro_batch_size = bs,
sequence_parallel = False,
apply_residual_connection_post_layernorm = False,
output_layernorm = False,
layer_type = "encoder",
drop_path_rate = drop_path_rates[layer_number - 1],
set_parallel_mode = True,
fuse_qkv_params = True,
zero_centered_gamma = False,
qkv_weight_interleaved = False,
ub_tp_comm_overlap = False,
bias = True,
)
.to(dtype = dtype)
.cuda()
)
op = block(inp)
op.backward(op_grad)
return op, inp.grad
model_configs_fp8 = { model_configs_fp8 = {
"test1": ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask"), "test1": ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask"),
} }
batch_sizes_fp8 = [1, 4] batch_sizes_fp8 = [1, 4]
param_types_fp8 = [torch.float16] param_types_fp8 = [torch.float16]
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("dtype", param_types_fp8) @pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("bs", batch_sizes_fp8) @pytest.mark.parametrize("bs", batch_sizes_fp8)
@pytest.mark.parametrize("model", model_configs_fp8.keys()) @pytest.mark.parametrize("model", model_configs_fp8.keys())
......
...@@ -805,7 +805,7 @@ def test_dpa_accuracy(dtype, bs, model): ...@@ -805,7 +805,7 @@ def test_dpa_accuracy(dtype, bs, model):
DotProductAttention( DotProductAttention(
config.num_attention_heads, config.num_attention_heads,
config.embed, config.embed,
0.1, # dropout attention_dropout=0.1, # dropout
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
......
...@@ -180,6 +180,15 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -180,6 +180,15 @@ class UnfusedDotProductAttention(torch.nn.Module):
key_layer.size(0), key_layer.size(0),
) )
assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!"
if key_layer.shape[2] != query_layer.shape[2]:
assert (query_layer.shape[2]%key_layer.shape[2]==0
),"The number of attention heads must be divisible by the number of GQA groups!"
key_layer = key_layer.repeat_interleave(
int(query_layer.shape[2]/key_layer.shape[2]), dim = 2)
value_layer = value_layer.repeat_interleave(
int(query_layer.shape[2]/value_layer.shape[2]), dim = 2)
# [sq, b, np, hn] -> [sq, b * np, hn] # [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.reshape( query_layer = query_layer.reshape(
output_size[2], output_size[0] * output_size[1], -1 output_size[2], output_size[0] * output_size[1], -1
...@@ -722,6 +731,14 @@ class DotProductAttention(torch.nn.Module): ...@@ -722,6 +731,14 @@ class DotProductAttention(torch.nn.Module):
number of attention heads in the transformer layer. number of attention heads in the transformer layer.
kv_channels : int kv_channels : int
number of key-value channels. number of key-value channels.
num_gqa_groups : Optional[int] = None
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the queries.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
attention_dropout: float, default = 0.0 attention_dropout: float, default = 0.0
dropout probability for the dropout op during multi-head attention. dropout probability for the dropout op during multi-head attention.
attn_mask_type: {'causal', 'padding'}, default = `causal` attn_mask_type: {'causal', 'padding'}, default = `causal`
...@@ -744,6 +761,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -744,6 +761,7 @@ class DotProductAttention(torch.nn.Module):
self, self,
num_attention_heads: int, num_attention_heads: int,
kv_channels: int, kv_channels: int,
num_gqa_groups: Optional[int] = None,
attention_dropout: float = 0.0, attention_dropout: float = 0.0,
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
sequence_parallel: bool = False, sequence_parallel: bool = False,
...@@ -758,12 +776,16 @@ class DotProductAttention(torch.nn.Module): ...@@ -758,12 +776,16 @@ class DotProductAttention(torch.nn.Module):
self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group) self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
self.tp_group = tp_group self.tp_group = tp_group
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
self.num_attention_heads = num_attention_heads
projection_size = kv_channels * num_attention_heads self.hidden_size_per_attention_head = kv_channels
self.hidden_size_per_partition = divide(projection_size, self.tp_size) self.num_gqa_groups = (
self.hidden_size_per_attention_head = divide( num_attention_heads if num_gqa_groups is None else num_gqa_groups
projection_size, num_attention_heads
) )
self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
assert (num_attention_heads % self.num_gqa_groups == 0
), "The number of attention heads must be divisible by the number of GQA groups!"
if sequence_parallel or get_rng_state_tracker is None: if sequence_parallel or get_rng_state_tracker is None:
attention_dropout_ctx = nullcontext attention_dropout_ctx = nullcontext
...@@ -883,6 +905,10 @@ class DotProductAttention(torch.nn.Module): ...@@ -883,6 +905,10 @@ class DotProductAttention(torch.nn.Module):
Whether to use the fast path to set output tensors to 0 or not. Whether to use the fast path to set output tensors to 0 or not.
""" """
assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition
and value_layer.shape[-2] == self.num_gqa_groups_per_partition
), f"Keys and values must have {self.num_gqa_groups} heads!"
use_flash_attention = self.use_flash_attention use_flash_attention = self.use_flash_attention
use_fused_attention = self.use_fused_attention use_fused_attention = self.use_fused_attention
...@@ -898,6 +924,9 @@ class DotProductAttention(torch.nn.Module): ...@@ -898,6 +924,9 @@ class DotProductAttention(torch.nn.Module):
elif not _flash_attn_2_available and self.device_compute_capability == 8.9: elif not _flash_attn_2_available and self.device_compute_capability == 8.9:
use_flash_attention = False use_flash_attention = False
if not _flash_attn_2_available and self.num_gqa_groups != self.num_attention_heads:
use_flash_attention = False
if self.attn_mask_type == "padding" and attention_mask is not None: if self.attn_mask_type == "padding" and attention_mask is not None:
use_flash_attention = False use_flash_attention = False
use_fused_attention = False use_fused_attention = False
...@@ -919,7 +948,9 @@ class DotProductAttention(torch.nn.Module): ...@@ -919,7 +948,9 @@ class DotProductAttention(torch.nn.Module):
# DPA does not support FP8; for FP8, use cpp_extensions modules directly # DPA does not support FP8; for FP8, use cpp_extensions modules directly
is_backend_avail = (fused_attention_backend in is_backend_avail = (fused_attention_backend in
[FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]]) [FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]])
use_fused_attention = use_fused_attention and is_backend_avail use_fused_attention = (use_fused_attention
and is_backend_avail
and self.num_gqa_groups == self.num_attention_heads)
if use_flash_attention: if use_flash_attention:
if checkpoint_core_attention: if checkpoint_core_attention:
...@@ -974,6 +1005,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -974,6 +1005,7 @@ class MultiHeadAttention(torch.nn.Module):
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
tp_size: int = 1, tp_size: int = 1,
num_gqa_groups: Optional[int] = None,
fuse_wgrad_accumulation: bool = False, fuse_wgrad_accumulation: bool = False,
get_rng_state_tracker: Optional[Callable] = None, get_rng_state_tracker: Optional[Callable] = None,
sequence_parallel: bool = False, sequence_parallel: bool = False,
...@@ -1002,6 +1034,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -1002,6 +1034,7 @@ class MultiHeadAttention(torch.nn.Module):
self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.init_method = init_method self.init_method = init_method
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.num_attention_heads = num_attention_heads
if not fuse_qkv_params: if not fuse_qkv_params:
qkv_weight_interleaved = False qkv_weight_interleaved = False
...@@ -1017,6 +1050,15 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -1017,6 +1050,15 @@ class MultiHeadAttention(torch.nn.Module):
self.hidden_size_per_attention_head = kv_channels self.hidden_size_per_attention_head = kv_channels
self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size) self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size)
self.num_gqa_groups = (
num_attention_heads if num_gqa_groups is None else num_gqa_groups
)
assert (num_attention_heads % self.num_gqa_groups == 0
), "The number of GQA groups must be divisible by the number of attention heads!"
assert (num_attention_heads % tp_size == 0
), "The number of GQA groups must be divisible by tensor parallel size!"
self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
self.hidden_size_kv = int(hidden_size * self.num_gqa_groups // num_attention_heads)
common_gemm_kwargs = { common_gemm_kwargs = {
"fuse_wgrad_accumulation": fuse_wgrad_accumulation, "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
...@@ -1029,7 +1071,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -1029,7 +1071,7 @@ class MultiHeadAttention(torch.nn.Module):
qkv_parallel_mode = "column" if set_parallel_mode else None qkv_parallel_mode = "column" if set_parallel_mode else None
if self.attention_type == "self": if self.attention_type == "self" and self.num_gqa_groups == self.num_attention_heads:
if self.input_layernorm: if self.input_layernorm:
self.layernorm_qkv = LayerNormLinear( self.layernorm_qkv = LayerNormLinear(
hidden_size, hidden_size,
...@@ -1059,7 +1101,9 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -1059,7 +1101,9 @@ class MultiHeadAttention(torch.nn.Module):
parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None, parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
else: elif ((self.attention_type == "cross")
or (self.attention_type == "self"
and self.num_gqa_groups != self.num_attention_heads)):
if self.input_layernorm: if self.input_layernorm:
self.layernorm_query = LayerNormLinear( self.layernorm_query = LayerNormLinear(
hidden_size, hidden_size,
...@@ -1089,7 +1133,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -1089,7 +1133,7 @@ class MultiHeadAttention(torch.nn.Module):
) )
self.key_value = Linear( self.key_value = Linear(
hidden_size, hidden_size,
2 * hidden_size, 2 * self.hidden_size_kv,
init_method=init_method, init_method=init_method,
bias=bias, bias=bias,
return_bias=False, return_bias=False,
...@@ -1102,7 +1146,8 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -1102,7 +1146,8 @@ class MultiHeadAttention(torch.nn.Module):
self.core_attention = DotProductAttention( self.core_attention = DotProductAttention(
num_attention_heads, num_attention_heads,
kv_channels, kv_channels,
attention_dropout, num_gqa_groups=self.num_gqa_groups,
attention_dropout=attention_dropout,
tp_size=tp_size, tp_size=tp_size,
get_rng_state_tracker=get_rng_state_tracker, get_rng_state_tracker=get_rng_state_tracker,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
...@@ -1131,7 +1176,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -1131,7 +1176,7 @@ class MultiHeadAttention(torch.nn.Module):
return torch.empty( return torch.empty(
inference_max_sequence_len, inference_max_sequence_len,
batch_size, batch_size,
self.num_attention_heads_per_partition, self.num_gqa_groups_per_partition,
self.hidden_size_per_attention_head, self.hidden_size_per_attention_head,
dtype=dtype, dtype=dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
...@@ -1192,7 +1237,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -1192,7 +1237,7 @@ class MultiHeadAttention(torch.nn.Module):
# Query, Key, and Value # Query, Key, and Value
# ===================== # =====================
if self.attention_type == "self": if self.attention_type == "self" and self.num_gqa_groups == self.num_attention_heads:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
if self.input_layernorm: if self.input_layernorm:
layernorm_qkv_outputs = self.layernorm_qkv( layernorm_qkv_outputs = self.layernorm_qkv(
...@@ -1235,17 +1280,25 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -1235,17 +1280,25 @@ class MultiHeadAttention(torch.nn.Module):
query_layer, key_layer, value_layer = split_tensor_along_dim( query_layer, key_layer, value_layer = split_tensor_along_dim(
mixed_x_layer, split_dim, 3 mixed_x_layer, split_dim, 3
) )
else: elif ((self.attention_type == "cross")
or (self.attention_type == "self"
and self.num_gqa_groups != self.num_attention_heads)):
if self.attention_type == "cross":
input_tensor = encoder_output
else:
input_tensor = hidden_states
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer = self.key_value( mixed_kv_layer = self.key_value(
encoder_output, input_tensor,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
) )
if self.qkv_weight_interleaved: if self.qkv_weight_interleaved:
# [sq, b, (np * 2 * hn)] --> [sq, b, np, 2 * hn] # [sq, b, (np * 2 * hn)] --> [sq, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + ( new_tensor_shape = mixed_kv_layer.size()[:-1] + (
self.num_attention_heads_per_partition, self.num_gqa_groups_per_partition,
2 * self.hidden_size_per_attention_head, 2 * self.hidden_size_per_attention_head,
) )
# split along last dimension # split along last dimension
...@@ -1253,7 +1306,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -1253,7 +1306,7 @@ class MultiHeadAttention(torch.nn.Module):
else: else:
# [sq, b, (np * 2 * hn)] --> [sq, b, 2 * np, hn] # [sq, b, (np * 2 * hn)] --> [sq, b, 2 * np, hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + ( new_tensor_shape = mixed_kv_layer.size()[:-1] + (
2 * self.num_attention_heads_per_partition, 2 * self.num_gqa_groups_per_partition,
self.hidden_size_per_attention_head, self.hidden_size_per_attention_head,
) )
# split along second last dimension # split along second last dimension
......
...@@ -86,6 +86,14 @@ class TransformerLayer(torch.nn.Module): ...@@ -86,6 +86,14 @@ class TransformerLayer(torch.nn.Module):
intermediate size to which input samples are projected. intermediate size to which input samples are projected.
num_attention_heads : int num_attention_heads : int
number of attention heads in the transformer layer. number of attention heads in the transformer layer.
num_gqa_groups : int, default = `None`
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the querys.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
layernorm_epsilon : float, default = 1e-5 layernorm_epsilon : float, default = 1e-5
a value added to the denominator of layer normalization a value added to the denominator of layer normalization
for numerical stability. for numerical stability.
...@@ -194,6 +202,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -194,6 +202,7 @@ class TransformerLayer(torch.nn.Module):
hidden_size: int, hidden_size: int,
ffn_hidden_size: int, ffn_hidden_size: int,
num_attention_heads: int, num_attention_heads: int,
num_gqa_groups: Optional[int] = None,
layernorm_epsilon: float = 1e-5, layernorm_epsilon: float = 1e-5,
hidden_dropout: float = 0.1, hidden_dropout: float = 0.1,
attention_dropout: float = 0.1, attention_dropout: float = 0.1,
...@@ -293,6 +302,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -293,6 +302,7 @@ class TransformerLayer(torch.nn.Module):
"layer_number": layer_number, "layer_number": layer_number,
"tp_group": tp_group, "tp_group": tp_group,
"tp_size": self.tp_size, "tp_size": self.tp_size,
"num_gqa_groups": num_gqa_groups,
"fuse_wgrad_accumulation": fuse_wgrad_accumulation, "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
"get_rng_state_tracker": get_rng_state_tracker, "get_rng_state_tracker": get_rng_state_tracker,
"sequence_parallel": self.sequence_parallel, "sequence_parallel": self.sequence_parallel,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment