"docs/getting_started/getting_started_utils_jax.py" did not exist on "c90a9214091badd1234b2d9ca851bd97f8edb0f6"
Unverified Commit 5d5fe819 authored by hXl3s's avatar hXl3s Committed by GitHub
Browse files

feat(pytorch): Allow TransformerLayer and MultiheadAttention to accept...


feat(pytorch): Allow TransformerLayer and MultiheadAttention to accept sequence length parameters (#1066)

* Added ability for seqlen for transformer and mha layer
Signed-off-by: default avatarLukasz Pierscieniewski <lukaszp@nvidia.com>

* Documentation for new parameters
Signed-off-by: default avatarLukasz Pierscieniewski <lukaszp@nvidia.com>

* Add tests for THD layout, assert for THD layout with KV-Cache
Signed-off-by: default avatarLukasz Pierscieniewski <lukaszp@nvidia.com>

* Fixed tests
Signed-off-by: default avatarLukasz Pierscieniewski <lukaszp@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Move THD logic in shape calculation, add missing optional in params
Signed-off-by: default avatarLukasz Pierscieniewski <lukaszp@nvidia.com>

* Skip the THD test on GPUs older than Ampere
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

---------
Signed-off-by: default avatarLukasz Pierscieniewski <lukaszp@nvidia.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent ee541e83
...@@ -34,11 +34,13 @@ from transformer_engine.pytorch import ( ...@@ -34,11 +34,13 @@ from transformer_engine.pytorch import (
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, fp8_grouped_gemm, gemm, grouped_gemm from transformer_engine.pytorch.cpp_extensions import fp8_gemm, fp8_grouped_gemm, gemm, grouped_gemm
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.pytorch.utils import get_device_compute_capability
import transformer_engine_torch as tex import transformer_engine_torch as tex
# Only run FP8 tests on H100. # Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
sm_80plus = get_device_compute_capability() >= (8, 0)
seed = 1234 seed = 1234
torch.manual_seed(seed) torch.manual_seed(seed)
...@@ -1548,8 +1550,29 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): ...@@ -1548,8 +1550,29 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
attn_input_format="bshd", attn_input_format="bshd",
) )
for (n1, p1), (n2, p2) in zip(block_bshd.named_parameters(), block_sbhd.named_parameters()): torch.manual_seed(0)
assert torch.all(torch.eq(p1, p2)), f"{n1}, {n2} not identical" block_thd = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0,
attention_dropout=0,
kv_channels=config.embed,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
device="cuda",
attn_input_format="thd",
self_attn_mask_type="padding_causal",
)
for (n1, p1), (n2, p2), (n3, p3) in zip(
block_bshd.named_parameters(), block_sbhd.named_parameters(), block_thd.named_parameters()
):
assert torch.all(torch.eq(p1, p2) & torch.eq(p1, p3)), f"{n1}, {n2} and {n3} not identical"
x_sbhd = torch.randn( x_sbhd = torch.randn(
(config.seq_len, bs, config.hidden_size), (config.seq_len, bs, config.hidden_size),
...@@ -1559,6 +1582,8 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): ...@@ -1559,6 +1582,8 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
) )
x_bshd = x_sbhd.transpose(0, 1).contiguous() x_bshd = x_sbhd.transpose(0, 1).contiguous()
x_thd = x_bshd.reshape(bs * config.seq_len, config.hidden_size).contiguous()
x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.seq_len
# To make sure forward is also identical (just in case some module decides # To make sure forward is also identical (just in case some module decides
# to act fancy) # to act fancy)
...@@ -1576,6 +1601,24 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): ...@@ -1576,6 +1601,24 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
y_sbhd.transpose(0, 1).contiguous(), y_sbhd.transpose(0, 1).contiguous(),
) )
# THD is not supported in float32 and on GPUs older than Ampere, skip the test here
if dtype != torch.float32 and sm_80plus:
# To make sure forward is also identical (just in case some module decides
# to act fancy)
torch.manual_seed(0)
y_thd = block_thd(
x_thd,
cu_seqlens_q=x_thd_cumsum,
cu_seqlens_kv=x_thd_cumsum,
max_seqlen_q=config.seq_len,
max_seqlen_kv=config.seq_len,
)
torch.testing.assert_close(
y_bshd,
y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(),
)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
......
...@@ -7048,6 +7048,10 @@ class MultiheadAttention(torch.nn.Module): ...@@ -7048,6 +7048,10 @@ class MultiheadAttention(torch.nn.Module):
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
""" """
...@@ -7113,6 +7117,18 @@ class MultiheadAttention(torch.nn.Module): ...@@ -7113,6 +7117,18 @@ class MultiheadAttention(torch.nn.Module):
ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads]. ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
to the attention score of query i and key j. to the attention score of query i and key j.
cu_seqlens_q: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32.
cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
max_seqlen_q: Optional[int], default = `None`
Maximum sequence length in `query_layer`.
Calculated from `cu_seqlens_q` if not provided.
max_seqlen_kv: Optional[int], default = `None`
Maximum sequence length in `key_layer` and `value_layer`.
Calculated from `cu_seqlens_kv` if not provided.
fast_zero_fill: bool, default = `True` fast_zero_fill: bool, default = `True`
Whether to set output tensors to 0 or not before use. Whether to set output tensors to 0 or not before use.
""" """
...@@ -7139,6 +7155,9 @@ class MultiheadAttention(torch.nn.Module): ...@@ -7139,6 +7155,9 @@ class MultiheadAttention(torch.nn.Module):
# ================================================= # =================================================
if inference_params and self.layer_number is not None: if inference_params and self.layer_number is not None:
assert (
self.qkv_format != "thd"
), "qkv_format == thd is not supported for an inference with KV-cache!"
if self.layer_number not in inference_params.key_value_memory_dict: if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_length inf_max_seq_len = inference_params.max_sequence_length
inf_max_batch_size = inference_params.max_batch_size inf_max_batch_size = inference_params.max_batch_size
...@@ -7221,13 +7240,18 @@ class MultiheadAttention(torch.nn.Module): ...@@ -7221,13 +7240,18 @@ class MultiheadAttention(torch.nn.Module):
dim=split_dim, dim=split_dim,
) )
if self.qkv_format == "thd":
query_layer, key_layer, value_layer = (
x.reshape(x.size(0), -1, self.hidden_size_per_attention_head)
for x in (query_layer, key_layer, value_layer)
)
else:
# query: -> [sq, b, np, hn] # query: -> [sq, b, np, hn]
# key, value: -> [sq, b, ng, hn] # key, value: -> [sq, b, ng, hn]
query_layer, key_layer, value_layer = ( query_layer, key_layer, value_layer = (
x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head) x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head)
for x in (query_layer, key_layer, value_layer) for x in (query_layer, key_layer, value_layer)
) )
elif self.attention_type == "cross": elif self.attention_type == "cross":
# Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)] # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)]
mixed_kv_layer = self.key_value( mixed_kv_layer = self.key_value(
...@@ -7341,8 +7365,10 @@ class MultiheadAttention(torch.nn.Module): ...@@ -7341,8 +7365,10 @@ class MultiheadAttention(torch.nn.Module):
key_layer, key_layer,
value_layer, value_layer,
qkv_format=self.qkv_format, qkv_format=self.qkv_format,
cu_seqlens_q=None, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=None, cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
attention_mask=attention_mask, attention_mask=attention_mask,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
window_size=window_size, window_size=window_size,
......
...@@ -13,6 +13,7 @@ from torch.nn import init ...@@ -13,6 +13,7 @@ from torch.nn import init
from .base import ( from .base import (
get_workspace, get_workspace,
_ub_communicators,
get_ub, get_ub,
TransformerEngineBaseModule, TransformerEngineBaseModule,
_2X_ACC_FPROP, _2X_ACC_FPROP,
...@@ -1297,7 +1298,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1297,7 +1298,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.gemm_gelu_fusion = ( self.gemm_gelu_fusion = (
bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0")))
and self.activation == "gelu" and self.activation == "gelu"
and not get_ub("fc1_fprop").is_atomic_gemm() and ((_ub_communicators is None) or (not get_ub("fc1_fprop").is_atomic_gemm()))
) )
if tp_group is None: if tp_group is None:
......
...@@ -529,6 +529,10 @@ class TransformerLayer(torch.nn.Module): ...@@ -529,6 +529,10 @@ class TransformerLayer(torch.nn.Module):
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -604,6 +608,18 @@ class TransformerLayer(torch.nn.Module): ...@@ -604,6 +608,18 @@ class TransformerLayer(torch.nn.Module):
ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads]. ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
to the attention score of query i and key j. to the attention score of query i and key j.
cu_seqlens_q: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32.
cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
max_seqlen_q: Optional[int], default = `None`
Maximum sequence length in `query_layer`.
Calculated from `cu_seqlens_q` if not provided.
max_seqlen_kv: Optional[int], default = `None`
Maximum sequence length in `key_layer` and `value_layer`.
Calculated from `cu_seqlens_kv` if not provided.
fast_zero_fill: bool, default = `True` fast_zero_fill: bool, default = `True`
Whether to set output tensors to 0 or not before use. Whether to set output tensors to 0 or not before use.
inference_params: InferenceParams, default = None inference_params: InferenceParams, default = None
...@@ -664,6 +680,10 @@ class TransformerLayer(torch.nn.Module): ...@@ -664,6 +680,10 @@ class TransformerLayer(torch.nn.Module):
core_attention_bias_type=core_attention_bias_type, core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
fast_zero_fill=fast_zero_fill, fast_zero_fill=fast_zero_fill,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment