Unverified Commit 66736890 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

Different dimension for attention (#833)



* Fixed Llama tutorial. Changed batch size and added fused=True.
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarroot <root@ipp2-0037.nvidia.com>

* Tutorial updated but not complete yet.
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarroot <root@ipp2-0037.nvidia.com>

* Tutorial notebook reseted - removed fuse=true
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarroot <root@ipp2-0037.nvidia.com>

* Removed fused=true
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarroot <root@ipp2-0037.nvidia.com>

* Batch size back to 8
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarroot <root@ipp2-0037.nvidia.com>

* Typo and commented out line
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarroot <root@ipp2-0037.nvidia.com>

* fixed whitespace
Signed-off-by: default avatarroot <root@ipp2-0037.nvidia.com>

* fixed whitespace
Signed-off-by: default avatarroot <root@ipp2-0037.nvidia.com>

* Added comment to attention line. Fixed potential bug with loading weights - now loading works correctly, confirmed by the generation code.
Signed-off-by: default avatarroot <root@ipp2-1661.nvidia.com>

* Comments
Signed-off-by: default avatarroot <root@ipp2-1661.nvidia.com>

* Models cast added again
Signed-off-by: default avatarroot <root@ipp2-1661.nvidia.com>

* Weight download info
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Moved parameter gate_proj_size to config
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* gate_proj_size removed and put immediate_size instead
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Llama 3 added to tutorial
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Typos fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Typos fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Fixed model loading
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Loading fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Different dim for attention
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Reversed other commit
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Changed name to kv_channels
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Fixed typo
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Back to kv_channels in transformer layer
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Back to kv_channels in transformer layer
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Small bug fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Small bug fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Test fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* changed file modes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* lint fix and resolved conflict
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* lint fix and resolved conflict
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Lint fix, hopefully last
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarroot <root@ipp2-0037.nvidia.com>
Signed-off-by: default avatarroot <root@ipp2-1661.nvidia.com>
Co-authored-by: default avatarroot <root@ipp2-2373.nvidia.com>
Co-authored-by: default avatarroot <root@ipp2-1588.nvidia.com>
Co-authored-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarroot <root@ipp2-0037.nvidia.com>
Co-authored-by: default avatarroot <root@ipp2-1661.nvidia.com>
Co-authored-by: default avatarroot <root@ipp2-2371.nvidia.com>
Co-authored-by: default avatarroot <root@ipp2-1589.nvidia.com>
Co-authored-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7c4887b2
......@@ -18,3 +18,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py
pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import transformer_engine.pytorch as te
batch_size = 32
seq_length = 2048
num_heads = 16
head_dim = 64
dtype = torch.bfloat16
num_attn_head = 16
ffn_hidden_size=1024
@pytest.mark.parametrize("kv_channels", [128, 256])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("num_gqa_groups", [1, 2, 4, 8, 16])
def test_gqa(
kv_channels,
hidden_size,
num_gqa_groups
) -> None:
model = te.TransformerLayer(
hidden_size,
ffn_hidden_size,
num_attn_head,
num_gqa_groups,
kv_channels=kv_channels
)
# Run forward pass
x = torch.randn((batch_size, 1, hidden_size)).cuda()
model(x)
# Check shapes of weights.
assert model.self_attention.layernorm_qkv.key_weight.shape[0] == kv_channels * num_gqa_groups
assert model.self_attention.layernorm_qkv.key_weight.shape[1] == hidden_size
assert model.self_attention.layernorm_qkv.query_weight.shape[0] == kv_channels * num_attn_head
assert model.self_attention.layernorm_qkv.query_weight.shape[1] == hidden_size
assert model.self_attention.layernorm_qkv.value_weight.shape[0] == kv_channels * num_gqa_groups
assert model.self_attention.layernorm_qkv.value_weight.shape[1] == hidden_size
assert model.self_attention.proj.weight.shape[0] == hidden_size
assert model.self_attention.proj.weight.shape[1] == kv_channels * num_attn_head
......@@ -3197,7 +3197,7 @@ class DotProductAttention(torch.nn.Module):
num_attention_heads : int
number of attention heads in the transformer layer.
kv_channels : int
number of key-value channels.
number of key-query-value channels per attention head.
num_gqa_groups : Optional[int] = None
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
......@@ -3302,6 +3302,7 @@ class DotProductAttention(torch.nn.Module):
self.cp_stream = cp_stream
self.hidden_size_per_attention_head = kv_channels
self.num_gqa_groups = (
num_attention_heads if num_gqa_groups is None else num_gqa_groups
)
......@@ -3318,7 +3319,7 @@ class DotProductAttention(torch.nn.Module):
set_all_rng_states(self.rng_states_tracker.get_states())
attention_dropout_ctx = self.rng_states_tracker.fork
norm_factor = math.sqrt(self.hidden_size_per_attention_head)
norm_factor = math.sqrt(kv_channels)
self.device_compute_capability = get_device_compute_capability()
self.deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) \
......@@ -3449,9 +3450,11 @@ class DotProductAttention(torch.nn.Module):
.. note::
Input tensors :attr:`query_layer`, :attr:`key_layer`, and :attr:`value_layer`
Input tensor :attr:`query_layer` must be of shape
(:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads`,
:attr:`kv_channels`) and the tensors :attr:`key_layer` and :attr:`value_layer`
must each be of shape (:attr:`sequence_length`, :attr:`batch_size`,
:attr:`num_attention_heads`, :attr:`kv_channels`). Output of shape
:attr:`num_gqa_groups`, :attr:`kv_channels`). Output of shape
(:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads`
* :attr:`kv_channels`) is returned.
......@@ -4131,7 +4134,7 @@ class MultiheadAttention(torch.nn.Module):
bias: bool = True,
normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda",
qkv_format: str = "sbhd",
qkv_format: str = "sbhd"
) -> None:
super().__init__()
......@@ -4168,7 +4171,6 @@ class MultiheadAttention(torch.nn.Module):
self.tp_size = tp_size
self.sequence_parallel = (tp_size > 1) and sequence_parallel
self.hidden_size_per_attention_head = kv_channels
self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size)
self.num_gqa_groups = (
num_attention_heads if num_gqa_groups is None else num_gqa_groups
......@@ -4178,7 +4180,10 @@ class MultiheadAttention(torch.nn.Module):
assert (self.num_gqa_groups % tp_size == 0
), "The number of GQA groups must be divisible by tensor parallel size!"
self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
self.hidden_size_kv = int(hidden_size * self.num_gqa_groups // num_attention_heads)
self.hidden_size_per_attention_head = kv_channels
self.hidden_size_q = self.hidden_size_per_attention_head * num_attention_heads
self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups
common_gemm_kwargs = {
"fuse_wgrad_accumulation": fuse_wgrad_accumulation,
......@@ -4196,14 +4201,14 @@ class MultiheadAttention(torch.nn.Module):
parameters_split = None
if not fuse_qkv_params:
parameters_split = collections.OrderedDict([
("query", hidden_size),
("query", self.hidden_size_q),
("key", self.hidden_size_kv),
("value", self.hidden_size_kv),
])
if self.input_layernorm:
self.layernorm_qkv = LayerNormLinear(
hidden_size,
hidden_size + 2 * self.hidden_size_kv,
self.hidden_size_q + 2 * self.hidden_size_kv,
eps=layernorm_epsilon,
init_method=init_method,
bias=bias,
......@@ -4223,7 +4228,7 @@ class MultiheadAttention(torch.nn.Module):
else:
self.qkv = Linear(
hidden_size,
hidden_size + 2 * self.hidden_size_kv,
self.hidden_size_q + 2 * self.hidden_size_kv,
init_method=init_method,
bias=bias,
return_bias=False,
......@@ -4235,7 +4240,7 @@ class MultiheadAttention(torch.nn.Module):
if self.input_layernorm:
self.layernorm_query = LayerNormLinear(
hidden_size,
hidden_size,
self.hidden_size_q,
eps=layernorm_epsilon,
init_method=init_method,
bias=bias,
......@@ -4255,7 +4260,7 @@ class MultiheadAttention(torch.nn.Module):
else:
self.query_layer = Linear(
hidden_size,
hidden_size,
self.hidden_size_q,
init_method=init_method,
bias=bias,
return_bias=False,
......@@ -4276,7 +4281,7 @@ class MultiheadAttention(torch.nn.Module):
# Attention.
self.core_attention = DotProductAttention(
num_attention_heads,
kv_channels,
self.hidden_size_per_attention_head,
num_gqa_groups=self.num_gqa_groups,
attention_dropout=attention_dropout,
qkv_format=self.qkv_format,
......@@ -4290,7 +4295,7 @@ class MultiheadAttention(torch.nn.Module):
# Linear
self.proj = Linear(
hidden_size,
self.hidden_size_q,
hidden_size,
init_method=output_layer_init_method,
bias=bias,
......
......@@ -129,7 +129,7 @@ class TransformerLayer(torch.nn.Module):
This can be used for structures like `T5` Transformer in conjunction with the
`encoder` option.
kv_channels: int, default = `None`
number of key-value channels. defaults to
number of query-key-value channels per attention head. defaults to
:attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
self_attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'arbitrary'},
default = `causal`
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment