Commit 7bd25e26 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'next-best-lm/merge-rope-main' into 'main'

merging rope to main

See merge request ADLR/megatron-lm!556
parents f6d36d03 74bd02ec
...@@ -521,6 +521,14 @@ def _add_network_size_args(parser): ...@@ -521,6 +521,14 @@ def _add_network_size_args(parser):
group.add_argument('--max-position-embeddings', type=int, default=None, group.add_argument('--max-position-embeddings', type=int, default=None,
help='Maximum number of position embeddings to use. ' help='Maximum number of position embeddings to use. '
'This is the size of position embedding.') 'This is the size of position embedding.')
group.add_argument('--use-rotary-position-embeddings', action='store_true',
help='Use rotary positional embeddings or not')
group.add_argument('--rotary-percent', type=float, default=1.0,
help='Percent of rotary dimension to use, default 100%')
group.add_argument('--no-position-embedding',
action='store_false',
help='Disable position embedding.',
dest='add_position_embedding')
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128, group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
help='Pad the vocab size to be divisible by this value.' help='Pad the vocab size to be divisible by this value.'
'This is added for computational efficieny reasons.') 'This is added for computational efficieny reasons.')
......
...@@ -11,6 +11,7 @@ from megatron.core import mpu, tensor_parallel ...@@ -11,6 +11,7 @@ from megatron.core import mpu, tensor_parallel
from .enums import LayerType, AttnMaskType from .enums import LayerType, AttnMaskType
from .module import MegatronModule from .module import MegatronModule
from .retro_transformer import ParallelRetroEncoder, ParallelRetroTransformer from .retro_transformer import ParallelRetroEncoder, ParallelRetroTransformer
from .rotary_pos_embedding import apply_rotary_pos_emb, RotaryEmbedding
from .transformer import ParallelTransformer from .transformer import ParallelTransformer
from .utils import get_linear_layer from .utils import get_linear_layer
from .utils import init_method_normal, scaled_init_method_normal from .utils import init_method_normal, scaled_init_method_normal
...@@ -158,6 +159,8 @@ class Embedding(MegatronModule): ...@@ -158,6 +159,8 @@ class Embedding(MegatronModule):
self._word_embeddings_key = 'word_embeddings' self._word_embeddings_key = 'word_embeddings'
# Position embedding (serial). # Position embedding (serial).
self.add_position_embedding = args.add_position_embedding
if self.add_position_embedding:
self.position_embeddings = torch.nn.Embedding( self.position_embeddings = torch.nn.Embedding(
max_sequence_length, self.hidden_size) max_sequence_length, self.hidden_size)
self._position_embeddings_key = 'position_embeddings' self._position_embeddings_key = 'position_embeddings'
...@@ -188,6 +191,7 @@ class Embedding(MegatronModule): ...@@ -188,6 +191,7 @@ class Embedding(MegatronModule):
"""Zero out all parameters in embedding.""" """Zero out all parameters in embedding."""
self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True self.word_embeddings.weight.shared = True
if self.add_position_embedding:
self.position_embeddings.weight.data.fill_(0) self.position_embeddings.weight.data.fill_(0)
self.position_embeddings.weight.shared = True self.position_embeddings.weight.shared = True
if self.num_tokentypes > 0: if self.num_tokentypes > 0:
...@@ -214,8 +218,12 @@ class Embedding(MegatronModule): ...@@ -214,8 +218,12 @@ class Embedding(MegatronModule):
def forward(self, input_ids, position_ids, tokentype_ids=None): def forward(self, input_ids, position_ids, tokentype_ids=None):
# Embeddings. # Embeddings.
words_embeddings = self.word_embeddings(input_ids) words_embeddings = self.word_embeddings(input_ids)
if self.add_position_embedding:
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings embeddings = words_embeddings + position_embeddings
else:
embeddings = words_embeddings
if tokentype_ids is not None: if tokentype_ids is not None:
assert self.tokentype_embeddings is not None assert self.tokentype_embeddings is not None
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids) embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
...@@ -246,6 +254,7 @@ class Embedding(MegatronModule): ...@@ -246,6 +254,7 @@ class Embedding(MegatronModule):
state_dict_[self._word_embeddings_key] \ state_dict_[self._word_embeddings_key] \
= self.word_embeddings.state_dict(prefix=prefix, = self.word_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars) keep_vars=keep_vars)
if self.add_position_embedding:
state_dict_[self._position_embeddings_key] \ state_dict_[self._position_embeddings_key] \
= self.position_embeddings.state_dict(prefix=prefix, = self.position_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars) keep_vars=keep_vars)
...@@ -272,6 +281,7 @@ class Embedding(MegatronModule): ...@@ -272,6 +281,7 @@ class Embedding(MegatronModule):
self.word_embeddings.load_state_dict(state_dict_, strict=strict) self.word_embeddings.load_state_dict(state_dict_, strict=strict)
# Position embedding. # Position embedding.
if self.add_position_embedding:
if self._position_embeddings_key in state_dict: if self._position_embeddings_key in state_dict:
state_dict_ = state_dict[self._position_embeddings_key] state_dict_ = state_dict[self._position_embeddings_key]
else: else:
...@@ -351,6 +361,22 @@ class TransformerLanguageModel(MegatronModule): ...@@ -351,6 +361,22 @@ class TransformerLanguageModel(MegatronModule):
self.num_tokentypes) self.num_tokentypes)
self._embedding_key = 'embedding' self._embedding_key = 'embedding'
# Rotary positional embeddings
self.use_rotary_position_embeddings = \
args.use_rotary_position_embeddings
if args.use_rotary_position_embeddings:
self.seq_length = args.seq_length
rotary_dim = args.hidden_size // args.num_attention_heads \
if args.kv_channels is None else args.kv_channels
if args.rotary_percent < 1.0:
rotary_dim = int(rotary_dim * args.rotary_percent)
# partial rotary embeddings, which is better than full rotary
# Wang and Komatsuzaki et al
# https://github.com/kingoflolz/mesh-transformer-jax/
self.rotary_pos_emb = RotaryEmbedding(rotary_dim)
# Retriever (bi-directional transformer with cross attention) # Retriever (bi-directional transformer with cross attention)
if args.retro_add_retriever: if args.retro_add_retriever:
self.retriever = ParallelRetroEncoder( self.retriever = ParallelRetroEncoder(
...@@ -458,6 +484,15 @@ class TransformerLanguageModel(MegatronModule): ...@@ -458,6 +484,15 @@ class TransformerLanguageModel(MegatronModule):
else: else:
encoder_input = None encoder_input = None
# Rotary positional embeddings
rotary_pos_emb = None
if self.use_rotary_position_embeddings:
if inference_params is not None:
rotary_pos_emb = \
self.rotary_pos_emb(inference_params.max_sequence_len)
else:
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
# Run encoder. # Run encoder.
if enc_hidden_states is None: if enc_hidden_states is None:
if self.encoder is not None: if self.encoder is not None:
...@@ -472,7 +507,8 @@ class TransformerLanguageModel(MegatronModule): ...@@ -472,7 +507,8 @@ class TransformerLanguageModel(MegatronModule):
encoder_output = self.encoder( encoder_output = self.encoder(
encoder_input, encoder_input,
enc_attn_mask, enc_attn_mask,
inference_params=inference_params) inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb)
else: else:
encoder_output = self.encoder_hidden_state encoder_output = self.encoder_hidden_state
else: else:
...@@ -505,7 +541,8 @@ class TransformerLanguageModel(MegatronModule): ...@@ -505,7 +541,8 @@ class TransformerLanguageModel(MegatronModule):
dec_attn_mask, dec_attn_mask,
encoder_output=encoder_output, encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask, enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params) inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb)
if self.add_pooler and self.post_process: if self.add_pooler and self.post_process:
return decoder_output, encoder_output, pooled_output return decoder_output, encoder_output, pooled_output
......
# coding=utf-8
# The following code has been taken from https://github.com/NVIDIA/NeMo/blob/ \
# 782b4e1652aaa43c8be390d9db0dc89544afa080/nemo/collections/nlp/modules/ \
# common/megatron/rotary_pos_embedding.py
import importlib.util
import torch
from torch import einsum, nn
__all__ = ['RotaryEmbedding', 'apply_rotary_pos_emb']
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
if importlib.util.find_spec('einops') is None:
raise RuntimeError("einops is required for Rotary Embedding")
def forward(self, max_seq_len, offset=0):
seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset
freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq)
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
emb = torch.cat((freqs, freqs), dim=-1)
# emb [seq_length, .., dim]
from einops import rearrange
return rearrange(emb, 'n d -> n 1 1 d')
def _rotate_half(x):
"""
change sign so the last dimension becomes [-odd, +even]
"""
from einops import rearrange
x = rearrange(x, '... (j d) -> ... j d', j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(t, freqs):
"""
input tensor t is of shape [seq_length, ..., dim]
rotary positional embeding tensor freqs is of shape [seq_length, ..., dim]
check https://kexue.fm/archives/8265 for detailed formulas
"""
rot_dim = freqs.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
t = (t * freqs.cos()) + (_rotate_half(t) * freqs.sin())
return torch.cat((t, t_pass), dim=-1)
...@@ -14,6 +14,7 @@ from megatron.model.enums import AttnMaskType, LayerType, AttnType ...@@ -14,6 +14,7 @@ from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import LayerNorm from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.rotary_pos_embedding import apply_rotary_pos_emb
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
try: try:
...@@ -467,7 +468,8 @@ class ParallelAttention(MegatronModule): ...@@ -467,7 +468,8 @@ class ParallelAttention(MegatronModule):
**_args_to_kwargs()) **_args_to_kwargs())
def _checkpointed_attention_forward(self, query_layer, key_layer, def _checkpointed_attention_forward(self, query_layer, key_layer,
value_layer, attention_mask): value_layer, attention_mask,
rotary_pos_emb=None):
"""Forward method with activation checkpointing.""" """Forward method with activation checkpointing."""
def custom_forward(*inputs): def custom_forward(*inputs):
query_layer = inputs[0] query_layer = inputs[0]
...@@ -478,9 +480,13 @@ class ParallelAttention(MegatronModule): ...@@ -478,9 +480,13 @@ class ParallelAttention(MegatronModule):
value_layer, attention_mask) value_layer, attention_mask)
return output_ return output_
q_pos_emb, k_pos_emb = (None, None) if rotary_pos_emb is None \
else rotary_pos_emb
hidden_states = tensor_parallel.checkpoint( hidden_states = tensor_parallel.checkpoint(
custom_forward, custom_forward,
False, query_layer, key_layer, value_layer, attention_mask) False, query_layer, key_layer, value_layer, attention_mask,
q_pos_emb, k_pos_emb)
return hidden_states return hidden_states
...@@ -494,13 +500,14 @@ class ParallelAttention(MegatronModule): ...@@ -494,13 +500,14 @@ class ParallelAttention(MegatronModule):
device=torch.cuda.current_device()) device=torch.cuda.current_device())
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None): encoder_output=None, inference_params=None,
rotary_pos_emb=None):
# hidden_states: [sq, b, h] # hidden_states: [sq, b, h]
# ================================================= # =================================================
# Pre-allocate memory for key-values for inference. # Pre-allocate memory for key-values for inference.
# ================================================= # =================================================
is_first_step = False
if inference_params: if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict: if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len inf_max_seq_len = inference_params.max_sequence_len
...@@ -511,6 +518,7 @@ class ParallelAttention(MegatronModule): ...@@ -511,6 +518,7 @@ class ParallelAttention(MegatronModule):
inf_max_seq_len, inf_max_batch_size) inf_max_seq_len, inf_max_batch_size)
inference_params.key_value_memory_dict[self.layer_number] = ( inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory, inference_value_memory) inference_key_memory, inference_value_memory)
is_first_step = True
else: else:
inference_key_memory, inference_value_memory = \ inference_key_memory, inference_value_memory = \
inference_params.key_value_memory_dict[self.layer_number] inference_params.key_value_memory_dict[self.layer_number]
...@@ -559,6 +567,13 @@ class ParallelAttention(MegatronModule): ...@@ -559,6 +567,13 @@ class ParallelAttention(MegatronModule):
# Adjust key and value for inference # Adjust key and value for inference
# ================================== # ==================================
# duplicate the pos_emb for self attention
if rotary_pos_emb is not None:
if isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = rotary_pos_emb
else:
rotary_pos_emb = ((rotary_pos_emb,) * 2)
if inference_params: if inference_params:
batch_start = inference_params.batch_size_offset batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1) batch_end = batch_start + key_layer.size(1)
...@@ -576,10 +591,42 @@ class ParallelAttention(MegatronModule): ...@@ -576,10 +591,42 @@ class ParallelAttention(MegatronModule):
value_layer = inference_value_memory[ value_layer = inference_value_memory[
:sequence_end, batch_start:batch_end, ...] :sequence_end, batch_start:batch_end, ...]
# adjust the key rotary positional embedding
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
# need to cross check this condition during inference
# if not set_inference_key_value_memory:
if not is_first_step:
# In inference, we compute one token at a time.
# Select the correct positional embedding
# (only the last token in the sequence)
q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end]
else:
# In the first forward pass of inference,
# we use the entire provided prefix.
# q_pos_emb here has the rope embeddings of the entire
# prefix + to-be-generated output so
# we slice to just the prefix.
q_pos_emb = q_pos_emb[:sequence_end, :, :, :]
k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
rotary_pos_emb = (q_pos_emb, k_pos_emb)
# ================================== # ==================================
# core attention computation # core attention computation
# ================================== # ==================================
# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb)
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb)
# TODO, can apply positional embedding to value_layer so it has
# absolute positional embedding.
# otherwise, only relative positional embedding takes effect
# value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)
if not self.use_flash_attn: if not self.use_flash_attn:
if self.checkpoint_core_attention: if self.checkpoint_core_attention:
context_layer = self._checkpointed_attention_forward( context_layer = self._checkpointed_attention_forward(
...@@ -713,7 +760,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -713,7 +760,7 @@ class ParallelTransformerLayer(MegatronModule):
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None, encoder_output=None, enc_dec_attn_mask=None,
inference_params=None): inference_params=None, rotary_pos_emb=None):
# hidden_states: [s, b, h] # hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer. # Layer norm at the beginning of the transformer layer.
...@@ -723,7 +770,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -723,7 +770,8 @@ class ParallelTransformerLayer(MegatronModule):
self.self_attention( self.self_attention(
layernorm_output, layernorm_output,
attention_mask, attention_mask,
inference_params=inference_params) inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb)
# Residual connection. # Residual connection.
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
...@@ -1066,7 +1114,8 @@ class ParallelTransformer(MegatronModule): ...@@ -1066,7 +1114,8 @@ class ParallelTransformer(MegatronModule):
return self.layers[layer_number] return self.layers[layer_number]
def _checkpointed_forward(self, hidden_states, attention_mask, def _checkpointed_forward(self, hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask, is_first_microbatch): encoder_output, enc_dec_attn_mask,
rotary_pos_emb, is_first_microbatch):
"""Forward method with activation checkpointing.""" """Forward method with activation checkpointing."""
def custom(start, end, is_transformer_engine=False): def custom(start, end, is_transformer_engine=False):
def custom_forward(*args, **kwargs): def custom_forward(*args, **kwargs):
...@@ -1094,12 +1143,14 @@ class ParallelTransformer(MegatronModule): ...@@ -1094,12 +1143,14 @@ class ParallelTransformer(MegatronModule):
self.distribute_saved_activations, self.distribute_saved_activations,
tensor_parallel.get_cuda_rng_tracker, tensor_parallel.get_cuda_rng_tracker,
mpu.get_tensor_model_parallel_group(), mpu.get_tensor_model_parallel_group(),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output,
enc_dec_attn_mask, rotary_pos_emb)
else: else:
hidden_states = tensor_parallel.checkpoint( hidden_states = tensor_parallel.checkpoint(
custom(l, l + self.recompute_num_layers), custom(l, l + self.recompute_num_layers),
self.distribute_saved_activations, self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output,
enc_dec_attn_mask, rotary_pos_emb)
l += self.recompute_num_layers l += self.recompute_num_layers
...@@ -1115,19 +1166,23 @@ class ParallelTransformer(MegatronModule): ...@@ -1115,19 +1166,23 @@ class ParallelTransformer(MegatronModule):
self.distribute_saved_activations, self.distribute_saved_activations,
tensor_parallel.get_cuda_rng_tracker, tensor_parallel.get_cuda_rng_tracker,
mpu.get_tensor_model_parallel_group(), mpu.get_tensor_model_parallel_group(),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output,
enc_dec_attn_mask, rotary_pos_emb)
else: else:
hidden_states = tensor_parallel.checkpoint( hidden_states = tensor_parallel.checkpoint(
custom(l, l + 1), custom(l, l + 1),
self.distribute_saved_activations, self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output,
enc_dec_attn_mask, rotary_pos_emb)
else: else:
if self.transformer_impl == 'transformer_engine': if self.transformer_impl == 'transformer_engine':
hidden_states = custom(l, l + 1, is_transformer_engine=True)( hidden_states = custom(l, l + 1, is_transformer_engine=True)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output,
enc_dec_attn_mask, rotary_pos_emb)
else: else:
hidden_states = custom(l, l + 1)( hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output,
enc_dec_attn_mask, rotary_pos_emb)
else: else:
raise ValueError("Invalid activation recompute method.") raise ValueError("Invalid activation recompute method.")
...@@ -1145,7 +1200,7 @@ class ParallelTransformer(MegatronModule): ...@@ -1145,7 +1200,7 @@ class ParallelTransformer(MegatronModule):
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None, encoder_output=None, enc_dec_attn_mask=None,
inference_params=None): inference_params=None, rotary_pos_emb=None):
# hidden_states: [s, b, h] # hidden_states: [s, b, h]
# Checks. # Checks.
...@@ -1203,12 +1258,14 @@ class ParallelTransformer(MegatronModule): ...@@ -1203,12 +1258,14 @@ class ParallelTransformer(MegatronModule):
attention_mask, attention_mask,
encoder_output, encoder_output,
enc_dec_attn_mask, enc_dec_attn_mask,
rotary_pos_emb,
is_first_microbatch) is_first_microbatch)
else: else:
forward_kwargs = { forward_kwargs = {
'encoder_output': encoder_output, 'encoder_output': encoder_output,
'enc_dec_attn_mask': enc_dec_attn_mask, 'enc_dec_attn_mask': enc_dec_attn_mask,
'inference_params': inference_params, 'inference_params': inference_params,
'rotary_pos_emb': rotary_pos_emb,
} }
if self.transformer_impl == 'transformer_engine': if self.transformer_impl == 'transformer_engine':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment