Commit ec561daa authored by Jared Casper's avatar Jared Casper
Browse files

Better handling of padding in embedding table.

parent cdf0a5d4
import json
import os
import sys
import types
......@@ -7,6 +8,11 @@ import torch
def add_arguments(parser):
group = parser.add_argument_group(title='Megatron loader')
group.add_argument('--true-vocab-size', type=int, default=None,
help='original size of vocab, if specified will trim padding from embedding table.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file. If specified will use this to get vocab size and '
'trim padding from the embedding table.')
group.add_argument('--megatron-path', type=str, default=None,
help='Base directory of deepspeed repository')
......@@ -21,7 +27,7 @@ def _load_checkpoint(queue, args):
try:
from megatron.arguments import parse_args, validate_args
from megatron.global_vars import set_args, set_global_variables, rebuild_tokenizer
from megatron.global_vars import set_args, set_global_variables
from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint
from megatron.model import ModelType
from megatron import mpu, fused_kernels
......@@ -111,6 +117,19 @@ def _load_checkpoint(queue, args):
mpu.initialize.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
fused_kernels.load(margs)
# Get true (non-padded) vocab size
if args.true_vocab_size is not None:
true_vocab_size = args.true_vocab_size
elif args.vocab_file is not None:
vocab = json.load(open(args.vocab_file))
true_vocab_size = len(vocab)
if args.true_vocab_size is not None and true_vocab_size != args.true_vocab_size:
print("Both --true-vocab-size and --vocab-file specified and the vocab size does not match, aborting.")
queue.put("exit")
exit(1)
else:
true_vocab_size = None
# short aliases
tp_size = margs.tensor_model_parallel_size
pp_size = margs.pipeline_model_parallel_size
......@@ -129,6 +148,8 @@ def _load_checkpoint(queue, args):
md.bert_binary_head = margs.bert_binary_head
md.previous_tensor_parallel_size = margs.tensor_model_parallel_size
md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size
md.true_vocab_size = true_vocab_size
md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by
queue.put(md)
# Get first pipe stage
......@@ -137,6 +158,7 @@ def _load_checkpoint(queue, args):
models = get_models(tp_size, md.params_dtype, True, post_process)
# Send embeddings
word_embed = []
for tp_rank in range(tp_size):
if tp_rank == 0:
......@@ -144,6 +166,7 @@ def _load_checkpoint(queue, args):
queue.put(models[tp_rank].language_model.embedding.position_embeddings.weight.data)
word_embed.append(models[tp_rank].language_model.embedding.word_embeddings.weight.data)
full_word_embed = torch.cat(word_embed, dim=0)
print("Sending word embeddings")
queue.put(full_word_embed)
......
......@@ -31,6 +31,7 @@ def save_checkpoint(queue, args):
from megatron.checkpointing import save_checkpoint
from megatron.global_vars import set_global_variables, get_args
from megatron.model import ModelType
from megatron.tokenizer.tokenizer import _vocab_size_with_padding
from megatron import mpu, fused_kernels
except ModuleNotFoundError:
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
......@@ -91,6 +92,9 @@ def save_checkpoint(queue, args):
'--save-interval', '1',
'--save', args.save_dir
]
if md.make_vocab_size_divisible_by is not None:
sys.argv.extend(['--make-vocab-size-divisible-by', str(md.make_vocab_size_divisible_by)])
if md.params_dtype == torch.float16:
sys.argv.append('--fp16')
elif md.params_dtype == torch.bfloat16:
......@@ -127,13 +131,33 @@ def save_checkpoint(queue, args):
# Embeddings
#-----------
pos_embed = queue_get()
full_word_embed = queue_get()
orig_word_embed = queue_get()
# Tell Megatron what our full size is
margs.padded_vocab_size = full_word_embed.shape[0]
if margs.padded_vocab_size % args.target_tensor_parallel_size != 0:
print("source vocab size is not evenly divisble by target tensor parallel size")
exit(1)
# Deal with padding
if md.true_vocab_size is not None:
# figure out what our padded vocab size is
orig_vocab_size = orig_word_embed.shape[0]
margs.padded_vocab_size = _vocab_size_with_padding(md.true_vocab_size, margs)
# Cut out extra padding we don't need
if orig_vocab_size > margs.padded_vocab_size:
full_word_embed = orig_word_embed[0:margs.padded_vocab_size,:]
# Expanding embedding to larger size by replicating final entry
elif orig_vocab_size < margs.padded_vocab_size:
padding_size = margs.padded_vocab_size - orig_vocab_size
full_word_embed = torch.cat((
orig_word_embed,
orig_word_embed[-1].unsqueeze(0).expand(padding_size, -1)))
# Same size!
else:
full_word_embed = orig_word_embed
else:
print("Original vocab size not specified, leaving embedding table as-is. "
"If you've changed the tensor parallel size this could cause problems.")
full_word_embed = orig_word_embed
# Split into new tensor model parallel sizes
out_word_embed = torch.chunk(full_word_embed, args.target_tensor_parallel_size, dim=0)
......@@ -143,6 +167,7 @@ def save_checkpoint(queue, args):
post_process = args.target_pipeline_parallel_size == 1
models = get_models(args.target_tensor_parallel_size, md.params_dtype, True, post_process)
for tp_rank, model in enumerate(models):
print(f"word embeddings shape {model.language_model.embedding.word_embeddings.weight.shape}")
model.language_model.embedding.word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
model.language_model.embedding.position_embeddings.weight.data.copy_(pos_embed)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment