Commit c63906a6 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'vijay/weight_mapping' into 'main'

Vijay/weight mapping

See merge request ADLR/megatron-lm!138
parents 577b4657 64e45f29
...@@ -27,6 +27,17 @@ from megatron import mpu, get_args ...@@ -27,6 +27,17 @@ from megatron import mpu, get_args
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
_CHECKPOINT_VERSION = None
def set_checkpoint_version(value):
global _CHECKPOINT_VERSION
assert _CHECKPOINT_VERSION is None, \
"checkpoint version already set"
_CHECKPOINT_VERSION = value
def get_checkpoint_version():
global _CHECKPOINT_VERSION
return _CHECKPOINT_VERSION
def check_checkpoint_args(checkpoint_args): def check_checkpoint_args(checkpoint_args):
"""Ensure fixed arguments for a model are the same for the input """Ensure fixed arguments for a model are the same for the input
...@@ -90,6 +101,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -90,6 +101,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
# Arguments, iteration, and model. # Arguments, iteration, and model.
state_dict = {} state_dict = {}
state_dict['args'] = args state_dict['args'] = args
state_dict['checkpoint_version'] = 1.0
state_dict['iteration'] = iteration state_dict['iteration'] = iteration
state_dict['model'] = model.state_dict_for_save_checkpoint() state_dict['model'] = model.state_dict_for_save_checkpoint()
...@@ -184,6 +196,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): ...@@ -184,6 +196,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
print_rank_0('could not load the checkpoint') print_rank_0('could not load the checkpoint')
sys.exit() sys.exit()
# set checkpoint version
set_checkpoint_version(state_dict.get('checkpoint_version', 0))
# Set iteration. # Set iteration.
if args.finetune or release: if args.finetune or release:
iteration = 0 iteration = 0
...@@ -198,6 +213,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): ...@@ -198,6 +213,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
'iteration from checkpoint {}, exiting'.format( 'iteration from checkpoint {}, exiting'.format(
checkpoint_name)) checkpoint_name))
sys.exit() sys.exit()
# Check arguments. # Check arguments.
if 'args' in state_dict: if 'args' in state_dict:
......
...@@ -23,6 +23,7 @@ from megatron import get_args ...@@ -23,6 +23,7 @@ from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.mpu import LayerNorm from megatron.mpu import LayerNorm
from megatron.module import MegatronModule from megatron.module import MegatronModule
from megatron.checkpointing import get_checkpoint_version
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import openai_gelu, erf_gelu from megatron.model.utils import openai_gelu, erf_gelu
...@@ -170,7 +171,23 @@ class ParallelSelfAttention(MegatronModule): ...@@ -170,7 +171,23 @@ class ParallelSelfAttention(MegatronModule):
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True) skip_bias_add=True)
def _transpose_last_dim(self, mixed_layer):
"""[s, b, 3 * hp] -->(view) [s, b, 3, hp] -->(tranpose)
[s, b, hp, 3] -->(view) [s, b, 3 * hp] """
input_shape = mixed_layer.size();
last_dim = input_shape[-1]
assert last_dim % 3 == 0, "expected QKV dimension"
last_dim_split = last_dim // 3
intermediate_shape = input_shape[:-1] +\
(3, last_dim_split)
mixed_layer = mixed_layer.view(*intermediate_shape)
mixed_layer = mixed_layer.transpose(-1, -2).contiguous()
mixed_layer = mixed_layer.view(*input_shape)
return mixed_layer
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False): get_key_value=False):
...@@ -180,20 +197,25 @@ class ParallelSelfAttention(MegatronModule): ...@@ -180,20 +197,25 @@ class ParallelSelfAttention(MegatronModule):
# Query, Key, and Value # Query, Key, and Value
# ===================== # =====================
# Attention heads [s, b, hp] --> [s, b, 3 * hp] # Attention heads [s, b, hp] --> [s, b, hp * 3]
mixed_x_layer, _ = self.query_key_value(hidden_states) mixed_x_layer, _ = self.query_key_value(hidden_states)
# [s, b, 3 * hp] --> [s, b, np, 3 * hn] checkpoint_version = get_checkpoint_version()
if checkpoint_version is not None and \
checkpoint_version == 0:
# [s, b, 3 * hp] --> [s, b, hp * 3]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer)
# [s, b, hp * 3] --> [s, b, np, hn, 3]
new_tensor_shape = mixed_x_layer.size()[:-1] + \ new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition, (self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head) self.hidden_size_per_attention_head, 3)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [s, b, np, 3 * hn] --> 3 [s, b, np, hn] # [s, b, np, hn, 3] --> 3 [s, b, np, hn]
(query_layer, query_layer = mixed_x_layer[:,:,:,:,0]
key_layer, key_layer = mixed_x_layer[:,:,:,:,1]
value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3) value_layer = mixed_x_layer[:,:,:,:,2]
# ================================== # ==================================
# Adjust key and value for inference # Adjust key and value for inference
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment