Commit 748a9a7d authored by Jared Casper's avatar Jared Casper
Browse files

Add error checking by wrapping all weights and biases into named messages with named tensors.

parent 2755bcb8
......@@ -170,18 +170,20 @@ def _load_checkpoint(queue, args):
md.consumed_valid_samples = consumed_valid_samples
queue.put(md)
# Send embeddings
def queue_put(name, msg):
print(f"sending {name}")
msg["name"] = name
queue.put(msg)
word_embed = []
for tp_rank in range(tp_size):
if tp_rank == 0:
print("Sending position embeddings")
queue.put(models[tp_rank].language_model.embedding.position_embeddings.weight.data)
word_embed.append(models[tp_rank].language_model.embedding.word_embeddings.weight.data)
full_word_embed = torch.cat(word_embed, dim=0)
# Send embeddings
message = {
"position embeddings": models[0].language_model.embedding.position_embeddings.weight.data,
"word embeddings": torch.cat(
[models[tp_rank].language_model.embedding.word_embeddings.weight.data for tp_rank in range(tp_size)],
dim = 0)
}
print("Sending word embeddings")
queue.put(full_word_embed)
queue_put("embeddings", message)
total_layer_num = 0
for pp_rank in range(pp_size):
......@@ -190,23 +192,24 @@ def _load_checkpoint(queue, args):
post_process = pp_rank == pp_size - 1
models = get_models(tp_size, md.params_dtype, False, post_process)
for layer_num in range(len(models[0].language_model.encoder.layers)):
message = {}
# Get non-parallel tensors from tp_rank 0
layer = models[0].language_model.encoder.layers[layer_num]
message["input layernorm weight"] = layer.input_layernorm.weight.data
message["input layernorm bias"] = layer.input_layernorm.bias.data
message["dense bias"] = layer.self_attention.dense.bias.data
message["post layernorm weight"] = layer.post_attention_layernorm.weight.data
message["post layernorm bias"] = layer.post_attention_layernorm.bias.data
message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data
# Grab all parallel tensors for this layer
qkv_weight = []
qkv_bias = []
dense_weight = []
mlp_l0_weight = []
mlp_l0_bias = []
mlp_l1_weight = []
# Get non-parallel tensors from tp_rank 0
layer = models[0].language_model.encoder.layers[layer_num]
input_layernorm_weight = layer.input_layernorm.weight.data
input_layernorm_bias = layer.input_layernorm.bias.data
dense_bias = layer.self_attention.dense.bias.data
post_layernorm_weight = layer.post_attention_layernorm.weight.data
post_layernorm_bias = layer.post_attention_layernorm.bias.data
mlp_l1_bias = layer.mlp.dense_4h_to_h.bias.data
# Grab all parallel tensors for this layer
for tp_rank, model in enumerate(models):
layer = model.language_model.encoder.layers[layer_num]
qkv_weight.append(layer.self_attention.query_key_value.weight.data)
......@@ -216,47 +219,50 @@ def _load_checkpoint(queue, args):
mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data)
mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data)
# send everything in order while concatenating them
print(f"Sending layer {layer_num} of pipeline rank {pp_rank} (total layer {total_layer_num})")
queue.put(input_layernorm_weight)
queue.put(input_layernorm_bias)
queue.put(torch.cat(qkv_weight, dim=0))
queue.put(torch.cat(qkv_bias, dim=0))
queue.put(torch.cat(dense_weight, dim=1))
queue.put(dense_bias)
queue.put(post_layernorm_weight)
queue.put(post_layernorm_bias)
queue.put(torch.cat(mlp_l0_weight, dim=0))
queue.put(torch.cat(mlp_l0_bias, dim=0))
queue.put(torch.cat(mlp_l1_weight, dim=1))
queue.put(mlp_l1_bias)
# concat them
message["qkv weight"] = torch.cat(qkv_weight, dim=0)
message["qkv bias"] = torch.cat(qkv_bias, dim=0)
message["dense weight"] = torch.cat(dense_weight, dim=1)
message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0)
message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0)
message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1)
queue_put(f"transformer layer {total_layer_num}", message)
total_layer_num = total_layer_num + 1
# Send final layernorm from tp_rank 0
print("Sending final layernorm")
queue.put(models[0].language_model.encoder.final_layernorm.weight.data)
queue.put(models[0].language_model.encoder.final_layernorm.bias.data)
message = {
"weight": models[0].language_model.encoder.final_layernorm.weight.data,
"bias": models[0].language_model.encoder.final_layernorm.bias.data
}
queue_put("final layernorm", message)
# Send BERT lm head and binary head if it exists
if md.model_type == 'BERT':
print("Sending LM Pooler")
queue.put("pooler")
queue.put(models[0].language_model.pooler.dense.weight.data)
queue.put(models[0].language_model.pooler.dense.bias.data)
print("Sending BERT LM head")
queue.put("lm head")
queue.put(models[0].lm_head.dense.weight.data)
queue.put(models[0].lm_head.dense.bias.data)
queue.put(models[0].lm_head.layernorm.weight.data)
queue.put(models[0].lm_head.layernorm.bias.data)
message = {
"weight": models[0].language_model.pooler.dense.weight.data,
"bias": models[0].language_model.pooler.dense.bias.data
}
queue_put("pooler", message)
message = {
"dense weight": models[0].lm_head.dense.weight.data,
"dense bias": models[0].lm_head.dense.bias.data,
"layernorm weight": models[0].lm_head.layernorm.weight.data,
"layernorm bias": models[0].lm_head.layernorm.bias.data
}
queue_put("lm head", message)
if md.bert_binary_head:
print("Sending BERT Binary head")
queue.put("binary head")
queue.put(models[0].binary_head.weight.data)
queue.put(models[0].binary_head.bias.data)
message = {
"weight": models[0].binary_head.weight.data,
"bias": models[0].binary_head.bias.data
}
queue_put("binary head", message)
queue.put("done")
def load_checkpoint(queue, args):
......
import argparse
from collections.abc import Mapping
import concurrent.futures
import os
import sys
......@@ -38,13 +39,31 @@ def save_checkpoint(queue, args):
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
exit(1)
def queue_get():
def queue_get(name=None):
val = queue.get()
if val == "exit":
print("Loader exited, exiting saver")
exit(1)
if name is not None and args.checking and val["name"] != name:
val_name = val["name"]
print(f'Unexpected message. Expecting "{name}" but got "{val_name}". Exiting saver.')
exit(1)
if name is not None:
print(f"received {name}")
return val
def check_message(msg):
if not args.checking:
return
msg_name = msg.pop("name")
if len(msg.keys()) > 0:
print(f"Unexpected values in {msg_name}:")
for key in msg.keys():
print(f" {key}")
print(f"Exiting. If you want to ignore this, use the argument --no-checking.")
exit(1)
md = queue_get()
if args.target_tensor_parallel_size is None:
......@@ -141,8 +160,11 @@ def save_checkpoint(queue, args):
# Embeddings
#-----------
pos_embed = queue_get()
orig_word_embed = queue_get()
embeddings_msg = queue_get("embeddings")
pos_embed = embeddings_msg.pop("position embeddings")
orig_word_embed = embeddings_msg.pop("word embeddings")
check_message(embeddings_msg)
# Deal with padding
if md.true_vocab_size is not None:
......@@ -185,6 +207,7 @@ def save_checkpoint(queue, args):
# Transformer layers
#-------------------
total_layer_num = 0
for pp_rank in range(args.target_pipeline_parallel_size):
# For later pipeline parallel ranks, make the new models
if pp_rank > 0:
......@@ -193,47 +216,47 @@ def save_checkpoint(queue, args):
models = get_models(args.target_tensor_parallel_size, md.params_dtype, False, post_process)
for layer in range(len(models[0].language_model.encoder.layers)):
# get full tensors
input_layernorm_weight = queue_get()
input_layernorm_bias = queue_get()
full_qkv_weight = queue_get()
full_qkv_bias = queue_get()
full_dense_weight = queue_get()
dense_bias = queue_get()
post_layernorm_weight = queue_get()
post_layernorm_bias = queue_get()
full_mlp_l0_weight = queue_get()
full_mlp_l0_bias = queue_get()
full_mlp_l1_weight = queue_get()
mlp_l1_bias = queue_get()
msg = queue_get(f"transformer layer {total_layer_num}")
# duplicated tensors
input_layernorm_weight = msg.pop("input layernorm weight")
input_layernorm_bias = msg.pop("input layernorm bias")
dense_bias = msg.pop("dense bias")
post_layernorm_weight = msg.pop("post layernorm weight")
post_layernorm_bias = msg.pop("post layernorm bias")
mlp_l1_bias = msg.pop("mlp l1 bias")
# Split up the parallel tensors
out_qkv_weight = torch.chunk(full_qkv_weight, args.target_tensor_parallel_size, dim=0)
out_qkv_bias = torch.chunk(full_qkv_bias, args.target_tensor_parallel_size, dim=0)
out_dense_weight = torch.chunk(full_dense_weight, args.target_tensor_parallel_size, dim=1)
out_mlp_l0_weight = torch.chunk(full_mlp_l0_weight, args.target_tensor_parallel_size, dim=0)
out_mlp_l0_bias = torch.chunk(full_mlp_l0_bias, args.target_tensor_parallel_size, dim=0)
out_mlp_l1_weight = torch.chunk(full_mlp_l1_weight, args.target_tensor_parallel_size, dim=1)
qkv_weight = torch.chunk(msg.pop("qkv weight"), args.target_tensor_parallel_size, dim=0)
qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0)
dense_weight = torch.chunk(msg.pop("dense weight"), args.target_tensor_parallel_size, dim=1)
mlp_l0_weight = torch.chunk(msg.pop("mlp l0 weight"), args.target_tensor_parallel_size, dim=0)
mlp_l0_bias = torch.chunk(msg.pop("mlp l0 bias"), args.target_tensor_parallel_size, dim=0)
mlp_l1_weight = torch.chunk(msg.pop("mlp l1 weight"), args.target_tensor_parallel_size, dim=1)
# Save them to the model
for tp_rank in range(args.target_tensor_parallel_size):
l = models[tp_rank].language_model.encoder.layers[layer]
l.input_layernorm.weight.data.copy_(input_layernorm_weight)
l.input_layernorm.bias.data.copy_(input_layernorm_bias)
l.self_attention.query_key_value.weight.data.copy_(out_qkv_weight[tp_rank])
l.self_attention.query_key_value.bias.data.copy_(out_qkv_bias[tp_rank])
l.self_attention.dense.weight.data.copy_(out_dense_weight[tp_rank])
l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank])
l.self_attention.query_key_value.bias.data.copy_(qkv_bias[tp_rank])
l.self_attention.dense.weight.data.copy_(dense_weight[tp_rank])
l.self_attention.dense.bias.data.copy_(dense_bias)
l.post_attention_layernorm.weight.data.copy_(post_layernorm_weight)
l.post_attention_layernorm.bias.data.copy_(post_layernorm_bias)
l.mlp.dense_h_to_4h.weight.data.copy_(out_mlp_l0_weight[tp_rank])
l.mlp.dense_h_to_4h.bias.data.copy_(out_mlp_l0_bias[tp_rank])
l.mlp.dense_4h_to_h.weight.data.copy_(out_mlp_l1_weight[tp_rank])
l.mlp.dense_h_to_4h.weight.data.copy_(mlp_l0_weight[tp_rank])
l.mlp.dense_h_to_4h.bias.data.copy_(mlp_l0_bias[tp_rank])
l.mlp.dense_4h_to_h.weight.data.copy_(mlp_l1_weight[tp_rank])
l.mlp.dense_4h_to_h.bias.data.copy_(mlp_l1_bias)
total_layer_num = total_layer_num + 1
check_message(msg)
if post_process:
final_layernorm_weight = queue_get()
final_layernorm_bias = queue_get()
msg = queue_get("final layernorm")
final_layernorm_weight = msg.pop("weight")
final_layernorm_bias = msg.pop("bias")
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].language_model.encoder.final_layernorm.weight.data.copy_(final_layernorm_weight)
models[tp_rank].language_model.encoder.final_layernorm.bias.data.copy_(final_layernorm_bias)
......@@ -242,49 +265,56 @@ def save_checkpoint(queue, args):
models[tp_rank].word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
del final_layernorm_weight
del final_layernorm_bias
check_message(msg)
name = queue_get()
if name == "pooler":
msg = queue_get()
if msg != "done" and msg["name"] == "pooler":
if not hasattr(models[0].language_model, 'pooler'):
print("ERROR: got a pooler, but model does not have one")
exit(1)
pooler_weight = queue_get()
pooler_bias = queue_get()
print("received pooler")
pooler_weight = msg.pop("weight")
pooler_bias = msg.pop("bias")
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].language_model.pooler.dense.weight.data.copy_(pooler_weight)
models[tp_rank].language_model.pooler.dense.bias.data.copy_(pooler_bias)
name = queue_get()
del pooler_weight
del pooler_bias
check_message(msg)
msg = queue_get()
if name == "lm head":
if msg != "done" and msg["name"] == "lm head":
if not hasattr(models[0], 'lm_head'):
print("ERROR: got an lm head, but model does not have one")
exit(1)
lm_head_dense_weight = queue_get()
lm_head_dense_bias = queue_get()
lm_head_layernorm_weight = queue_get()
lm_head_layernorm_bias = queue_get()
print("received lm head")
lm_head_dense_weight = msg.pop("dense weight")
lm_head_dense_bias = msg.pop("dense bias")
lm_head_layernorm_weight = msg.pop("layernorm weight")
lm_head_layernorm_bias = msg.pop("layernorm bias")
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].lm_head.dense.weight.data.copy_(lm_head_dense_weight)
models[tp_rank].lm_head.dense.bias.data.copy_(lm_head_dense_bias)
models[tp_rank].lm_head.layernorm.weight.data.copy_(lm_head_layernorm_weight)
models[tp_rank].lm_head.layernorm.bias.data.copy_(lm_head_layernorm_bias)
name = queue_get()
check_message(msg)
msg = queue_get()
if name == "binary head":
if msg != "done" and msg["name"] == "binary head":
if not hasattr(models[0], 'binary_head'):
print("ERROR: got a binary head, but model does not have one")
exit(1)
binary_head_weight = queue_get()
binary_head_bias = queue_get()
print("received binary head")
binary_head_weight = msg.pop("weight")
binary_head_bias = msg.pop("bias")
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].binary_head.weight.data.copy_(binary_head_weight)
models[tp_rank].binary_head.bias.data.copy_(binary_head_bias)
name = queue_get()
check_message(msg)
msg = queue_get()
if name != "done":
print("ERROR: got some more data but were expecting to be done")
if msg != "done":
print("ERROR: got some more data but was expecting to be done")
for tp_rank in range(args.target_tensor_parallel_size):
mpu.initialize.set_tensor_model_parallel_rank(tp_rank)
......
......@@ -12,10 +12,12 @@ import sys
# load_checkpoint
# The loader and saver process are each given a queue, the loader
# should load the checkpoint and send the weights in the following
# order, the saver should receive them in this order and save the
# checkpoints. Note that the weight sent over the queue are the full
# model weights, nothing split.
# should load the checkpoint and send the weights in messages in the
# following order, the saver should receive them in this order and
# save the checkpoints. A message consists of a python dictionary with
# a "name" for error checking and an entry for each tensor as
# indicated below. Note that the weight sent over the queue are the
# full model weights, nothing split.
# If the loader ever sends "exit" to the queue, that means something
# went wrong and it is exiting.
......@@ -37,35 +39,51 @@ import sys
# make_vocab_size_divisble_by
# consumed_train_samples
# consumed_valid_samples
# - Position embeddings
# - Word embeddings
# - For each transformer layer:
# - input layernorm weights
# - input layernorm bias
# - qkv weight
# - qkv bias
# - dense weight
# - dense bias
# - post attention layernorm weight
# - post attention layernorm bias
# - mlp layer 0 (h to 4h) weight
# - mlp layer 0 (h to 4h) bias
# - mlp layer 1 (4h to h) weight
# - mlp layer 1 (4h to h) bias
# - final layer norm weight
# - final layer norm bias
# - if present (i.e. for BERT):
# - "pooler"
# - LM Pooler weight
# - LM Pooler bias
# - "lm head"
# - LM head dense weight
# - LM head dense bias
# - LM head layernorm weight
# - LM head layernorm bias
# - "binary head"
# - BERT Binary head weight
# - BERT Binary head bias
# messages
# {
# "name": "embeddings"
# "position embeddings"
# "word embeddings"
# }
# (for each transformer layer):
# {
# "name": "transformer layer N"
# "input layernorm weight"
# "input layernorm bias"
# "qkv weight"
# "qkv bias"
# "dense weight"
# "dense bias"
# "post layernorm weight"
# "post layernorm bias"
# "mlp l0 weight"
# "mlp l0 bias"
# "mlp l1 weight"
# "mlp l1 bias"
# }
# {
# "name": "final layer norm"
# "weight"
# "bias"
# }
# if present (i.e. for BERT):
# {
# "name": "pooler"
# "weight"
# "bias"
# }
# {
# "name": "lm head"
# "dense weight"
# "dense bias"
# "layernorm weight"
# "layernorm bias"
# }
# {
# "name": "binary head"
# "weight"
# "bias"
# }
# - "done"
def load_plugin(plugin_type, name):
......@@ -103,6 +121,9 @@ def main():
help='Directory to save model checkpoint to')
parser.add_argument('--max-queue-size', type=int, default=50,
help='Maximum number of tensors in the queue')
parser.add_argument('--no-checking', action='store_false',
help='Do not perform checking on the name and ordering of weights',
dest='checking')
known_args, _ = parser.parse_known_args()
loader = load_plugin('loader', known_args.loader)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment