Commit 76960d7c authored by Jared Casper's avatar Jared Casper
Browse files

Move rearranging query_key_value and key_value values in old checkpoints to...

Move rearranging query_key_value and key_value values in old checkpoints to when the checkpoint is loaded instead of runtime..
parent c7444380
...@@ -23,9 +23,10 @@ import numpy as np ...@@ -23,9 +23,10 @@ import numpy as np
import torch import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP from torch.nn.parallel import DistributedDataParallel as torchDDP
from megatron import mpu, get_args, update_num_microbatches from megatron import (get_args,
from megatron import get_args mpu,
from megatron import print_rank_0 print_rank_0,
update_num_microbatches)
_CHECKPOINT_VERSION = None _CHECKPOINT_VERSION = None
...@@ -163,6 +164,43 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -163,6 +164,43 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
torch.distributed.barrier() torch.distributed.barrier()
def _transpose_first_dim(t, num_splits, num_splits_first, model):
input_shape = t.size()
# We use a self_attention module but the values extracted aren't
# specific to self attention so should work for cross attention as well
while hasattr(model, 'module'):
model = model.module
attention_module = model.language_model.encoder.layers[0].self_attention
hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head
num_attention_heads_per_partition = attention_module.num_attention_heads_per_partition
if num_splits_first:
"""[num_splits * np * hn, h]
-->(view) [num_splits, np, hn, h]
-->(tranpose) [np, num_splits, hn, h]
-->(view) [np * num_splits * hn, h] """
intermediate_shape = \
(num_splits, num_attention_heads_per_partition,
hidden_size_per_attention_head) + input_shape[1:]
t = t.view(*intermediate_shape)
t = t.transpose(0, 1).contiguous()
else:
"""[np * hn * num_splits, h]
-->(view) [np, hn, num_splits, h]
-->(tranpose) [np, num_splits, hn, h]
-->(view) [np * num_splits * hn, h] """
intermediate_shape = \
(num_attention_heads_per_partition,
hidden_size_per_attention_head, num_splits) +\
input_shape[1:]
t = t.view(*intermediate_shape)
t = t.transpose(1, 2).contiguous()
t = t.view(*input_shape)
return t
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True): def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True):
"""Load a model checkpoint and return the iteration. """Load a model checkpoint and return the iteration.
...@@ -261,6 +299,29 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True ...@@ -261,6 +299,29 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
# Model. # Model.
model.load_state_dict(state_dict['model'], strict=strict) model.load_state_dict(state_dict['model'], strict=strict)
# Fix up query/key/value matrix ordering
if get_checkpoint_version() < 2.0:
checkpoint_version = get_checkpoint_version()
for name, param in model.named_parameters():
if name.endswith(('.query_key_value.weight', '.query_key_value.bias')):
if checkpoint_version == 0:
fixed_param = _transpose_first_dim(param.data, 3, True, model)
elif checkpoint_version == 1.0:
fixed_param = _transpose_first_dim(param.data, 3, False, model)
else:
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
sys.exit()
param.data.copy_(fixed_param)
if name.endswith(('.key_value.weight', '.key_value.bias'):
if checkpoint_version == 0:
fixed_param = _transpose_first_dim(param.data, 2, True, model)
elif checkpoint_version == 1.0:
fixed_param = _transpose_first_dim(param.data, 2, False, model)
else:
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
sys.exit()
param.data.copy_(fixed_param)
# Optimizer. # Optimizer.
if not release and not args.finetune and not args.no_load_optim: if not release and not args.finetune and not args.no_load_optim:
try: try:
......
...@@ -21,7 +21,6 @@ import torch.nn.functional as F ...@@ -21,7 +21,6 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from .module import MegatronModule from .module import MegatronModule
from megatron.checkpointing import get_checkpoint_version
from megatron.model.enums import AttnMaskType, LayerType, AttnType from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import import_layernorm from megatron.model import import_layernorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
...@@ -185,36 +184,6 @@ class ParallelAttention(MegatronModule): ...@@ -185,36 +184,6 @@ class ParallelAttention(MegatronModule):
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True) skip_bias_add=True)
def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
input_shape = mixed_layer.size()
if num_splits_first:
"""[s, b, num_splits * np * hn]
-->(view) [s, b, num_splits, np, hn]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
intermediate_shape = input_shape[:-1] +\
(num_splits, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
mixed_layer = mixed_layer.view(*intermediate_shape)
mixed_layer = mixed_layer.transpose(-2, -3).contiguous()
else:
"""[s, b, np * hn * num_splits]
-->(view) [s, b, np, hn, num_splits]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
intermediate_shape = input_shape[:-1] +\
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, num_splits)
mixed_layer = mixed_layer.view(*intermediate_shape)
mixed_layer = mixed_layer.transpose(-1, -2).contiguous()
mixed_layer = mixed_layer.view(*input_shape)
return mixed_layer
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False, encoder_output=None): get_key_value=False, encoder_output=None):
# hidden_states: [sq, b, h] # hidden_states: [sq, b, h]
...@@ -227,15 +196,6 @@ class ParallelAttention(MegatronModule): ...@@ -227,15 +196,6 @@ class ParallelAttention(MegatronModule):
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states) mixed_x_layer, _ = self.query_key_value(hidden_states)
checkpoint_version = get_checkpoint_version()
if checkpoint_version is not None:
if checkpoint_version == 0:
# [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True)
elif checkpoint_version == 1.0:
# [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \ new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition, (self.num_attention_heads_per_partition,
...@@ -250,15 +210,6 @@ class ParallelAttention(MegatronModule): ...@@ -250,15 +210,6 @@ class ParallelAttention(MegatronModule):
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output) mixed_kv_layer, _ = self.key_value(encoder_output)
checkpoint_version = get_checkpoint_version()
if checkpoint_version is not None:
if checkpoint_version == 0:
# [s, b, (2 * np * hn)] --> [s, b, (np * 2 * hn)]
mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, True)
elif checkpoint_version == 1.0:
# [s, b, (np * hn * 2)] --> [s, b, (np * 2 * hn)]
mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, False)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + \ new_tensor_shape = mixed_kv_layer.size()[:-1] + \
(self.num_attention_heads_per_partition, (self.num_attention_heads_per_partition,
......
...@@ -243,10 +243,6 @@ def main(): ...@@ -243,10 +243,6 @@ def main():
print(f'> loading {checkpoint_name} ...') print(f'> loading {checkpoint_name} ...')
load_checkpoint(model_, None, None) load_checkpoint(model_, None, None)
print(f'> checkpoint version {get_checkpoint_version()}') print(f'> checkpoint version {get_checkpoint_version()}')
if get_checkpoint_version() < 2.0:
# Need to deal with the qkv matrix order of old versions
print("Checkpoints less than version 2.0 are not currently supported.")
exit()
partitions.append(model_) partitions.append(model_)
# Parameter generators so we can loop through them semiltaneouly. # Parameter generators so we can loop through them semiltaneouly.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment