Commit 63c300ba authored by wxj's avatar wxj
Browse files

Update loader_llama_mistral.py

parent be4dda7b
Pipeline #2651 passed with stage
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import json
import os
import sys
import torch
try:
import transformers
except ImportError:
raise ImportError("The 'transformers' package is not installed.")
import gc
import shutil
from tqdm import tqdm
import types
def add_arguments(parser):
group = parser.add_argument_group(title='Llama/Mistral loader.')
# TODO(jbarker): Need assertion to make sure *exactly* one of these is used
parser.add_argument('--model-size', type=str, required=True,
choices=['llama2-7B', 'llama2-13B', 'llama2-70B', 'llama2-7Bf', 'llama2-13Bf', 'llama2-70Bf', 'llama3', 'mistral', 'yi-34B', 'qwen2.5'],
help='Select model size/type')
parser.add_argument('--checkpoint-type', type=str, required=True,
choices=['meta', 'hf'],
help='Type of checkpoint to convert, options are "meta" or "hf"')
parser.add_argument('--bf16', action='store_true', help='Whether to load weights in bf16.')
parser.add_argument('--fp16', action='store_true', help='Whether to load weights in fp16.')
group.add_argument('--true-vocab-size', type=int, default=None,
help='original size of vocab, if specified will trim padding from embedding table.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file. If specified will use this to get vocab size and '
'trim padding from the embedding table.')
group.add_argument('--tokenizer-model', required=True,
help='Tokenizer model file.')
group.add_argument('--megatron-path', type=str, default=None,
help='Base directory of Megatron repository')
group.add_argument("--make-vocab-size-divisible-by", type=int, default=None, help="Make vocab size divisible by")
group.add_argument('--loader-transformer-impl', default='local',
choices=['local', 'transformer_engine'],
help='Which Transformer implementation to use.')
def verify_transformers_version():
major, minor, patch = map(int, transformers.__version__.split('.'))
assert major >= 4 and minor >= 31
NUM_SHARDS = {
"llama2-7B": 1,
"llama2-7Bf": 1,
"llama2-13B": 2,
"llama2-13Bf": 2,
"llama2-70B": 8,
"llama2-70Bf": 8,
}
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
def read_json(path):
with open(path, "r") as f:
return json.load(f)
def write_json(text, path):
with open(path, "w") as f:
json.dump(text, f)
# This conversion is adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py
def convert_to_hf(model_path, input_base_path, model_size, tokenizer_path):
if "llama2" in model_size:
from transformers import LlamaConfig as ModelConfig
from transformers import LlamaTokenizer, LlamaTokenizerFast
else:
raise NotImplementedError(f"converting {model_size} is only supported using HuggingFace weights")
# for backward compatibility, before you needed the repo to be called `my_repo/model_size`
if not os.path.isfile(os.path.join(input_base_path, "params.json")):
input_base_path = os.path.join(input_base_path, model_size)
os.makedirs(model_path, exist_ok=True)
params = read_json(os.path.join(input_base_path, "params.json"))
num_shards = NUM_SHARDS[model_size]
params = params.get("model", params)
n_layers = params["n_layers"]
n_heads = params["n_heads"]
n_heads_per_shard = n_heads // num_shards
dim = params["dim"]
dims_per_head = dim // n_heads
base = params.get("rope_theta", 10000.0)
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
if base > 10000.0:
max_position_embeddings = 32768 if "mistral" in model_size else 16384
else:
max_position_embeddings = 4096
if "llama2" in model_size:
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
else:
raise AttributeError(f"model_size={model_size} not supported")
if tokenizer_path is not None:
if "llama2" in model_size:
tokenizer = tokenizer_class(tokenizer_path)
tokenizer.save_pretrained(model_path)
vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000
else:
raise AttributeError(f"model_size={model_size} is not supported")
if params.get("n_kv_heads", None) is not None:
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
key_value_dim = dim // num_key_value_heads
else: # compatibility with other checkpoints
num_key_value_heads = n_heads
num_local_key_value_heads = n_heads_per_shard
key_value_dim = dim
# permute for sliced rotary
def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
# Load weights
if num_shards == 1:
# Not sharded
# (The sharded implementation would also work, but this is simpler.)
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
else:
# Sharded
loaded = [
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
for i in range(num_shards)
]
param_count = 0
index_dict = {"weight_map": {}}
for layer_i in range(n_layers):
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
if num_shards == 1:
# Unsharded
q_proj = loaded[f"layers.{layer_i}.attention.wq.weight"]
k_proj = loaded[f"layers.{layer_i}.attention.wk.weight"]
if ("llama2" in model_size) or ("mistral" in model_size):
q_proj = permute(q_proj)
k_proj = permute(k_proj)
state_dict = {
f"model.layers.{layer_i}.self_attn.q_proj.weight": q_proj,
f"model.layers.{layer_i}.self_attn.k_proj.weight": k_proj,
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"],
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"],
}
else:
# Sharded
# Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
# the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
# redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.
state_dict = {
f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][
f"layers.{layer_i}.attention_norm.weight"
].clone(),
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][
f"layers.{layer_i}.ffn_norm.weight"
].clone(),
}
state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
for i in range(num_shards)
],
dim=0,
).reshape(dim, dim)
)
state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(
num_local_key_value_heads, dims_per_head, dim
)
for i in range(num_shards)
],
dim=0,
).reshape(key_value_dim, dim),
num_key_value_heads,
key_value_dim,
dim,
)
state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(
num_local_key_value_heads, dims_per_head, dim
)
for i in range(num_shards)
],
dim=0,
).reshape(key_value_dim, dim)
state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
)
state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
)
state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
)
state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
)
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(model_path, filename))
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
if num_shards == 1:
# Unsharded
state_dict = {
"model.embed_tokens.weight": loaded["tok_embeddings.weight"],
"model.norm.weight": loaded["norm.weight"],
"lm_head.weight": loaded["output.weight"],
}
else:
d = 0 if "llama3" in model_size else 1
state_dict = {
"model.norm.weight": loaded[0]["norm.weight"],
"model.embed_tokens.weight": torch.cat(
[loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=d
),
"lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
}
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(model_path, filename))
# Write configs
index_dict["metadata"] = {"total_size": param_count * 2}
write_json(index_dict, os.path.join(model_path, "pytorch_model.bin.index.json"))
ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1
multiple_of = params["multiple_of"] if "multiple_of" in params else 256
config = ModelConfig(
hidden_size=dim,
intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of),
num_attention_heads=params["n_heads"],
num_hidden_layers=params["n_layers"],
rms_norm_eps=params["norm_eps"],
num_key_value_heads=num_key_value_heads,
vocab_size=vocab_size,
rope_theta=base,
max_position_embeddings=max_position_embeddings,
)
config.save_pretrained(model_path)
# Make space so we can load the model properly now.
del state_dict
del loaded
gc.collect()
return model_path
def load_args_from_checkpoint(args, model_size):
# Read Llama args.
model_args_path = os.path.join(args.load, "config.json")
with open(model_args_path) as f:
model_args = json.load(f)
# Update Megatron args.
args.seq_length = 4096
if "llama2" in model_size:
# Correct bug in earlier conversion script.
args.max_position_embeddings = 4096
else:
args.max_position_embeddings = model_args["max_position_embeddings"]
args.hidden_size = model_args["hidden_size"]
args.num_attention_heads = model_args["num_attention_heads"]
args.num_layers = model_args["num_hidden_layers"]
args.global_batch_size = 1024
args.norm_epsilon = model_args["rms_norm_eps"]
args.iteration = 1 # '0', 'release' don't work
args.position_embedding_type = "rope"
args.swiglu = True
args.normalization = "RMSNorm"
args.add_bias_linear = False
args.untie_embeddings_and_output_weights = not model_args.get("tie_word_embeddings", False)
args.vocab_size = model_args["vocab_size"]
args.padded_vocab_size = model_args["vocab_size"]
args.ffn_hidden_size = model_args["intermediate_size"]
if "num_key_value_heads" in model_args:
args.group_query_attention = True
args.num_query_groups = model_args["num_key_value_heads"]
def set_preprocess_state(args, model, hf_model):
'''Set embedding params.'''
model.language_model.embedding.word_embeddings.weight.data.copy_(
hf_model.model.embed_tokens.weight)
def set_postprocess_state(args, model, hf_model):
'''Set output layer & norm params.'''
model.language_model.encoder.final_norm.weight.data.copy_(hf_model.model.norm.weight)
if args.untie_embeddings_and_output_weights:
model.language_model.output_layer.weight.data.copy_(hf_model.lm_head.weight)
def set_attn_state(args, layer, hf_layer):
'''Set self-attention params.'''
# Get attention layer & state.
attn = layer.self_attention
hf_attn = hf_layer.self_attn
# Reshape loaded weights.
tp = args.tensor_model_parallel_size
nh = args.num_attention_heads // tp
ng = (args.num_query_groups if args.group_query_attention \
else args.num_attention_heads) // tp
dim = args.kv_channels
assert nh % ng == 0
# Copy weights (re-order dimensions for Megatron).
attn.query_key_value.weight.data.copy_(torch.cat([
hf_attn.q_proj.weight.reshape((ng, dim*nh//ng, -1)),
hf_attn.k_proj.weight.reshape((ng, dim, -1)),
hf_attn.v_proj.weight.reshape((ng, dim, -1)),
], dim=1).reshape((-1, args.hidden_size)))
if args.add_qkv_bias:
attn.query_key_value.bias.data.copy_(torch.cat([
hf_attn.q_proj.bias.reshape((ng, dim*nh//ng)),
hf_attn.k_proj.bias.reshape((ng, dim)),
hf_attn.v_proj.bias.reshape((ng, dim)),
], dim=1).reshape(-1))
attn.dense.weight.data.copy_(hf_attn.o_proj.weight)
def set_mlp_state(args, layer, hf_layer):
'''Set MLP params.'''
mlp = layer.mlp
hf_mlp = hf_layer.mlp
mlp.dense_h_to_4h.weight.data.copy_(torch.cat([
hf_mlp.gate_proj.weight,
hf_mlp.up_proj.weight,
], dim=0))
mlp.dense_4h_to_h.weight.data.copy_(hf_mlp.down_proj.weight)
def set_layer_state(args, model, hf_model, layer_idx):
'''Set transformer layer params.'''
layer = model.language_model.encoder.layers[layer_idx]
hf_layer = hf_model.model.layers[layer_idx]
set_attn_state(args, layer, hf_layer)
set_mlp_state(args, layer, hf_layer)
layer.input_norm.weight.data.copy_(hf_layer.input_layernorm.weight)
layer.post_attention_norm.weight.data.copy_(hf_layer.post_attention_layernorm.weight)
def load_checkpoint_to_model(args):
'''Set model params.'''
from pretrain_gpt import model_provider
from transformers import AutoModelForCausalLM
# Load Huggingface model.
hf_model = AutoModelForCausalLM.from_pretrained(args.load, torch_dtype=args.params_dtype, low_cpu_mem_usage=True, device_map="cpu")
# Init Megatron model.
model = model_provider(True, True).to(args.params_dtype)
# Set model state.
set_preprocess_state(args, model, hf_model)
set_postprocess_state(args, model, hf_model)
for layer_idx in tqdm(range(args.num_layers), "set layer states"):
set_layer_state(args, model, hf_model, layer_idx)
return model
def _load_checkpoint(queue, args):
verify_transformers_version()
# Search in directory above this.
sys.path.append(os.path.abspath(
os.path.join(os.path.dirname(__file__),
os.path.pardir,
os.path.pardir)))
if args.megatron_path is not None:
sys.path.insert(0, args.megatron_path)
# Convert Meta checkpoint to HF format as an intermediate step
if args.checkpoint_type == "meta":
model_tmp_path = convert_to_hf(model_path=os.path.join(args.save_dir, 'tmp'), input_base_path=args.load_dir, model_size=args.model_size, tokenizer_path=args.tokenizer_model)
args.load_dir = model_tmp_path
args.tokenizer_model = model_tmp_path # point to HF tokenizer model
try:
from megatron.training.arguments import parse_args, validate_args
from megatron.training.global_vars import set_args, set_global_variables
from megatron.legacy.model import module
from megatron.core import mpu
from megatron.core.enums import ModelType
from megatron.legacy import fused_kernels
except ModuleNotFoundError:
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
queue.put("exit")
exit(1)
# We want all arguments to come from us.
sys.argv = ['script.py',
'--no-masked-softmax-fusion',
'--no-bias-gelu-fusion',
'--no-bias-dropout-fusion',
'--no-async-tensor-model-parallel-allreduce',
'--use-cpu-initialization',
'--micro-batch-size', '1',
'--no-load-optim',
'--no-load-rng',
'--no-save-optim',
'--no-save-rng',
'--mock-data', # To pass the "blend data checks" in arguments.py
'--no-initialization',
'--load', args.load_dir,
'--no-one-logger',
]
if args.make_vocab_size_divisible_by is not None:
sys.argv.extend(["--make-vocab-size-divisible-by", str(args.make_vocab_size_divisible_by)])
margs = parse_args()
margs.tokenizer_model = args.tokenizer_model
load_args_from_checkpoint(margs, args.model_size)
if "llama2" in args.model_size:
margs.tokenizer_type = "Llama2Tokenizer"
elif "yi" in args.model_size:
margs.tokenizer_type = "HuggingFaceTokenizer"
elif "llama3" in args.model_size:
margs.tokenizer_type = "HuggingFaceTokenizer"
elif "mistral" in args.model_size:
margs.tokenizer_type = "HuggingFaceTokenizer"
elif "qwen2.5" in args.model_size:
margs.tokenizer_type = "HuggingFaceTokenizer"
margs.add_qkv_bias = True
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes.
margs.world_size = margs.tensor_model_parallel_size * margs.pipeline_model_parallel_size
margs = validate_args(margs)
margs.use_legacy_models = True
margs.transformer_impl = args.loader_transformer_impl
margs.position_embedding_type = "rope"
def check_for_arg(arg_name, default=None):
if getattr(margs, arg_name, None) is None:
if default is not None:
setattr(margs, arg_name, default)
else:
print(f"Checkpoint does not specify the argument {arg_name}. Exiting.")
print(f"Arguments: {margs}")
queue.put("exit")
exit(1)
check_for_arg('tensor_model_parallel_size')
check_for_arg('pipeline_model_parallel_size')
check_for_arg('num_layers')
check_for_arg('hidden_size')
check_for_arg('seq_length')
check_for_arg('num_attention_heads')
check_for_arg('max_position_embeddings')
check_for_arg('position_embedding_type')
check_for_arg('iteration')
check_for_arg('bert_binary_head')
check_for_arg('disable_bias_linear', False)
check_for_arg('params_dtype')
check_for_arg('swiglu', False)
# Determine how to make our models.
assert args.model_type == 'GPT', 'Llama-2, Llama-3 and Mistral are GPT models.'
margs.model_type = ModelType.encoder_or_decoder
margs.params_dtype = torch.bfloat16 if args.bf16 else torch.float16 if args.fp16 else torch.float32
# Suppress warning about torch.distributed not being initialized.
module.MegatronModule.embedding_warning_printed = True
set_global_variables(margs, build_tokenizer=False)
mpu.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
mpu.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
mpu.set_virtual_pipeline_model_parallel_world_size(margs.virtual_pipeline_model_parallel_size)
fused_kernels.load(margs)
# Short aliases.
tp_size = margs.tensor_model_parallel_size
pp_size = margs.pipeline_model_parallel_size
vp_size = margs.virtual_pipeline_model_parallel_size
if vp_size is None:
vp_size = 1
# Metadata.
md = types.SimpleNamespace()
md.model_type = args.model_type
md.num_layers = margs.num_layers
md.hidden_size = margs.hidden_size
md.seq_length = margs.seq_length
md.num_attention_heads = margs.num_attention_heads
md.max_position_embeddings = margs.max_position_embeddings
md.tokenizer_type = margs.tokenizer_type
md.iteration = margs.iteration
md.params_dtype = margs.params_dtype
md.bert_binary_head = margs.bert_binary_head
md.output_layer = margs.untie_embeddings_and_output_weights
md.position_embedding_type = margs.position_embedding_type
md.linear_bias = margs.add_bias_linear
md.qkv_bias = margs.add_qkv_bias
md.norm_has_bias = False
md.swiglu = margs.swiglu
md.previous_tensor_parallel_size = margs.tensor_model_parallel_size
md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size
md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by
md.checkpoint_args = margs
md.consumed_train_samples = 0
md.consumed_valid_samples = 0
margs.model_size = args.model_size
# Get true (non-padded) vocab size
tokenizer = transformers.AutoTokenizer.from_pretrained(margs.tokenizer_model)
md.true_vocab_size = tokenizer._tokenizer.get_vocab_size(with_added_tokens=True)
# Get first pipe stage.
mpu.set_tensor_model_parallel_rank(0)
mpu.set_pipeline_model_parallel_rank(0)
model = load_checkpoint_to_model(margs)
queue.put(md)
def queue_put(name, msg):
print(f"sending {name}")
msg["name"] = name
queue.put(msg)
# Send embeddings.
message = {
"word embeddings": model.language_model.embedding.word_embeddings.weight.data
}
if md.position_embedding_type == 'learned_absolute':
message["position embeddings"] = model.language_model.embedding.position_embeddings.weight.data
else:
assert not hasattr(model.language_model.embedding, 'position_embeddings')
queue_put("embeddings", message)
for layer_num in range(margs.num_layers):
message = {}
# Get non-parallel tensors from tp_rank 0.
layer = model.language_model.encoder.layers[layer_num]
message["input norm weight"] = layer.input_norm.weight.data
message["post norm weight"] = layer.post_attention_norm.weight.data
if md.linear_bias:
message["dense bias"] = layer.self_attention.dense.bias.data
message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data
# Grab all parallel tensors for this layer.
qkv_weight = []
qkv_bias = []
dense_weight = []
mlp_l0_weight = []
mlp_l0_bias = []
mlp_l1_weight = []
layer = model.language_model.encoder.layers[layer_num]
qkv_weight.append(layer.self_attention.query_key_value.weight.data)
dense_weight.append(layer.self_attention.dense.weight.data)
mlp_l0_weight.append(layer.mlp.dense_h_to_4h.weight.data)
mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data)
if md.qkv_bias:
qkv_bias.append(layer.self_attention.query_key_value.bias.data)
if md.linear_bias:
mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data)
# Handle gated linear units.
if md.swiglu:
# Concat all the first halves ('W's) and all the second halves ('V's).
for tp_rank in range(tp_size):
mlp_l0_weight[tp_rank] = torch.chunk(mlp_l0_weight[tp_rank], 2, dim=0)
message["mlp l0 weight W"] = torch.cat([w[0] for w in mlp_l0_weight], dim=0)
message["mlp l0 weight V"] = torch.cat([w[1] for w in mlp_l0_weight], dim=0)
else:
message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0)
# Simple concat of the rest.
message["qkv weight"] = torch.cat(qkv_weight, dim=0)
message["dense weight"] = torch.cat(dense_weight, dim=1)
message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1)
if md.qkv_bias:
message["qkv bias"] = torch.cat(qkv_bias, dim=0)
if md.linear_bias:
if md.swiglu:
for tp_rank in range(tp_size):
mlp_l0_bias[tp_rank] = torch.chunk(mlp_l0_bias[tp_rank], 2, dim=0)
message["mlp l0 bias W"] = torch.cat([b[0] for b in mlp_l0_bias],dim=0)
message["mlp l0 bias V"] = torch.cat([b[1] for b in mlp_l0_bias],dim=0)
else:
message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0)
queue_put(f"transformer layer {layer_num}", message)
# Send final norm from tp_rank 0.
message = {
"weight": model.language_model.encoder.final_norm.weight.data,
}
queue_put("final norm", message)
if md.output_layer:
message = {
"weight": model.language_model.output_layer.weight.data
}
queue_put("output layer", message)
queue.put("done")
if args.checkpoint_type == "meta":
shutil.rmtree(os.path.join(args.load_dir))
def load_checkpoint(queue, args):
try:
_load_checkpoint(queue, args)
except Exception:
queue.put("exit")
raise
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import json
import os
import sys
import torch
try:
import transformers
except ImportError:
raise ImportError("The 'transformers' package is not installed.")
import gc
import shutil
from tqdm import tqdm
import types
def add_arguments(parser):
group = parser.add_argument_group(title='Llama/Mistral loader.')
# TODO(jbarker): Need assertion to make sure *exactly* one of these is used
parser.add_argument('--model-size', type=str, required=True,
choices=['llama2-7B', 'llama2-13B', 'llama2-70B', 'llama2-7Bf', 'llama2-13Bf', 'llama2-70Bf', 'llama3', 'mistral', 'yi-34B', 'qwen2.5'],
help='Select model size/type')
parser.add_argument('--checkpoint-type', type=str, required=True,
choices=['meta', 'hf'],
help='Type of checkpoint to convert, options are "meta" or "hf"')
parser.add_argument('--bf16', action='store_true', help='Whether to load weights in bf16.')
parser.add_argument('--fp16', action='store_true', help='Whether to load weights in fp16.')
group.add_argument('--true-vocab-size', type=int, default=None,
help='original size of vocab, if specified will trim padding from embedding table.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file. If specified will use this to get vocab size and '
'trim padding from the embedding table.')
group.add_argument('--tokenizer-model', required=True,
help='Tokenizer model file.')
group.add_argument('--megatron-path', type=str, default=None,
help='Base directory of Megatron repository')
group.add_argument("--make-vocab-size-divisible-by", type=int, default=None, help="Make vocab size divisible by")
group.add_argument('--loader-transformer-impl', default='local',
choices=['local', 'transformer_engine'],
help='Which Transformer implementation to use.')
def verify_transformers_version():
major, minor, patch = map(int, transformers.__version__.split('.'))
assert major >= 4 and minor >= 31
NUM_SHARDS = {
"llama2-7B": 1,
"llama2-7Bf": 1,
"llama2-13B": 2,
"llama2-13Bf": 2,
"llama2-70B": 8,
"llama2-70Bf": 8,
}
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
def read_json(path):
with open(path, "r") as f:
return json.load(f)
def write_json(text, path):
with open(path, "w") as f:
json.dump(text, f)
# This conversion is adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py
def convert_to_hf(model_path, input_base_path, model_size, tokenizer_path):
if "llama2" in model_size:
from transformers import LlamaConfig as ModelConfig
from transformers import LlamaTokenizer, LlamaTokenizerFast
else:
raise NotImplementedError(f"converting {model_size} is only supported using HuggingFace weights")
# for backward compatibility, before you needed the repo to be called `my_repo/model_size`
if not os.path.isfile(os.path.join(input_base_path, "params.json")):
input_base_path = os.path.join(input_base_path, model_size)
os.makedirs(model_path, exist_ok=True)
params = read_json(os.path.join(input_base_path, "params.json"))
num_shards = NUM_SHARDS[model_size]
params = params.get("model", params)
n_layers = params["n_layers"]
n_heads = params["n_heads"]
n_heads_per_shard = n_heads // num_shards
dim = params["dim"]
dims_per_head = dim // n_heads
base = params.get("rope_theta", 10000.0)
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
if base > 10000.0:
max_position_embeddings = 32768 if "mistral" in model_size else 16384
else:
max_position_embeddings = 4096
if "llama2" in model_size:
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
else:
raise AttributeError(f"model_size={model_size} not supported")
if tokenizer_path is not None:
if "llama2" in model_size:
tokenizer = tokenizer_class(tokenizer_path)
tokenizer.save_pretrained(model_path)
vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000
else:
raise AttributeError(f"model_size={model_size} is not supported")
if params.get("n_kv_heads", None) is not None:
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
key_value_dim = dim // num_key_value_heads
else: # compatibility with other checkpoints
num_key_value_heads = n_heads
num_local_key_value_heads = n_heads_per_shard
key_value_dim = dim
# permute for sliced rotary
def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
# Load weights
if num_shards == 1:
# Not sharded
# (The sharded implementation would also work, but this is simpler.)
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
else:
# Sharded
loaded = [
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
for i in range(num_shards)
]
param_count = 0
index_dict = {"weight_map": {}}
for layer_i in range(n_layers):
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
if num_shards == 1:
# Unsharded
q_proj = loaded[f"layers.{layer_i}.attention.wq.weight"]
k_proj = loaded[f"layers.{layer_i}.attention.wk.weight"]
if ("llama2" in model_size) or ("mistral" in model_size):
q_proj = permute(q_proj)
k_proj = permute(k_proj)
state_dict = {
f"model.layers.{layer_i}.self_attn.q_proj.weight": q_proj,
f"model.layers.{layer_i}.self_attn.k_proj.weight": k_proj,
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"],
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"],
}
else:
# Sharded
# Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
# the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
# redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.
state_dict = {
f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][
f"layers.{layer_i}.attention_norm.weight"
].clone(),
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][
f"layers.{layer_i}.ffn_norm.weight"
].clone(),
}
state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
for i in range(num_shards)
],
dim=0,
).reshape(dim, dim)
)
state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(
num_local_key_value_heads, dims_per_head, dim
)
for i in range(num_shards)
],
dim=0,
).reshape(key_value_dim, dim),
num_key_value_heads,
key_value_dim,
dim,
)
state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(
num_local_key_value_heads, dims_per_head, dim
)
for i in range(num_shards)
],
dim=0,
).reshape(key_value_dim, dim)
state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
)
state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
)
state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
)
state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
)
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(model_path, filename))
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
if num_shards == 1:
# Unsharded
state_dict = {
"model.embed_tokens.weight": loaded["tok_embeddings.weight"],
"model.norm.weight": loaded["norm.weight"],
"lm_head.weight": loaded["output.weight"],
}
else:
d = 0 if "llama3" in model_size else 1
state_dict = {
"model.norm.weight": loaded[0]["norm.weight"],
"model.embed_tokens.weight": torch.cat(
[loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=d
),
"lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
}
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(model_path, filename))
# Write configs
index_dict["metadata"] = {"total_size": param_count * 2}
write_json(index_dict, os.path.join(model_path, "pytorch_model.bin.index.json"))
ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1
multiple_of = params["multiple_of"] if "multiple_of" in params else 256
config = ModelConfig(
hidden_size=dim,
intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of),
num_attention_heads=params["n_heads"],
num_hidden_layers=params["n_layers"],
rms_norm_eps=params["norm_eps"],
num_key_value_heads=num_key_value_heads,
vocab_size=vocab_size,
rope_theta=base,
max_position_embeddings=max_position_embeddings,
)
config.save_pretrained(model_path)
# Make space so we can load the model properly now.
del state_dict
del loaded
gc.collect()
return model_path
def load_args_from_checkpoint(args, model_size):
# Read Llama args.
model_args_path = os.path.join(args.load, "config.json")
with open(model_args_path) as f:
model_args = json.load(f)
# Update Megatron args.
args.seq_length = 4096
if "llama2" in model_size:
# Correct bug in earlier conversion script.
args.max_position_embeddings = 4096
else:
args.max_position_embeddings = model_args["max_position_embeddings"]
args.hidden_size = model_args["hidden_size"]
args.num_attention_heads = model_args["num_attention_heads"]
args.num_layers = model_args["num_hidden_layers"]
args.global_batch_size = 1024
args.norm_epsilon = model_args["rms_norm_eps"]
args.iteration = 1 # '0', 'release' don't work
args.position_embedding_type = "rope"
args.swiglu = True
args.normalization = "RMSNorm"
args.add_bias_linear = False
args.untie_embeddings_and_output_weights = not model_args.get("tie_word_embeddings", False)
args.vocab_size = model_args["vocab_size"]
args.padded_vocab_size = model_args["vocab_size"]
args.ffn_hidden_size = model_args["intermediate_size"]
if "num_key_value_heads" in model_args:
args.group_query_attention = True
args.num_query_groups = model_args["num_key_value_heads"]
def set_preprocess_state(args, model, hf_model):
'''Set embedding params.'''
model.language_model.embedding.word_embeddings.weight.data.copy_(
hf_model.model.embed_tokens.weight)
def set_postprocess_state(args, model, hf_model):
'''Set output layer & norm params.'''
model.language_model.encoder.final_norm.weight.data.copy_(hf_model.model.norm.weight)
if args.untie_embeddings_and_output_weights:
model.language_model.output_layer.weight.data.copy_(hf_model.lm_head.weight)
def set_attn_state(args, layer, hf_layer):
'''Set self-attention params.'''
# Get attention layer & state.
attn = layer.self_attention
hf_attn = hf_layer.self_attn
# Reshape loaded weights.
tp = args.tensor_model_parallel_size
nh = args.num_attention_heads // tp
ng = (args.num_query_groups if args.group_query_attention \
else args.num_attention_heads) // tp
dim = args.kv_channels
assert nh % ng == 0
# Copy weights (re-order dimensions for Megatron).
attn.query_key_value.weight.data.copy_(torch.cat([
hf_attn.q_proj.weight.reshape((ng, dim*nh//ng, -1)),
hf_attn.k_proj.weight.reshape((ng, dim, -1)),
hf_attn.v_proj.weight.reshape((ng, dim, -1)),
], dim=1).reshape((-1, args.hidden_size)))
if args.add_qkv_bias:
attn.query_key_value.bias.data.copy_(torch.cat([
hf_attn.q_proj.bias.reshape((ng, dim*nh//ng)),
hf_attn.k_proj.bias.reshape((ng, dim)),
hf_attn.v_proj.bias.reshape((ng, dim)),
], dim=1).reshape(-1))
attn.dense.weight.data.copy_(hf_attn.o_proj.weight)
def set_mlp_state(args, layer, hf_layer):
'''Set MLP params.'''
mlp = layer.mlp
hf_mlp = hf_layer.mlp
mlp.dense_h_to_4h.weight.data.copy_(torch.cat([
hf_mlp.gate_proj.weight,
hf_mlp.up_proj.weight,
], dim=0))
mlp.dense_4h_to_h.weight.data.copy_(hf_mlp.down_proj.weight)
def set_layer_state(args, model, hf_model, layer_idx):
'''Set transformer layer params.'''
layer = model.language_model.encoder.layers[layer_idx]
hf_layer = hf_model.model.layers[layer_idx]
set_attn_state(args, layer, hf_layer)
set_mlp_state(args, layer, hf_layer)
layer.input_norm.weight.data.copy_(hf_layer.input_layernorm.weight)
layer.post_attention_norm.weight.data.copy_(hf_layer.post_attention_layernorm.weight)
def load_checkpoint_to_model(args):
'''Set model params.'''
from pretrain_gpt import model_provider
from transformers import AutoModelForCausalLM
# Load Huggingface model.
hf_model = AutoModelForCausalLM.from_pretrained(args.load, torch_dtype=args.params_dtype, low_cpu_mem_usage=True, device_map="cpu")
# Init Megatron model.
model = model_provider(True, True).to(args.params_dtype)
# Set model state.
set_preprocess_state(args, model, hf_model)
set_postprocess_state(args, model, hf_model)
for layer_idx in tqdm(range(args.num_layers), "set layer states"):
set_layer_state(args, model, hf_model, layer_idx)
return model
def _load_checkpoint(queue, args):
verify_transformers_version()
# Search in directory above this.
sys.path.append(os.path.abspath(
os.path.join(os.path.dirname(__file__),
os.path.pardir,
os.path.pardir)))
if args.megatron_path is not None:
sys.path.insert(0, args.megatron_path)
# Convert Meta checkpoint to HF format as an intermediate step
if args.checkpoint_type == "meta":
model_tmp_path = convert_to_hf(model_path=os.path.join(args.save_dir, 'tmp'), input_base_path=args.load_dir, model_size=args.model_size, tokenizer_path=args.tokenizer_model)
args.load_dir = model_tmp_path
args.tokenizer_model = model_tmp_path # point to HF tokenizer model
try:
from megatron.training.arguments import parse_args, validate_args
from megatron.training.global_vars import set_args, set_global_variables
from megatron.legacy.model import module
from megatron.core import mpu
from megatron.core.enums import ModelType
from megatron.legacy import fused_kernels
except ModuleNotFoundError:
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
queue.put("exit")
exit(1)
# We want all arguments to come from us.
sys.argv = ['script.py',
'--no-masked-softmax-fusion',
'--no-bias-gelu-fusion',
'--no-bias-dropout-fusion',
'--no-async-tensor-model-parallel-allreduce',
'--use-cpu-initialization',
'--micro-batch-size', '1',
'--no-load-optim',
'--no-load-rng',
'--no-save-optim',
'--no-save-rng',
'--mock-data', # To pass the "blend data checks" in arguments.py
'--no-initialization',
'--load', args.load_dir,
'--no-one-logger',
]
if args.make_vocab_size_divisible_by is not None:
sys.argv.extend(["--make-vocab-size-divisible-by", str(args.make_vocab_size_divisible_by)])
margs = parse_args()
margs.tokenizer_model = args.tokenizer_model
load_args_from_checkpoint(margs, args.model_size)
if "llama2" in args.model_size:
margs.tokenizer_type = "Llama2Tokenizer"
elif "yi" in args.model_size:
margs.tokenizer_type = "HuggingFaceTokenizer"
elif "llama3" in args.model_size:
margs.tokenizer_type = "HuggingFaceTokenizer"
elif "mistral" in args.model_size:
margs.tokenizer_type = "HuggingFaceTokenizer"
elif "qwen2.5" in args.model_size:
margs.tokenizer_type = "HuggingFaceTokenizer"
margs.add_qkv_bias = True
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes.
margs.world_size = margs.tensor_model_parallel_size * margs.pipeline_model_parallel_size
margs = validate_args(margs)
margs.use_legacy_models = True
margs.transformer_impl = args.loader_transformer_impl
margs.position_embedding_type = "rope"
def check_for_arg(arg_name, default=None):
if getattr(margs, arg_name, None) is None:
if default is not None:
setattr(margs, arg_name, default)
else:
print(f"Checkpoint does not specify the argument {arg_name}. Exiting.")
print(f"Arguments: {margs}")
queue.put("exit")
exit(1)
check_for_arg('tensor_model_parallel_size')
check_for_arg('pipeline_model_parallel_size')
check_for_arg('num_layers')
check_for_arg('hidden_size')
check_for_arg('seq_length')
check_for_arg('num_attention_heads')
check_for_arg('max_position_embeddings')
check_for_arg('position_embedding_type')
check_for_arg('iteration')
check_for_arg('bert_binary_head')
check_for_arg('disable_bias_linear', False)
check_for_arg('params_dtype')
check_for_arg('swiglu', False)
# Determine how to make our models.
assert args.model_type == 'GPT', 'Llama-2, Llama-3 and Mistral are GPT models.'
margs.model_type = ModelType.encoder_or_decoder
margs.params_dtype = torch.bfloat16 if args.bf16 else torch.float16 if args.fp16 else torch.float32
# Suppress warning about torch.distributed not being initialized.
module.MegatronModule.embedding_warning_printed = True
set_global_variables(margs, build_tokenizer=False)
mpu.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
mpu.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
mpu.set_virtual_pipeline_model_parallel_world_size(margs.virtual_pipeline_model_parallel_size)
# fused_kernels.load(margs)
# Short aliases.
tp_size = margs.tensor_model_parallel_size
pp_size = margs.pipeline_model_parallel_size
vp_size = margs.virtual_pipeline_model_parallel_size
if vp_size is None:
vp_size = 1
# Metadata.
md = types.SimpleNamespace()
md.model_type = args.model_type
md.num_layers = margs.num_layers
md.hidden_size = margs.hidden_size
md.seq_length = margs.seq_length
md.num_attention_heads = margs.num_attention_heads
md.max_position_embeddings = margs.max_position_embeddings
md.tokenizer_type = margs.tokenizer_type
md.iteration = margs.iteration
md.params_dtype = margs.params_dtype
md.bert_binary_head = margs.bert_binary_head
md.output_layer = margs.untie_embeddings_and_output_weights
md.position_embedding_type = margs.position_embedding_type
md.linear_bias = margs.add_bias_linear
md.qkv_bias = margs.add_qkv_bias
md.norm_has_bias = False
md.swiglu = margs.swiglu
md.previous_tensor_parallel_size = margs.tensor_model_parallel_size
md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size
md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by
md.checkpoint_args = margs
md.consumed_train_samples = 0
md.consumed_valid_samples = 0
margs.model_size = args.model_size
# Get true (non-padded) vocab size
tokenizer = transformers.AutoTokenizer.from_pretrained(margs.tokenizer_model)
md.true_vocab_size = tokenizer._tokenizer.get_vocab_size(with_added_tokens=True)
# Get first pipe stage.
mpu.set_tensor_model_parallel_rank(0)
mpu.set_pipeline_model_parallel_rank(0)
model = load_checkpoint_to_model(margs)
queue.put(md)
def queue_put(name, msg):
print(f"sending {name}")
msg["name"] = name
queue.put(msg)
# Send embeddings.
message = {
"word embeddings": model.language_model.embedding.word_embeddings.weight.data
}
if md.position_embedding_type == 'learned_absolute':
message["position embeddings"] = model.language_model.embedding.position_embeddings.weight.data
else:
assert not hasattr(model.language_model.embedding, 'position_embeddings')
queue_put("embeddings", message)
for layer_num in range(margs.num_layers):
message = {}
# Get non-parallel tensors from tp_rank 0.
layer = model.language_model.encoder.layers[layer_num]
message["input norm weight"] = layer.input_norm.weight.data
message["post norm weight"] = layer.post_attention_norm.weight.data
if md.linear_bias:
message["dense bias"] = layer.self_attention.dense.bias.data
message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data
# Grab all parallel tensors for this layer.
qkv_weight = []
qkv_bias = []
dense_weight = []
mlp_l0_weight = []
mlp_l0_bias = []
mlp_l1_weight = []
layer = model.language_model.encoder.layers[layer_num]
qkv_weight.append(layer.self_attention.query_key_value.weight.data)
dense_weight.append(layer.self_attention.dense.weight.data)
mlp_l0_weight.append(layer.mlp.dense_h_to_4h.weight.data)
mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data)
if md.qkv_bias:
qkv_bias.append(layer.self_attention.query_key_value.bias.data)
if md.linear_bias:
mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data)
# Handle gated linear units.
if md.swiglu:
# Concat all the first halves ('W's) and all the second halves ('V's).
for tp_rank in range(tp_size):
mlp_l0_weight[tp_rank] = torch.chunk(mlp_l0_weight[tp_rank], 2, dim=0)
message["mlp l0 weight W"] = torch.cat([w[0] for w in mlp_l0_weight], dim=0)
message["mlp l0 weight V"] = torch.cat([w[1] for w in mlp_l0_weight], dim=0)
else:
message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0)
# Simple concat of the rest.
message["qkv weight"] = torch.cat(qkv_weight, dim=0)
message["dense weight"] = torch.cat(dense_weight, dim=1)
message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1)
if md.qkv_bias:
message["qkv bias"] = torch.cat(qkv_bias, dim=0)
if md.linear_bias:
if md.swiglu:
for tp_rank in range(tp_size):
mlp_l0_bias[tp_rank] = torch.chunk(mlp_l0_bias[tp_rank], 2, dim=0)
message["mlp l0 bias W"] = torch.cat([b[0] for b in mlp_l0_bias],dim=0)
message["mlp l0 bias V"] = torch.cat([b[1] for b in mlp_l0_bias],dim=0)
else:
message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0)
queue_put(f"transformer layer {layer_num}", message)
# Send final norm from tp_rank 0.
message = {
"weight": model.language_model.encoder.final_norm.weight.data,
}
queue_put("final norm", message)
if md.output_layer:
message = {
"weight": model.language_model.output_layer.weight.data
}
queue_put("output layer", message)
queue.put("done")
if args.checkpoint_type == "meta":
shutil.rmtree(os.path.join(args.load_dir))
def load_checkpoint(queue, args):
try:
_load_checkpoint(queue, args)
except Exception:
queue.put("exit")
raise
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment