Commit 748a9a7d authored by Jared Casper's avatar Jared Casper
Browse files

Add error checking by wrapping all weights and biases into named messages with named tensors.

parent 2755bcb8
...@@ -170,18 +170,20 @@ def _load_checkpoint(queue, args): ...@@ -170,18 +170,20 @@ def _load_checkpoint(queue, args):
md.consumed_valid_samples = consumed_valid_samples md.consumed_valid_samples = consumed_valid_samples
queue.put(md) queue.put(md)
# Send embeddings def queue_put(name, msg):
print(f"sending {name}")
msg["name"] = name
queue.put(msg)
word_embed = [] # Send embeddings
for tp_rank in range(tp_size): message = {
if tp_rank == 0: "position embeddings": models[0].language_model.embedding.position_embeddings.weight.data,
print("Sending position embeddings") "word embeddings": torch.cat(
queue.put(models[tp_rank].language_model.embedding.position_embeddings.weight.data) [models[tp_rank].language_model.embedding.word_embeddings.weight.data for tp_rank in range(tp_size)],
word_embed.append(models[tp_rank].language_model.embedding.word_embeddings.weight.data) dim = 0)
full_word_embed = torch.cat(word_embed, dim=0) }
print("Sending word embeddings") queue_put("embeddings", message)
queue.put(full_word_embed)
total_layer_num = 0 total_layer_num = 0
for pp_rank in range(pp_size): for pp_rank in range(pp_size):
...@@ -190,23 +192,24 @@ def _load_checkpoint(queue, args): ...@@ -190,23 +192,24 @@ def _load_checkpoint(queue, args):
post_process = pp_rank == pp_size - 1 post_process = pp_rank == pp_size - 1
models = get_models(tp_size, md.params_dtype, False, post_process) models = get_models(tp_size, md.params_dtype, False, post_process)
for layer_num in range(len(models[0].language_model.encoder.layers)): for layer_num in range(len(models[0].language_model.encoder.layers)):
message = {}
# Get non-parallel tensors from tp_rank 0
layer = models[0].language_model.encoder.layers[layer_num]
message["input layernorm weight"] = layer.input_layernorm.weight.data
message["input layernorm bias"] = layer.input_layernorm.bias.data
message["dense bias"] = layer.self_attention.dense.bias.data
message["post layernorm weight"] = layer.post_attention_layernorm.weight.data
message["post layernorm bias"] = layer.post_attention_layernorm.bias.data
message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data
# Grab all parallel tensors for this layer
qkv_weight = [] qkv_weight = []
qkv_bias = [] qkv_bias = []
dense_weight = [] dense_weight = []
mlp_l0_weight = [] mlp_l0_weight = []
mlp_l0_bias = [] mlp_l0_bias = []
mlp_l1_weight = [] mlp_l1_weight = []
# Get non-parallel tensors from tp_rank 0
layer = models[0].language_model.encoder.layers[layer_num]
input_layernorm_weight = layer.input_layernorm.weight.data
input_layernorm_bias = layer.input_layernorm.bias.data
dense_bias = layer.self_attention.dense.bias.data
post_layernorm_weight = layer.post_attention_layernorm.weight.data
post_layernorm_bias = layer.post_attention_layernorm.bias.data
mlp_l1_bias = layer.mlp.dense_4h_to_h.bias.data
# Grab all parallel tensors for this layer
for tp_rank, model in enumerate(models): for tp_rank, model in enumerate(models):
layer = model.language_model.encoder.layers[layer_num] layer = model.language_model.encoder.layers[layer_num]
qkv_weight.append(layer.self_attention.query_key_value.weight.data) qkv_weight.append(layer.self_attention.query_key_value.weight.data)
...@@ -216,47 +219,50 @@ def _load_checkpoint(queue, args): ...@@ -216,47 +219,50 @@ def _load_checkpoint(queue, args):
mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data) mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data)
mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data) mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data)
# send everything in order while concatenating them # concat them
print(f"Sending layer {layer_num} of pipeline rank {pp_rank} (total layer {total_layer_num})") message["qkv weight"] = torch.cat(qkv_weight, dim=0)
queue.put(input_layernorm_weight) message["qkv bias"] = torch.cat(qkv_bias, dim=0)
queue.put(input_layernorm_bias) message["dense weight"] = torch.cat(dense_weight, dim=1)
queue.put(torch.cat(qkv_weight, dim=0)) message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0)
queue.put(torch.cat(qkv_bias, dim=0)) message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0)
queue.put(torch.cat(dense_weight, dim=1)) message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1)
queue.put(dense_bias)
queue.put(post_layernorm_weight) queue_put(f"transformer layer {total_layer_num}", message)
queue.put(post_layernorm_bias)
queue.put(torch.cat(mlp_l0_weight, dim=0))
queue.put(torch.cat(mlp_l0_bias, dim=0))
queue.put(torch.cat(mlp_l1_weight, dim=1))
queue.put(mlp_l1_bias)
total_layer_num = total_layer_num + 1 total_layer_num = total_layer_num + 1
# Send final layernorm from tp_rank 0 # Send final layernorm from tp_rank 0
print("Sending final layernorm") message = {
queue.put(models[0].language_model.encoder.final_layernorm.weight.data) "weight": models[0].language_model.encoder.final_layernorm.weight.data,
queue.put(models[0].language_model.encoder.final_layernorm.bias.data) "bias": models[0].language_model.encoder.final_layernorm.bias.data
}
queue_put("final layernorm", message)
# Send BERT lm head and binary head if it exists # Send BERT lm head and binary head if it exists
if md.model_type == 'BERT': if md.model_type == 'BERT':
print("Sending LM Pooler") print("Sending LM Pooler")
queue.put("pooler") message = {
queue.put(models[0].language_model.pooler.dense.weight.data) "weight": models[0].language_model.pooler.dense.weight.data,
queue.put(models[0].language_model.pooler.dense.bias.data) "bias": models[0].language_model.pooler.dense.bias.data
}
print("Sending BERT LM head") queue_put("pooler", message)
queue.put("lm head")
queue.put(models[0].lm_head.dense.weight.data) message = {
queue.put(models[0].lm_head.dense.bias.data) "dense weight": models[0].lm_head.dense.weight.data,
queue.put(models[0].lm_head.layernorm.weight.data) "dense bias": models[0].lm_head.dense.bias.data,
queue.put(models[0].lm_head.layernorm.bias.data) "layernorm weight": models[0].lm_head.layernorm.weight.data,
"layernorm bias": models[0].lm_head.layernorm.bias.data
}
queue_put("lm head", message)
if md.bert_binary_head: if md.bert_binary_head:
print("Sending BERT Binary head") print("Sending BERT Binary head")
queue.put("binary head") queue.put("binary head")
queue.put(models[0].binary_head.weight.data) message = {
queue.put(models[0].binary_head.bias.data) "weight": models[0].binary_head.weight.data,
"bias": models[0].binary_head.bias.data
}
queue_put("binary head", message)
queue.put("done") queue.put("done")
def load_checkpoint(queue, args): def load_checkpoint(queue, args):
......
import argparse import argparse
from collections.abc import Mapping
import concurrent.futures import concurrent.futures
import os import os
import sys import sys
...@@ -38,13 +39,31 @@ def save_checkpoint(queue, args): ...@@ -38,13 +39,31 @@ def save_checkpoint(queue, args):
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.") print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
exit(1) exit(1)
def queue_get(): def queue_get(name=None):
val = queue.get() val = queue.get()
if val == "exit": if val == "exit":
print("Loader exited, exiting saver") print("Loader exited, exiting saver")
exit(1) exit(1)
if name is not None and args.checking and val["name"] != name:
val_name = val["name"]
print(f'Unexpected message. Expecting "{name}" but got "{val_name}". Exiting saver.')
exit(1)
if name is not None:
print(f"received {name}")
return val return val
def check_message(msg):
if not args.checking:
return
msg_name = msg.pop("name")
if len(msg.keys()) > 0:
print(f"Unexpected values in {msg_name}:")
for key in msg.keys():
print(f" {key}")
print(f"Exiting. If you want to ignore this, use the argument --no-checking.")
exit(1)
md = queue_get() md = queue_get()
if args.target_tensor_parallel_size is None: if args.target_tensor_parallel_size is None:
...@@ -141,8 +160,11 @@ def save_checkpoint(queue, args): ...@@ -141,8 +160,11 @@ def save_checkpoint(queue, args):
# Embeddings # Embeddings
#----------- #-----------
pos_embed = queue_get() embeddings_msg = queue_get("embeddings")
orig_word_embed = queue_get()
pos_embed = embeddings_msg.pop("position embeddings")
orig_word_embed = embeddings_msg.pop("word embeddings")
check_message(embeddings_msg)
# Deal with padding # Deal with padding
if md.true_vocab_size is not None: if md.true_vocab_size is not None:
...@@ -185,6 +207,7 @@ def save_checkpoint(queue, args): ...@@ -185,6 +207,7 @@ def save_checkpoint(queue, args):
# Transformer layers # Transformer layers
#------------------- #-------------------
total_layer_num = 0
for pp_rank in range(args.target_pipeline_parallel_size): for pp_rank in range(args.target_pipeline_parallel_size):
# For later pipeline parallel ranks, make the new models # For later pipeline parallel ranks, make the new models
if pp_rank > 0: if pp_rank > 0:
...@@ -193,47 +216,47 @@ def save_checkpoint(queue, args): ...@@ -193,47 +216,47 @@ def save_checkpoint(queue, args):
models = get_models(args.target_tensor_parallel_size, md.params_dtype, False, post_process) models = get_models(args.target_tensor_parallel_size, md.params_dtype, False, post_process)
for layer in range(len(models[0].language_model.encoder.layers)): for layer in range(len(models[0].language_model.encoder.layers)):
# get full tensors msg = queue_get(f"transformer layer {total_layer_num}")
input_layernorm_weight = queue_get()
input_layernorm_bias = queue_get() # duplicated tensors
full_qkv_weight = queue_get() input_layernorm_weight = msg.pop("input layernorm weight")
full_qkv_bias = queue_get() input_layernorm_bias = msg.pop("input layernorm bias")
full_dense_weight = queue_get() dense_bias = msg.pop("dense bias")
dense_bias = queue_get() post_layernorm_weight = msg.pop("post layernorm weight")
post_layernorm_weight = queue_get() post_layernorm_bias = msg.pop("post layernorm bias")
post_layernorm_bias = queue_get() mlp_l1_bias = msg.pop("mlp l1 bias")
full_mlp_l0_weight = queue_get()
full_mlp_l0_bias = queue_get()
full_mlp_l1_weight = queue_get()
mlp_l1_bias = queue_get()
# Split up the parallel tensors # Split up the parallel tensors
out_qkv_weight = torch.chunk(full_qkv_weight, args.target_tensor_parallel_size, dim=0) qkv_weight = torch.chunk(msg.pop("qkv weight"), args.target_tensor_parallel_size, dim=0)
out_qkv_bias = torch.chunk(full_qkv_bias, args.target_tensor_parallel_size, dim=0) qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0)
out_dense_weight = torch.chunk(full_dense_weight, args.target_tensor_parallel_size, dim=1) dense_weight = torch.chunk(msg.pop("dense weight"), args.target_tensor_parallel_size, dim=1)
out_mlp_l0_weight = torch.chunk(full_mlp_l0_weight, args.target_tensor_parallel_size, dim=0) mlp_l0_weight = torch.chunk(msg.pop("mlp l0 weight"), args.target_tensor_parallel_size, dim=0)
out_mlp_l0_bias = torch.chunk(full_mlp_l0_bias, args.target_tensor_parallel_size, dim=0) mlp_l0_bias = torch.chunk(msg.pop("mlp l0 bias"), args.target_tensor_parallel_size, dim=0)
out_mlp_l1_weight = torch.chunk(full_mlp_l1_weight, args.target_tensor_parallel_size, dim=1) mlp_l1_weight = torch.chunk(msg.pop("mlp l1 weight"), args.target_tensor_parallel_size, dim=1)
# Save them to the model # Save them to the model
for tp_rank in range(args.target_tensor_parallel_size): for tp_rank in range(args.target_tensor_parallel_size):
l = models[tp_rank].language_model.encoder.layers[layer] l = models[tp_rank].language_model.encoder.layers[layer]
l.input_layernorm.weight.data.copy_(input_layernorm_weight) l.input_layernorm.weight.data.copy_(input_layernorm_weight)
l.input_layernorm.bias.data.copy_(input_layernorm_bias) l.input_layernorm.bias.data.copy_(input_layernorm_bias)
l.self_attention.query_key_value.weight.data.copy_(out_qkv_weight[tp_rank]) l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank])
l.self_attention.query_key_value.bias.data.copy_(out_qkv_bias[tp_rank]) l.self_attention.query_key_value.bias.data.copy_(qkv_bias[tp_rank])
l.self_attention.dense.weight.data.copy_(out_dense_weight[tp_rank]) l.self_attention.dense.weight.data.copy_(dense_weight[tp_rank])
l.self_attention.dense.bias.data.copy_(dense_bias) l.self_attention.dense.bias.data.copy_(dense_bias)
l.post_attention_layernorm.weight.data.copy_(post_layernorm_weight) l.post_attention_layernorm.weight.data.copy_(post_layernorm_weight)
l.post_attention_layernorm.bias.data.copy_(post_layernorm_bias) l.post_attention_layernorm.bias.data.copy_(post_layernorm_bias)
l.mlp.dense_h_to_4h.weight.data.copy_(out_mlp_l0_weight[tp_rank]) l.mlp.dense_h_to_4h.weight.data.copy_(mlp_l0_weight[tp_rank])
l.mlp.dense_h_to_4h.bias.data.copy_(out_mlp_l0_bias[tp_rank]) l.mlp.dense_h_to_4h.bias.data.copy_(mlp_l0_bias[tp_rank])
l.mlp.dense_4h_to_h.weight.data.copy_(out_mlp_l1_weight[tp_rank]) l.mlp.dense_4h_to_h.weight.data.copy_(mlp_l1_weight[tp_rank])
l.mlp.dense_4h_to_h.bias.data.copy_(mlp_l1_bias) l.mlp.dense_4h_to_h.bias.data.copy_(mlp_l1_bias)
total_layer_num = total_layer_num + 1
check_message(msg)
if post_process: if post_process:
final_layernorm_weight = queue_get() msg = queue_get("final layernorm")
final_layernorm_bias = queue_get() final_layernorm_weight = msg.pop("weight")
final_layernorm_bias = msg.pop("bias")
for tp_rank in range(args.target_tensor_parallel_size): for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].language_model.encoder.final_layernorm.weight.data.copy_(final_layernorm_weight) models[tp_rank].language_model.encoder.final_layernorm.weight.data.copy_(final_layernorm_weight)
models[tp_rank].language_model.encoder.final_layernorm.bias.data.copy_(final_layernorm_bias) models[tp_rank].language_model.encoder.final_layernorm.bias.data.copy_(final_layernorm_bias)
...@@ -242,49 +265,56 @@ def save_checkpoint(queue, args): ...@@ -242,49 +265,56 @@ def save_checkpoint(queue, args):
models[tp_rank].word_embeddings.weight.data.copy_(out_word_embed[tp_rank]) models[tp_rank].word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
del final_layernorm_weight del final_layernorm_weight
del final_layernorm_bias del final_layernorm_bias
check_message(msg)
name = queue_get() msg = queue_get()
if name == "pooler": if msg != "done" and msg["name"] == "pooler":
if not hasattr(models[0].language_model, 'pooler'): if not hasattr(models[0].language_model, 'pooler'):
print("ERROR: got a pooler, but model does not have one") print("ERROR: got a pooler, but model does not have one")
exit(1) exit(1)
pooler_weight = queue_get() print("received pooler")
pooler_bias = queue_get() pooler_weight = msg.pop("weight")
pooler_bias = msg.pop("bias")
for tp_rank in range(args.target_tensor_parallel_size): for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].language_model.pooler.dense.weight.data.copy_(pooler_weight) models[tp_rank].language_model.pooler.dense.weight.data.copy_(pooler_weight)
models[tp_rank].language_model.pooler.dense.bias.data.copy_(pooler_bias) models[tp_rank].language_model.pooler.dense.bias.data.copy_(pooler_bias)
name = queue_get()
del pooler_weight del pooler_weight
del pooler_bias del pooler_bias
check_message(msg)
msg = queue_get()
if name == "lm head": if msg != "done" and msg["name"] == "lm head":
if not hasattr(models[0], 'lm_head'): if not hasattr(models[0], 'lm_head'):
print("ERROR: got an lm head, but model does not have one") print("ERROR: got an lm head, but model does not have one")
exit(1) exit(1)
lm_head_dense_weight = queue_get() print("received lm head")
lm_head_dense_bias = queue_get() lm_head_dense_weight = msg.pop("dense weight")
lm_head_layernorm_weight = queue_get() lm_head_dense_bias = msg.pop("dense bias")
lm_head_layernorm_bias = queue_get() lm_head_layernorm_weight = msg.pop("layernorm weight")
lm_head_layernorm_bias = msg.pop("layernorm bias")
for tp_rank in range(args.target_tensor_parallel_size): for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].lm_head.dense.weight.data.copy_(lm_head_dense_weight) models[tp_rank].lm_head.dense.weight.data.copy_(lm_head_dense_weight)
models[tp_rank].lm_head.dense.bias.data.copy_(lm_head_dense_bias) models[tp_rank].lm_head.dense.bias.data.copy_(lm_head_dense_bias)
models[tp_rank].lm_head.layernorm.weight.data.copy_(lm_head_layernorm_weight) models[tp_rank].lm_head.layernorm.weight.data.copy_(lm_head_layernorm_weight)
models[tp_rank].lm_head.layernorm.bias.data.copy_(lm_head_layernorm_bias) models[tp_rank].lm_head.layernorm.bias.data.copy_(lm_head_layernorm_bias)
name = queue_get() check_message(msg)
msg = queue_get()
if name == "binary head": if msg != "done" and msg["name"] == "binary head":
if not hasattr(models[0], 'binary_head'): if not hasattr(models[0], 'binary_head'):
print("ERROR: got a binary head, but model does not have one") print("ERROR: got a binary head, but model does not have one")
exit(1) exit(1)
binary_head_weight = queue_get() print("received binary head")
binary_head_bias = queue_get() binary_head_weight = msg.pop("weight")
binary_head_bias = msg.pop("bias")
for tp_rank in range(args.target_tensor_parallel_size): for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].binary_head.weight.data.copy_(binary_head_weight) models[tp_rank].binary_head.weight.data.copy_(binary_head_weight)
models[tp_rank].binary_head.bias.data.copy_(binary_head_bias) models[tp_rank].binary_head.bias.data.copy_(binary_head_bias)
name = queue_get() check_message(msg)
msg = queue_get()
if name != "done": if msg != "done":
print("ERROR: got some more data but were expecting to be done") print("ERROR: got some more data but was expecting to be done")
for tp_rank in range(args.target_tensor_parallel_size): for tp_rank in range(args.target_tensor_parallel_size):
mpu.initialize.set_tensor_model_parallel_rank(tp_rank) mpu.initialize.set_tensor_model_parallel_rank(tp_rank)
......
...@@ -12,10 +12,12 @@ import sys ...@@ -12,10 +12,12 @@ import sys
# load_checkpoint # load_checkpoint
# The loader and saver process are each given a queue, the loader # The loader and saver process are each given a queue, the loader
# should load the checkpoint and send the weights in the following # should load the checkpoint and send the weights in messages in the
# order, the saver should receive them in this order and save the # following order, the saver should receive them in this order and
# checkpoints. Note that the weight sent over the queue are the full # save the checkpoints. A message consists of a python dictionary with
# model weights, nothing split. # a "name" for error checking and an entry for each tensor as
# indicated below. Note that the weight sent over the queue are the
# full model weights, nothing split.
# If the loader ever sends "exit" to the queue, that means something # If the loader ever sends "exit" to the queue, that means something
# went wrong and it is exiting. # went wrong and it is exiting.
...@@ -37,35 +39,51 @@ import sys ...@@ -37,35 +39,51 @@ import sys
# make_vocab_size_divisble_by # make_vocab_size_divisble_by
# consumed_train_samples # consumed_train_samples
# consumed_valid_samples # consumed_valid_samples
# - Position embeddings # messages
# - Word embeddings # {
# - For each transformer layer: # "name": "embeddings"
# - input layernorm weights # "position embeddings"
# - input layernorm bias # "word embeddings"
# - qkv weight # }
# - qkv bias # (for each transformer layer):
# - dense weight # {
# - dense bias # "name": "transformer layer N"
# - post attention layernorm weight # "input layernorm weight"
# - post attention layernorm bias # "input layernorm bias"
# - mlp layer 0 (h to 4h) weight # "qkv weight"
# - mlp layer 0 (h to 4h) bias # "qkv bias"
# - mlp layer 1 (4h to h) weight # "dense weight"
# - mlp layer 1 (4h to h) bias # "dense bias"
# - final layer norm weight # "post layernorm weight"
# - final layer norm bias # "post layernorm bias"
# - if present (i.e. for BERT): # "mlp l0 weight"
# - "pooler" # "mlp l0 bias"
# - LM Pooler weight # "mlp l1 weight"
# - LM Pooler bias # "mlp l1 bias"
# - "lm head" # }
# - LM head dense weight # {
# - LM head dense bias # "name": "final layer norm"
# - LM head layernorm weight # "weight"
# - LM head layernorm bias # "bias"
# - "binary head" # }
# - BERT Binary head weight # if present (i.e. for BERT):
# - BERT Binary head bias # {
# "name": "pooler"
# "weight"
# "bias"
# }
# {
# "name": "lm head"
# "dense weight"
# "dense bias"
# "layernorm weight"
# "layernorm bias"
# }
# {
# "name": "binary head"
# "weight"
# "bias"
# }
# - "done" # - "done"
def load_plugin(plugin_type, name): def load_plugin(plugin_type, name):
...@@ -103,6 +121,9 @@ def main(): ...@@ -103,6 +121,9 @@ def main():
help='Directory to save model checkpoint to') help='Directory to save model checkpoint to')
parser.add_argument('--max-queue-size', type=int, default=50, parser.add_argument('--max-queue-size', type=int, default=50,
help='Maximum number of tensors in the queue') help='Maximum number of tensors in the queue')
parser.add_argument('--no-checking', action='store_false',
help='Do not perform checking on the name and ordering of weights',
dest='checking')
known_args, _ = parser.parse_known_args() known_args, _ = parser.parse_known_args()
loader = load_plugin('loader', known_args.loader) loader = load_plugin('loader', known_args.loader)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment