Commit ec561daa authored by Jared Casper's avatar Jared Casper
Browse files

Better handling of padding in embedding table.

parent cdf0a5d4
import json
import os import os
import sys import sys
import types import types
...@@ -7,6 +8,11 @@ import torch ...@@ -7,6 +8,11 @@ import torch
def add_arguments(parser): def add_arguments(parser):
group = parser.add_argument_group(title='Megatron loader') group = parser.add_argument_group(title='Megatron loader')
group.add_argument('--true-vocab-size', type=int, default=None,
help='original size of vocab, if specified will trim padding from embedding table.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file. If specified will use this to get vocab size and '
'trim padding from the embedding table.')
group.add_argument('--megatron-path', type=str, default=None, group.add_argument('--megatron-path', type=str, default=None,
help='Base directory of deepspeed repository') help='Base directory of deepspeed repository')
...@@ -21,7 +27,7 @@ def _load_checkpoint(queue, args): ...@@ -21,7 +27,7 @@ def _load_checkpoint(queue, args):
try: try:
from megatron.arguments import parse_args, validate_args from megatron.arguments import parse_args, validate_args
from megatron.global_vars import set_args, set_global_variables, rebuild_tokenizer from megatron.global_vars import set_args, set_global_variables
from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint
from megatron.model import ModelType from megatron.model import ModelType
from megatron import mpu, fused_kernels from megatron import mpu, fused_kernels
...@@ -111,6 +117,19 @@ def _load_checkpoint(queue, args): ...@@ -111,6 +117,19 @@ def _load_checkpoint(queue, args):
mpu.initialize.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size) mpu.initialize.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
fused_kernels.load(margs) fused_kernels.load(margs)
# Get true (non-padded) vocab size
if args.true_vocab_size is not None:
true_vocab_size = args.true_vocab_size
elif args.vocab_file is not None:
vocab = json.load(open(args.vocab_file))
true_vocab_size = len(vocab)
if args.true_vocab_size is not None and true_vocab_size != args.true_vocab_size:
print("Both --true-vocab-size and --vocab-file specified and the vocab size does not match, aborting.")
queue.put("exit")
exit(1)
else:
true_vocab_size = None
# short aliases # short aliases
tp_size = margs.tensor_model_parallel_size tp_size = margs.tensor_model_parallel_size
pp_size = margs.pipeline_model_parallel_size pp_size = margs.pipeline_model_parallel_size
...@@ -129,6 +148,8 @@ def _load_checkpoint(queue, args): ...@@ -129,6 +148,8 @@ def _load_checkpoint(queue, args):
md.bert_binary_head = margs.bert_binary_head md.bert_binary_head = margs.bert_binary_head
md.previous_tensor_parallel_size = margs.tensor_model_parallel_size md.previous_tensor_parallel_size = margs.tensor_model_parallel_size
md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size
md.true_vocab_size = true_vocab_size
md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by
queue.put(md) queue.put(md)
# Get first pipe stage # Get first pipe stage
...@@ -137,6 +158,7 @@ def _load_checkpoint(queue, args): ...@@ -137,6 +158,7 @@ def _load_checkpoint(queue, args):
models = get_models(tp_size, md.params_dtype, True, post_process) models = get_models(tp_size, md.params_dtype, True, post_process)
# Send embeddings # Send embeddings
word_embed = [] word_embed = []
for tp_rank in range(tp_size): for tp_rank in range(tp_size):
if tp_rank == 0: if tp_rank == 0:
...@@ -144,6 +166,7 @@ def _load_checkpoint(queue, args): ...@@ -144,6 +166,7 @@ def _load_checkpoint(queue, args):
queue.put(models[tp_rank].language_model.embedding.position_embeddings.weight.data) queue.put(models[tp_rank].language_model.embedding.position_embeddings.weight.data)
word_embed.append(models[tp_rank].language_model.embedding.word_embeddings.weight.data) word_embed.append(models[tp_rank].language_model.embedding.word_embeddings.weight.data)
full_word_embed = torch.cat(word_embed, dim=0) full_word_embed = torch.cat(word_embed, dim=0)
print("Sending word embeddings") print("Sending word embeddings")
queue.put(full_word_embed) queue.put(full_word_embed)
......
...@@ -31,6 +31,7 @@ def save_checkpoint(queue, args): ...@@ -31,6 +31,7 @@ def save_checkpoint(queue, args):
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.global_vars import set_global_variables, get_args from megatron.global_vars import set_global_variables, get_args
from megatron.model import ModelType from megatron.model import ModelType
from megatron.tokenizer.tokenizer import _vocab_size_with_padding
from megatron import mpu, fused_kernels from megatron import mpu, fused_kernels
except ModuleNotFoundError: except ModuleNotFoundError:
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.") print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
...@@ -91,6 +92,9 @@ def save_checkpoint(queue, args): ...@@ -91,6 +92,9 @@ def save_checkpoint(queue, args):
'--save-interval', '1', '--save-interval', '1',
'--save', args.save_dir '--save', args.save_dir
] ]
if md.make_vocab_size_divisible_by is not None:
sys.argv.extend(['--make-vocab-size-divisible-by', str(md.make_vocab_size_divisible_by)])
if md.params_dtype == torch.float16: if md.params_dtype == torch.float16:
sys.argv.append('--fp16') sys.argv.append('--fp16')
elif md.params_dtype == torch.bfloat16: elif md.params_dtype == torch.bfloat16:
...@@ -127,13 +131,33 @@ def save_checkpoint(queue, args): ...@@ -127,13 +131,33 @@ def save_checkpoint(queue, args):
# Embeddings # Embeddings
#----------- #-----------
pos_embed = queue_get() pos_embed = queue_get()
full_word_embed = queue_get() orig_word_embed = queue_get()
# Tell Megatron what our full size is # Deal with padding
margs.padded_vocab_size = full_word_embed.shape[0] if md.true_vocab_size is not None:
if margs.padded_vocab_size % args.target_tensor_parallel_size != 0: # figure out what our padded vocab size is
print("source vocab size is not evenly divisble by target tensor parallel size") orig_vocab_size = orig_word_embed.shape[0]
exit(1) margs.padded_vocab_size = _vocab_size_with_padding(md.true_vocab_size, margs)
# Cut out extra padding we don't need
if orig_vocab_size > margs.padded_vocab_size:
full_word_embed = orig_word_embed[0:margs.padded_vocab_size,:]
# Expanding embedding to larger size by replicating final entry
elif orig_vocab_size < margs.padded_vocab_size:
padding_size = margs.padded_vocab_size - orig_vocab_size
full_word_embed = torch.cat((
orig_word_embed,
orig_word_embed[-1].unsqueeze(0).expand(padding_size, -1)))
# Same size!
else:
full_word_embed = orig_word_embed
else:
print("Original vocab size not specified, leaving embedding table as-is. "
"If you've changed the tensor parallel size this could cause problems.")
full_word_embed = orig_word_embed
# Split into new tensor model parallel sizes # Split into new tensor model parallel sizes
out_word_embed = torch.chunk(full_word_embed, args.target_tensor_parallel_size, dim=0) out_word_embed = torch.chunk(full_word_embed, args.target_tensor_parallel_size, dim=0)
...@@ -143,6 +167,7 @@ def save_checkpoint(queue, args): ...@@ -143,6 +167,7 @@ def save_checkpoint(queue, args):
post_process = args.target_pipeline_parallel_size == 1 post_process = args.target_pipeline_parallel_size == 1
models = get_models(args.target_tensor_parallel_size, md.params_dtype, True, post_process) models = get_models(args.target_tensor_parallel_size, md.params_dtype, True, post_process)
for tp_rank, model in enumerate(models): for tp_rank, model in enumerate(models):
print(f"word embeddings shape {model.language_model.embedding.word_embeddings.weight.shape}")
model.language_model.embedding.word_embeddings.weight.data.copy_(out_word_embed[tp_rank]) model.language_model.embedding.word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
model.language_model.embedding.position_embeddings.weight.data.copy_(pos_embed) model.language_model.embedding.position_embeddings.weight.data.copy_(pos_embed)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment