Unverified Commit 37339478 authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

Refactoring attention.py part 1 (#1542)



* Create pytorch/dot_product_attention module and pytorch/d_p_a/utils.py
Move attention logging into a separate class in pytorch/d_p_a/utils.py
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Create FlashAttentionUtils class in pytorch/d_p_a/utils/py for versioning info
Move versioning info out of pytorch/attention.py
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Move AttentionParams and get_attention_backend from attention.py to d_p_a/utils.py
Fix tests and imports for the above refactor change
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Move get_qkv_layout(), get_full_mask(), get_alibi(), get_attention_quantizers() to d_p_a/utils.py
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Move tensor packing and unpacking helper functions from pyt/attention.py to d_p_a/utils.py
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Move cumulative seqlens and indices methods from pyt/attention.py to d_p_a/utils.py
Rename cumulative functions from using _cu_ to using _cumul_ to differentiate from CUDA cu calls protocol
Rename tensor packaging methods with leading underscore to make them as internal to file
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove unnecessary imports in pytorch/attention.py and d_p_a/utils.py
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Create d_p_a/inference.py and move InferenceParams from pyt/attention.py to it
Modify tests and other files to import InferenceParams correctly
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

Modify docs api for InferenceParams
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Create d_p_a/rope.py and move RoPE methods from  pytorch/attention.py to it
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Code cleanup
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix qa testing induced bug
Code clean up
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix incorrect pack_tensor arg type
Code clean up
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* nit: Resolve lint errors
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove typedef FAUtils for FlashAttentionUtils
Use attn_log instead of att_log
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

Fix lint error
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* nit: Fix the function name from get_cumul to the earlier get_cu
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* nit: Fix typos, explicit imports and remove extra comments
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

---------
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent c257bf31
......@@ -31,7 +31,7 @@ pyTorch
.. autoapiclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)
:members: forward, set_context_parallel_group, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.InferenceParams(max_batch_size, max_sequence_length)
.. autoapiclass:: transformer_engine.pytorch.dot_product_attention.inference.InferenceParams(max_batch_size, max_sequence_length)
.. autoapiclass:: transformer_engine.pytorch.CudaRNGStatesTracker()
:members: reset, get_states, set_states, add, fork
......
......@@ -458,7 +458,7 @@
" </tr>\n",
"</table>\n",
"\n",
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
"\n",
"<div class=\"alert alert-info\">\n",
"<b>Note</b>\n",
......
......@@ -11,7 +11,7 @@ import torch
from torch import nn
import transformer_engine as te
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
from transformer_engine.pytorch.dot_product_attention.rope import RotaryPositionEmbedding
from transformer_engine.pytorch.fp8 import fp8_model_init
import transformers
......
......@@ -18,15 +18,15 @@ from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model
from transformer_engine.pytorch.attention import (
DotProductAttention,
MultiheadAttention,
RotaryPositionEmbedding,
_attention_backends,
)
from transformer_engine.pytorch.dot_product_attention.utils import (
FlashAttentionUtils,
get_attention_backend,
_flash_attn_is_installed,
_flash_attn_2_3_plus,
_flash_attn_3_is_installed,
check_set_window_size,
AttentionParams,
_attention_backends,
)
from transformer_engine.pytorch.dot_product_attention.rope import RotaryPositionEmbedding
from transformer_engine.pytorch.constants import TE_DType
import transformer_engine.pytorch.cpp_extensions as ext
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
......@@ -191,9 +191,20 @@ def _get_attention_backends(
fp8=fp8,
fp8_meta=fp8_meta,
)
_, _, fused_attention_backend, _, available_backends = get_attention_backend(
attention_params
)
(
use_flash_attention,
use_fused_attention,
fused_attention_backend,
use_unfused_attention,
available_backends,
) = get_attention_backend(attention_params)
# Set attention.py _attention_backends var using return value
# from get_attention_backend()
_attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention
_attention_backends["backend_selection_requires_update"] = False
return available_backends, fused_attention_backend
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
......@@ -269,12 +280,12 @@ def test_dot_product_attention(
# mannually pads and unpads the input and output of FlashAttention for testing purposes
if (
pad_between_seqs
and _flash_attn_is_installed
and FlashAttentionUtils.is_installed
and not (
config.max_seqlen_q != config.max_seqlen_kv
and config.attn_mask_type in ["causal", "padding_causal"]
)
and (config.window_size[0] == -1 or _flash_attn_2_3_plus)
and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus)
):
flash_attn_supported = True
......@@ -581,7 +592,7 @@ model_configs_swa = {
}
@pytest.mark.skipif(not _flash_attn_2_3_plus, reason="Flash-attn 2.3+ is required.")
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa])
@pytest.mark.parametrize("model", model_configs_swa.keys())
......@@ -603,7 +614,7 @@ model_configs_alibi_slopes = {
}
@pytest.mark.skipif(not _flash_attn_2_3_plus, reason="Flash-attn 2.3+ is required.")
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes])
@pytest.mark.parametrize("model", model_configs_alibi_slopes.keys())
......@@ -1445,7 +1456,11 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
):
pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7")
if _flash_attn_3_is_installed and not is_training and "padding" not in config.attn_mask_type:
if (
FlashAttentionUtils.v3_is_installed
and not is_training
and "padding" not in config.attn_mask_type
):
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
......@@ -1471,7 +1486,11 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
rtol = 5e-1
rmse_tol = 0.15
logging.debug("========== {:^25s} ==========".format("forward output"))
if _flash_attn_3_is_installed and not is_training and "padding" not in config.attn_mask_type:
if (
FlashAttentionUtils.v3_is_installed
and not is_training
and "padding" not in config.attn_mask_type
):
_error(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
......@@ -1656,7 +1675,11 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
if _flash_attn_3_is_installed and not is_training and "padding" not in config.attn_mask_type:
if (
FlashAttentionUtils.v3_is_installed
and not is_training
and "padding" not in config.attn_mask_type
):
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
......@@ -1685,7 +1708,11 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
rmse_tol = 0.11
bwd_names = ["dq", "dk", "dv"]
logging.debug("========== {:^25s} ==========".format("forward output"))
if _flash_attn_3_is_installed and not is_training and "padding" not in config.attn_mask_type:
if (
FlashAttentionUtils.v3_is_installed
and not is_training
and "padding" not in config.attn_mask_type
):
_error(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
......
......@@ -7,15 +7,11 @@ import subprocess
import pytest
import torch
from transformer_engine.pytorch.attention import (
_flash_attn_2_plus,
_flash_attn_2_3_plus,
)
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
get_cudnn_version,
)
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils
from test_fused_attn import ModelConfig
model_configs_flash_attn = {
......@@ -54,7 +50,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
return args
@pytest.mark.skipif(not _flash_attn_2_plus, reason="Flash-attn 2.0+ is required.")
@pytest.mark.skipif(not FlashAttentionUtils.v2_plus, reason="Flash-attn 2.0+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
......
......@@ -5,7 +5,7 @@ import math
import pytest
import torch
from typing import Callable, Tuple, Union
from transformer_engine.pytorch.attention import (
from transformer_engine.pytorch.dot_product_attention.rope import (
RotaryPositionEmbedding,
apply_rotary_pos_emb,
)
......
......@@ -34,10 +34,10 @@ from transformer_engine.pytorch import (
RMSNorm,
TransformerLayer,
LayerNorm,
InferenceParams,
Fp8Padding,
Fp8Unpadding,
)
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
......
......@@ -89,8 +89,8 @@ from transformer_engine.pytorch.module import Fp8Padding, Fp8Unpadding
from transformer_engine.pytorch.module import initialize_ub
from transformer_engine.pytorch.module import destroy_ub
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention import MultiheadAttention
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.transformer import TransformerLayer
from transformer_engine.pytorch.permutation import (
moe_permute,
......
This diff is collapsed.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Python interface for dot product attention"""
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Inference classes for attention
"""
class InferenceParams: # pylint: disable=too-few-public-methods
"""
Inference parameters that are passed to the main model in order
to efficiently calculate and store the context during inference.
Parameters
----------
max_batch_size : int
maximum batch size during inference.
max_sequence_length : int
maximum sequence length during inference.
"""
def __init__(self, max_batch_size, max_sequence_length):
self.max_sequence_length = max_sequence_length
self.max_batch_size = max_batch_size
self.sequence_len_offset = 0
self.batch_size_offset = 0
self.key_value_memory_dict = {}
def swap_key_value_dict(self, batch_indices):
"""
Reorders the KV cache using the specified batch indices.
Parameters
----------
batch_indices : List[int]
Sequence of indices to reorder along the batch dimensions of
the KV cache. Must have a length equal to the batch size.
"""
if len(self.key_value_memory_dict) == 0:
raise ValueError("should not swap when dict in empty")
for layer_number, inference_memory in self.key_value_memory_dict.items():
inference_key_memory, inference_value_memory = inference_memory
assert (
len(batch_indices) == inference_key_memory.shape[1]
) # make sure batch size is the same
new_inference_key_memory = inference_key_memory[:, batch_indices]
new_inference_value_memory = inference_value_memory[:, batch_indices]
self.key_value_memory_dict[layer_number] = (
new_inference_key_memory,
new_inference_value_memory,
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Rotary Position Embedding implementation of different types along with helper functions
"""
from typing import Optional, Tuple, Union
import torch
import transformer_engine_torch as tex
class RotaryPositionEmbedding(torch.nn.Module):
"""
Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
"""
def __init__(
self,
dim: int,
rotary_percent: float = 1.0,
seq_len_interpolation_factor: Optional[int] = None,
pretrained_max_position_embeddings: Optional[int] = None,
rotary_base: float = 10000.0,
):
"""
Parameters
----------
dim: int
rotary embedding dimension
rotary_percent: float
Percent of rotary dimension to use for rotary position embeddings.
seq_len_interpolation_factor: int
if not None, discrete positions will be interpolated by this factor via the trick in
https://arxiv.org/abs/2306.15595
pretrained_max_position_embeddings: int
pre-trained max_position_embeddings before position interpolation
"""
super().__init__()
if rotary_percent < 1.0:
dim = int(dim * rotary_percent)
self.seq_len_interpolation_factor = seq_len_interpolation_factor
self.rotary_base = rotary_base
inv_freq = 1.0 / (
self.rotary_base
** (
torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device())
/ dim
)
)
self.register_buffer("inv_freq", inv_freq)
self.pretrained_max_position_embeddings = pretrained_max_position_embeddings
def forward(self, max_seq_len: int, offset: int = 0):
"""
Create rotary position embedding frequencies
Parameters
----------
max_seq_len: int
sequence length of a sample
offset: int, default = 0
fixed offset for freqencies
"""
seq = (
torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ offset
)
if (
self.pretrained_max_position_embeddings is not None
and self.seq_len_interpolation_factor is not None
):
if (
max_seq_len
> self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor
):
# dynamic linear scaling (length > position we have learned)
seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
else:
# fixed linear scaling
seq *= 1 / self.seq_len_interpolation_factor
freqs = torch.einsum("i , j -> i j", seq, self.inv_freq)
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
emb = torch.cat((freqs, freqs), dim=-1)
# emb [seq_length, .., dim]
return emb.reshape(emb.size(0), 1, 1, emb.size(1))
class FusedRoPEFunc(torch.autograd.Function):
"""
Function for FusedRoPE
This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and
the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid
the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern.
"""
@staticmethod
def forward(
ctx,
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd",
cu_seqlens: Union[torch.Tensor, None] = None,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
if freqs.dtype != torch.float32:
freqs = freqs.float()
if tensor_format == "sbhd":
output = tex.fused_rope_forward(t, freqs, False)
elif tensor_format == "bshd":
output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1)
elif tensor_format == "thd":
output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, cp_size, cp_rank)
else:
raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
ctx.save_for_backward(freqs, cu_seqlens)
ctx.tensor_format = tensor_format
ctx.cp_size = cp_size
ctx.cp_rank = cp_rank
return output
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
freqs, cu_seqlens = ctx.saved_tensors
if ctx.tensor_format == "sbhd":
grad_input = tex.fused_rope_backward(grad_output, freqs, False)
elif ctx.tensor_format == "bshd":
grad_input = tex.fused_rope_backward(
grad_output.transpose(0, 1), freqs, True
).transpose(0, 1)
elif ctx.tensor_format == "thd":
grad_input = tex.fused_rope_thd_backward(
grad_output, cu_seqlens, freqs, ctx.cp_size, ctx.cp_rank
)
else:
raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")
return grad_input, None, None, None, None, None
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
"""
change sign so the last dimension becomes [-odd, +even]
"""
x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2)))
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd",
fused: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
"""
Apply rotary positional embedding tensor to the input tensor.
Parameters
----------
t: torch.Tensor
Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
rotary positional embedding will be applied.
freqs: torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
fused: bool, default = False
Whether to use a fused applying RoPE implementation.
tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
cu_seqlens: torch.Tensor, default = None.
Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
dtype torch.int32. Only valid when `tensor_format` is 'thd'.
Should be `cu_seqlens_padded` when cp_size > 1.
cp_size: int, default = 1.
Context parallel world size. Only valid when `tensor_format` is 'thd' and `fused` is True.
cp_rank: int, default = 0.
Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True.
"""
if fused:
assert (
tensor_format != "thd" or cu_seqlens is not None
), "cu_seqlens must not be None when tensor_format is 'thd'."
return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens, cp_size, cp_rank)
assert tensor_format in ("sbhd", "bshd"), (
"Only formats `sbhd` or `bshd` are supported for input tensor `t` "
f"when fused is False, got {tensor_format}."
)
max_seq_len = freqs.shape[0]
cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]
# Only apply the rotary embeddings up to the sequence length of the running
# input.
assert (
cur_seq_len <= max_seq_len
), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
freqs = freqs[:cur_seq_len]
if tensor_format == "bshd":
freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim]
# cos/sin first then dtype conversion for better precision
cos_ = torch.cos(freqs).to(t.dtype)
sin_ = torch.sin(freqs).to(t.dtype)
rot_dim = freqs.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
t = (t * cos_) + (_rotate_half(t) * sin_)
return torch.cat((t, t_pass), dim=-1)
This diff is collapsed.
......@@ -12,10 +12,10 @@ import torch
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
from transformer_engine.pytorch.attention import (
InferenceParams,
MultiheadAttention,
check_set_window_size,
)
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.dot_product_attention.utils import check_set_window_size
from transformer_engine.pytorch.jit import (
set_jit_fusion_options,
warmup_jit_bias_dropout_add_all_dtypes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment