Commit 7bd25e26 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'next-best-lm/merge-rope-main' into 'main'

merging rope to main

See merge request ADLR/megatron-lm!556
parents f6d36d03 74bd02ec
......@@ -521,6 +521,14 @@ def _add_network_size_args(parser):
group.add_argument('--max-position-embeddings', type=int, default=None,
help='Maximum number of position embeddings to use. '
'This is the size of position embedding.')
group.add_argument('--use-rotary-position-embeddings', action='store_true',
help='Use rotary positional embeddings or not')
group.add_argument('--rotary-percent', type=float, default=1.0,
help='Percent of rotary dimension to use, default 100%')
group.add_argument('--no-position-embedding',
action='store_false',
help='Disable position embedding.',
dest='add_position_embedding')
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
help='Pad the vocab size to be divisible by this value.'
'This is added for computational efficieny reasons.')
......
......@@ -11,6 +11,7 @@ from megatron.core import mpu, tensor_parallel
from .enums import LayerType, AttnMaskType
from .module import MegatronModule
from .retro_transformer import ParallelRetroEncoder, ParallelRetroTransformer
from .rotary_pos_embedding import apply_rotary_pos_emb, RotaryEmbedding
from .transformer import ParallelTransformer
from .utils import get_linear_layer
from .utils import init_method_normal, scaled_init_method_normal
......@@ -158,12 +159,14 @@ class Embedding(MegatronModule):
self._word_embeddings_key = 'word_embeddings'
# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(
max_sequence_length, self.hidden_size)
self._position_embeddings_key = 'position_embeddings'
# Initialize the position embeddings.
if args.perform_initialization:
self.init_method(self.position_embeddings.weight)
self.add_position_embedding = args.add_position_embedding
if self.add_position_embedding:
self.position_embeddings = torch.nn.Embedding(
max_sequence_length, self.hidden_size)
self._position_embeddings_key = 'position_embeddings'
# Initialize the position embeddings.
if args.perform_initialization:
self.init_method(self.position_embeddings.weight)
# Token type embedding.
# Add this as an optional field that can be added through
......@@ -188,8 +191,9 @@ class Embedding(MegatronModule):
"""Zero out all parameters in embedding."""
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
self.position_embeddings.weight.data.fill_(0)
self.position_embeddings.weight.shared = True
if self.add_position_embedding:
self.position_embeddings.weight.data.fill_(0)
self.position_embeddings.weight.shared = True
if self.num_tokentypes > 0:
self.tokentype_embeddings.weight.data.fill_(0)
self.tokentype_embeddings.weight.shared = True
......@@ -214,8 +218,12 @@ class Embedding(MegatronModule):
def forward(self, input_ids, position_ids, tokentype_ids=None):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
if self.add_position_embedding:
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
else:
embeddings = words_embeddings
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
......@@ -246,8 +254,9 @@ class Embedding(MegatronModule):
state_dict_[self._word_embeddings_key] \
= self.word_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars)
state_dict_[self._position_embeddings_key] \
= self.position_embeddings.state_dict(prefix=prefix,
if self.add_position_embedding:
state_dict_[self._position_embeddings_key] \
= self.position_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars)
if self.num_tokentypes > 0:
state_dict_[self._tokentype_embeddings_key] \
......@@ -272,16 +281,17 @@ class Embedding(MegatronModule):
self.word_embeddings.load_state_dict(state_dict_, strict=strict)
# Position embedding.
if self._position_embeddings_key in state_dict:
state_dict_ = state_dict[self._position_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'position_embeddings' in key:
state_dict_[key.split('position_embeddings.')[1]] \
= state_dict[key]
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
if self.add_position_embedding:
if self._position_embeddings_key in state_dict:
state_dict_ = state_dict[self._position_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'position_embeddings' in key:
state_dict_[key.split('position_embeddings.')[1]] \
= state_dict[key]
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
# Tokentype embedding.
if self.num_tokentypes > 0:
......@@ -351,6 +361,22 @@ class TransformerLanguageModel(MegatronModule):
self.num_tokentypes)
self._embedding_key = 'embedding'
# Rotary positional embeddings
self.use_rotary_position_embeddings = \
args.use_rotary_position_embeddings
if args.use_rotary_position_embeddings:
self.seq_length = args.seq_length
rotary_dim = args.hidden_size // args.num_attention_heads \
if args.kv_channels is None else args.kv_channels
if args.rotary_percent < 1.0:
rotary_dim = int(rotary_dim * args.rotary_percent)
# partial rotary embeddings, which is better than full rotary
# Wang and Komatsuzaki et al
# https://github.com/kingoflolz/mesh-transformer-jax/
self.rotary_pos_emb = RotaryEmbedding(rotary_dim)
# Retriever (bi-directional transformer with cross attention)
if args.retro_add_retriever:
self.retriever = ParallelRetroEncoder(
......@@ -458,6 +484,15 @@ class TransformerLanguageModel(MegatronModule):
else:
encoder_input = None
# Rotary positional embeddings
rotary_pos_emb = None
if self.use_rotary_position_embeddings:
if inference_params is not None:
rotary_pos_emb = \
self.rotary_pos_emb(inference_params.max_sequence_len)
else:
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
# Run encoder.
if enc_hidden_states is None:
if self.encoder is not None:
......@@ -472,7 +507,8 @@ class TransformerLanguageModel(MegatronModule):
encoder_output = self.encoder(
encoder_input,
enc_attn_mask,
inference_params=inference_params)
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb)
else:
encoder_output = self.encoder_hidden_state
else:
......@@ -505,7 +541,8 @@ class TransformerLanguageModel(MegatronModule):
dec_attn_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb)
if self.add_pooler and self.post_process:
return decoder_output, encoder_output, pooled_output
......
# coding=utf-8
# The following code has been taken from https://github.com/NVIDIA/NeMo/blob/ \
# 782b4e1652aaa43c8be390d9db0dc89544afa080/nemo/collections/nlp/modules/ \
# common/megatron/rotary_pos_embedding.py
import importlib.util
import torch
from torch import einsum, nn
__all__ = ['RotaryEmbedding', 'apply_rotary_pos_emb']
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
if importlib.util.find_spec('einops') is None:
raise RuntimeError("einops is required for Rotary Embedding")
def forward(self, max_seq_len, offset=0):
seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset
freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq)
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
emb = torch.cat((freqs, freqs), dim=-1)
# emb [seq_length, .., dim]
from einops import rearrange
return rearrange(emb, 'n d -> n 1 1 d')
def _rotate_half(x):
"""
change sign so the last dimension becomes [-odd, +even]
"""
from einops import rearrange
x = rearrange(x, '... (j d) -> ... j d', j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(t, freqs):
"""
input tensor t is of shape [seq_length, ..., dim]
rotary positional embeding tensor freqs is of shape [seq_length, ..., dim]
check https://kexue.fm/archives/8265 for detailed formulas
"""
rot_dim = freqs.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
t = (t * freqs.cos()) + (_rotate_half(t) * freqs.sin())
return torch.cat((t, t_pass), dim=-1)
......@@ -14,6 +14,7 @@ from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.rotary_pos_embedding import apply_rotary_pos_emb
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
try:
......@@ -467,7 +468,8 @@ class ParallelAttention(MegatronModule):
**_args_to_kwargs())
def _checkpointed_attention_forward(self, query_layer, key_layer,
value_layer, attention_mask):
value_layer, attention_mask,
rotary_pos_emb=None):
"""Forward method with activation checkpointing."""
def custom_forward(*inputs):
query_layer = inputs[0]
......@@ -478,9 +480,13 @@ class ParallelAttention(MegatronModule):
value_layer, attention_mask)
return output_
q_pos_emb, k_pos_emb = (None, None) if rotary_pos_emb is None \
else rotary_pos_emb
hidden_states = tensor_parallel.checkpoint(
custom_forward,
False, query_layer, key_layer, value_layer, attention_mask)
False, query_layer, key_layer, value_layer, attention_mask,
q_pos_emb, k_pos_emb)
return hidden_states
......@@ -494,13 +500,14 @@ class ParallelAttention(MegatronModule):
device=torch.cuda.current_device())
def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None):
encoder_output=None, inference_params=None,
rotary_pos_emb=None):
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
is_first_step = False
if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
......@@ -511,6 +518,7 @@ class ParallelAttention(MegatronModule):
inf_max_seq_len, inf_max_batch_size)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory, inference_value_memory)
is_first_step = True
else:
inference_key_memory, inference_value_memory = \
inference_params.key_value_memory_dict[self.layer_number]
......@@ -559,6 +567,13 @@ class ParallelAttention(MegatronModule):
# Adjust key and value for inference
# ==================================
# duplicate the pos_emb for self attention
if rotary_pos_emb is not None:
if isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = rotary_pos_emb
else:
rotary_pos_emb = ((rotary_pos_emb,) * 2)
if inference_params:
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
......@@ -576,10 +591,42 @@ class ParallelAttention(MegatronModule):
value_layer = inference_value_memory[
:sequence_end, batch_start:batch_end, ...]
# adjust the key rotary positional embedding
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
# need to cross check this condition during inference
# if not set_inference_key_value_memory:
if not is_first_step:
# In inference, we compute one token at a time.
# Select the correct positional embedding
# (only the last token in the sequence)
q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end]
else:
# In the first forward pass of inference,
# we use the entire provided prefix.
# q_pos_emb here has the rope embeddings of the entire
# prefix + to-be-generated output so
# we slice to just the prefix.
q_pos_emb = q_pos_emb[:sequence_end, :, :, :]
k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
rotary_pos_emb = (q_pos_emb, k_pos_emb)
# ==================================
# core attention computation
# ==================================
# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb)
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb)
# TODO, can apply positional embedding to value_layer so it has
# absolute positional embedding.
# otherwise, only relative positional embedding takes effect
# value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)
if not self.use_flash_attn:
if self.checkpoint_core_attention:
context_layer = self._checkpointed_attention_forward(
......@@ -713,7 +760,7 @@ class ParallelTransformerLayer(MegatronModule):
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
inference_params=None, rotary_pos_emb=None):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
......@@ -723,7 +770,8 @@ class ParallelTransformerLayer(MegatronModule):
self.self_attention(
layernorm_output,
attention_mask,
inference_params=inference_params)
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
......@@ -1066,7 +1114,8 @@ class ParallelTransformer(MegatronModule):
return self.layers[layer_number]
def _checkpointed_forward(self, hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask, is_first_microbatch):
encoder_output, enc_dec_attn_mask,
rotary_pos_emb, is_first_microbatch):
"""Forward method with activation checkpointing."""
def custom(start, end, is_transformer_engine=False):
def custom_forward(*args, **kwargs):
......@@ -1094,12 +1143,14 @@ class ParallelTransformer(MegatronModule):
self.distribute_saved_activations,
tensor_parallel.get_cuda_rng_tracker,
mpu.get_tensor_model_parallel_group(),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
hidden_states, attention_mask, encoder_output,
enc_dec_attn_mask, rotary_pos_emb)
else:
hidden_states = tensor_parallel.checkpoint(
custom(l, l + self.recompute_num_layers),
self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
hidden_states, attention_mask, encoder_output,
enc_dec_attn_mask, rotary_pos_emb)
l += self.recompute_num_layers
......@@ -1115,19 +1166,23 @@ class ParallelTransformer(MegatronModule):
self.distribute_saved_activations,
tensor_parallel.get_cuda_rng_tracker,
mpu.get_tensor_model_parallel_group(),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
hidden_states, attention_mask, encoder_output,
enc_dec_attn_mask, rotary_pos_emb)
else:
hidden_states = tensor_parallel.checkpoint(
custom(l, l + 1),
self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
hidden_states, attention_mask, encoder_output,
enc_dec_attn_mask, rotary_pos_emb)
else:
if self.transformer_impl == 'transformer_engine':
hidden_states = custom(l, l + 1, is_transformer_engine=True)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
hidden_states, attention_mask, encoder_output,
enc_dec_attn_mask, rotary_pos_emb)
else:
hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
hidden_states, attention_mask, encoder_output,
enc_dec_attn_mask, rotary_pos_emb)
else:
raise ValueError("Invalid activation recompute method.")
......@@ -1145,7 +1200,7 @@ class ParallelTransformer(MegatronModule):
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
inference_params=None, rotary_pos_emb=None):
# hidden_states: [s, b, h]
# Checks.
......@@ -1203,12 +1258,14 @@ class ParallelTransformer(MegatronModule):
attention_mask,
encoder_output,
enc_dec_attn_mask,
rotary_pos_emb,
is_first_microbatch)
else:
forward_kwargs = {
'encoder_output': encoder_output,
'enc_dec_attn_mask': enc_dec_attn_mask,
'inference_params': inference_params,
'rotary_pos_emb': rotary_pos_emb,
}
if self.transformer_impl == 'transformer_engine':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment