Unverified Commit 156a074a authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Rotary Position Embedding (#246)



* Rotary Position Embedding
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove einops
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Improve docstr
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 871fdf51
......@@ -41,6 +41,30 @@ _flash_attn_version_required = packaging.version.Version("1.0.2")
__all__ = ["DotProductAttention"]
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
"""
change sign so the last dimension becomes [-odd, +even]
"""
x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2)))
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
"""
input tensor t is of shape [seq_length, ..., dim]
rotary positional embeding tensor `freqs` is of shape [seq_length, ..., dim]
"""
rot_dim = freqs.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
t = (t * freqs.cos()) + (_rotate_half(t) * freqs.sin())
return torch.cat((t, t_pass), dim=-1)
class _SplitLastDim(torch.autograd.Function):
""""""
......@@ -722,6 +746,7 @@ class MultiHeadAttention(torch.nn.Module):
is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False,
inference_params: Optional[Any] = None,
rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""MultiHeadAttention FWD"""
# hidden_states: [sq, b, h]
......@@ -735,6 +760,7 @@ class MultiHeadAttention(torch.nn.Module):
# Pre-allocate memory for key-values for inference.
# =================================================
is_first_step = False
if inference_params and self.layer_number is not None:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
......@@ -749,6 +775,7 @@ class MultiHeadAttention(torch.nn.Module):
inference_key_memory,
inference_value_memory,
)
is_first_step = True
else:
(
inference_key_memory,
......@@ -861,6 +888,11 @@ class MultiHeadAttention(torch.nn.Module):
# Adjust key and value for inference
# ==================================
# duplicate the pos_emb for self attention
if rotary_pos_emb is not None:
if not isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = ((rotary_pos_emb,) * 2)
if inference_params and self.layer_number is not None:
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
......@@ -880,10 +912,36 @@ class MultiHeadAttention(torch.nn.Module):
:sequence_end, batch_start:batch_end, ...
]
# adjust the key rotary positional embedding
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
# need to cross check this condition during inference
# if not set_inference_key_value_memory:
if not is_first_step:
# In inference, we compute one token at a time.
# Select the correct positional embedding
# (only the last token in the sequence)
q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end]
else:
# In the first forward pass of inference,
# we use the entire provided prefix.
# q_pos_emb here has the rope embeddings of the entire
# prefix + to-be-generated output so
# we slice to just the prefix.
q_pos_emb = q_pos_emb[:sequence_end, :, :, :]
k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
rotary_pos_emb = (q_pos_emb, k_pos_emb)
# ==================================
# core attention computation
# ==================================
# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb)
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb)
context_layer = self.core_attention(
query_layer,
key_layer,
......
......@@ -6,7 +6,7 @@
import os
import warnings
from contextlib import nullcontext
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Tuple, Union
import torch
......@@ -63,8 +63,6 @@ class DropPath(torch.nn.Module):
return output
class TransformerLayer(torch.nn.Module):
r"""
TransformerLayer is made up of an attention block and a feedforward network (MLP).
......@@ -394,6 +392,7 @@ class TransformerLayer(torch.nn.Module):
is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False,
inference_params: Optional[Any] = None,
rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
) -> torch.Tensor:
"""
Transformer Layer: attention block and a feedforward network (MLP)
......@@ -433,6 +432,9 @@ class TransformerLayer(torch.nn.Module):
during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until
backprop.
rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
Embeddings for query and key tensors for applying rotary position
embedding. By default no input embedding is applied.
"""
hidden_states = hidden_states.contiguous()
......@@ -460,6 +462,7 @@ class TransformerLayer(torch.nn.Module):
inference_params=inference_params,
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
rotary_pos_emb=rotary_pos_emb,
)
if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment