Unverified Commit 156a074a authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Rotary Position Embedding (#246)



* Rotary Position Embedding
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove einops
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Improve docstr
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 871fdf51
...@@ -41,6 +41,30 @@ _flash_attn_version_required = packaging.version.Version("1.0.2") ...@@ -41,6 +41,30 @@ _flash_attn_version_required = packaging.version.Version("1.0.2")
__all__ = ["DotProductAttention"] __all__ = ["DotProductAttention"]
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
"""
change sign so the last dimension becomes [-odd, +even]
"""
x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2)))
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
"""
input tensor t is of shape [seq_length, ..., dim]
rotary positional embeding tensor `freqs` is of shape [seq_length, ..., dim]
"""
rot_dim = freqs.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
t = (t * freqs.cos()) + (_rotate_half(t) * freqs.sin())
return torch.cat((t, t_pass), dim=-1)
class _SplitLastDim(torch.autograd.Function): class _SplitLastDim(torch.autograd.Function):
"""""" """"""
...@@ -722,6 +746,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -722,6 +746,7 @@ class MultiHeadAttention(torch.nn.Module):
is_first_microbatch: Optional[bool] = None, is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False, checkpoint_core_attention: bool = False,
inference_params: Optional[Any] = None, inference_params: Optional[Any] = None,
rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
"""MultiHeadAttention FWD""" """MultiHeadAttention FWD"""
# hidden_states: [sq, b, h] # hidden_states: [sq, b, h]
...@@ -735,6 +760,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -735,6 +760,7 @@ class MultiHeadAttention(torch.nn.Module):
# Pre-allocate memory for key-values for inference. # Pre-allocate memory for key-values for inference.
# ================================================= # =================================================
is_first_step = False
if inference_params and self.layer_number is not None: if inference_params and self.layer_number is not None:
if self.layer_number not in inference_params.key_value_memory_dict: if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len inf_max_seq_len = inference_params.max_sequence_len
...@@ -749,6 +775,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -749,6 +775,7 @@ class MultiHeadAttention(torch.nn.Module):
inference_key_memory, inference_key_memory,
inference_value_memory, inference_value_memory,
) )
is_first_step = True
else: else:
( (
inference_key_memory, inference_key_memory,
...@@ -861,6 +888,11 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -861,6 +888,11 @@ class MultiHeadAttention(torch.nn.Module):
# Adjust key and value for inference # Adjust key and value for inference
# ================================== # ==================================
# duplicate the pos_emb for self attention
if rotary_pos_emb is not None:
if not isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = ((rotary_pos_emb,) * 2)
if inference_params and self.layer_number is not None: if inference_params and self.layer_number is not None:
batch_start = inference_params.batch_size_offset batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1) batch_end = batch_start + key_layer.size(1)
...@@ -880,10 +912,36 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -880,10 +912,36 @@ class MultiHeadAttention(torch.nn.Module):
:sequence_end, batch_start:batch_end, ... :sequence_end, batch_start:batch_end, ...
] ]
# adjust the key rotary positional embedding
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
# need to cross check this condition during inference
# if not set_inference_key_value_memory:
if not is_first_step:
# In inference, we compute one token at a time.
# Select the correct positional embedding
# (only the last token in the sequence)
q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end]
else:
# In the first forward pass of inference,
# we use the entire provided prefix.
# q_pos_emb here has the rope embeddings of the entire
# prefix + to-be-generated output so
# we slice to just the prefix.
q_pos_emb = q_pos_emb[:sequence_end, :, :, :]
k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
rotary_pos_emb = (q_pos_emb, k_pos_emb)
# ================================== # ==================================
# core attention computation # core attention computation
# ================================== # ==================================
# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb)
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb)
context_layer = self.core_attention( context_layer = self.core_attention(
query_layer, query_layer,
key_layer, key_layer,
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import os import os
import warnings import warnings
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Tuple, Union
import torch import torch
...@@ -63,8 +63,6 @@ class DropPath(torch.nn.Module): ...@@ -63,8 +63,6 @@ class DropPath(torch.nn.Module):
return output return output
class TransformerLayer(torch.nn.Module): class TransformerLayer(torch.nn.Module):
r""" r"""
TransformerLayer is made up of an attention block and a feedforward network (MLP). TransformerLayer is made up of an attention block and a feedforward network (MLP).
...@@ -394,6 +392,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -394,6 +392,7 @@ class TransformerLayer(torch.nn.Module):
is_first_microbatch: Optional[bool] = None, is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False, checkpoint_core_attention: bool = False,
inference_params: Optional[Any] = None, inference_params: Optional[Any] = None,
rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Transformer Layer: attention block and a feedforward network (MLP) Transformer Layer: attention block and a feedforward network (MLP)
...@@ -433,6 +432,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -433,6 +432,9 @@ class TransformerLayer(torch.nn.Module):
during the backward pass in order to save memory that would during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until otherwise be occupied to store the forward activations until
backprop. backprop.
rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
Embeddings for query and key tensors for applying rotary position
embedding. By default no input embedding is applied.
""" """
hidden_states = hidden_states.contiguous() hidden_states = hidden_states.contiguous()
...@@ -460,6 +462,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -460,6 +462,7 @@ class TransformerLayer(torch.nn.Module):
inference_params=inference_params, inference_params=inference_params,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
rotary_pos_emb=rotary_pos_emb,
) )
if self.apply_residual_connection_post_layernorm and not self.output_layernorm: if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment