Unverified Commit 7fcd3e6a authored by Xuechen Li's avatar Xuechen Li Committed by GitHub
Browse files

map custom model state_dict back to huggingface format (#465)

* fix name.

* set inv function.

* add map back function.

* handle gqa.

* add type annotation to avoid confusion.

* fix docstr.

* test inverse remap logic.
parent f1a73d07
......@@ -785,8 +785,10 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
def combine_state_dicts_tp(state_dicts, config):
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
with tensor parallel.
"""Convert the state_dict of a GPT model with tensor parallel to the state_dict of a
standard GPT model.
This function is meant to be the "reverse" of shard_state_dict_tp.
"""
world_size = len(state_dicts)
keys = state_dicts[0].keys()
......
......@@ -13,7 +13,14 @@ import torch.nn.functional as F
from transformers import GPT2Config, LlamaConfig
def remap_state_dict_meta_llama(state_dict, config):
def remap_state_dict_meta_llama(
state_dict: dict[str, torch.Tensor], config: GPT2Config
) -> dict[str, torch.Tensor]:
"""Convert the state_dict in Meta format to standard GPT format.
This function modifies state_dict in place.
"""
def key_mapping_layers(key):
return f"transformer.{key}" if not key.startswith("output.") else key
......@@ -97,7 +104,13 @@ def remap_state_dict_meta_llama(state_dict, config):
return state_dict
def remap_state_dict_hf_llama(state_dict, config):
def remap_state_dict_hf_llama(
state_dict: dict[str, torch.Tensor], config: GPT2Config
) -> dict[str, torch.Tensor]:
"""Convert the state_dict in Hugging Face format to standard GPT format.
This function modifies state_dict in place.
"""
# Embedding
def key_mapping_emb(key):
return re.sub(r"^model.embed_tokens.", "transformer.embeddings.word_embeddings.", key)
......@@ -183,6 +196,106 @@ def remap_state_dict_hf_llama(state_dict, config):
return state_dict
def inv_remap_state_dict_hf_llama(
state_dict: dict[str, torch.Tensor], config: GPT2Config
) -> dict[str, torch.Tensor]:
"""Convert the state_dict in standard GPT format to Hugging Face format.
This function is meant to be the inverse of remap_state_dict_hf_llama, up to a
multiplier pad in the embedding and lm_head. That is if the original embedding
isn't a multiple of pad_vocab_size_multiple, then
inv_remap_state_dict_hf_llama(remap_state_dict_hf_llama(state_dict)) != state_dict.
This function modifies state_dict in place.
"""
# Embedding
def key_mapping_emb(key):
return re.sub(r"^transformer.embeddings.word_embeddings.", "model.embed_tokens.", key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop("model.embed_tokens.weight")
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = (
math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
)
state_dict["model.embed_tokens.weight"] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
# LM head
if getattr(config, "tie_word_embeddings"):
state_dict["lm_head.weight"] = state_dict["model.embed_tokens.weight"]
else:
output_embeddings = state_dict.pop("lm_head.weight")
vocab_size = (
math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple
)
state_dict["lm_head.weight"] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
# MLP
for l in range(config.n_layer):
w3, w1 = torch.chunk(
state_dict.pop(f"transformer.layers.{l}.mlp.fc1.weight"), chunks=2, dim=0
)
state_dict[f"model.layers.{l}.mlp.gate_proj.weight"] = w1
state_dict[f"model.layers.{l}.mlp.up_proj.weight"] = w3
def key_mapping_mlp(key):
return re.sub(r"^transformer.layers.(\d+).mlp.fc2.", r"model.layers.\1.mlp.down_proj.", key)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r"^transformer.ln_f.", r"model.norm.", key)
key = re.sub(r"^transformer.layers.(\d+).norm1.", r"model.layers.\1.input_layernorm.", key)
key = re.sub(
r"^transformer.layers.(\d+).norm2.", r"model.layers.\1.post_attention_layernorm.", key
)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
def permute(w):
return (
w.view(config.n_head, config.n_embd // config.n_head // 2, 2, config.n_embd)
.transpose(1, 2)
.reshape(config.n_embd, config.n_embd)
)
n_head = config.n_head
n_head_kv = getattr(config, "n_head_kv", n_head)
embed_dim = config.hidden_size
head_dim = embed_dim // n_head
q_dim = n_head * head_dim
k_dim = v_dim = n_head_kv * head_dim
# Attention
for l in range(config.n_layer):
Wqkv = state_dict.pop(f"transformer.layers.{l}.mixer.Wqkv.weight")
Wq = Wqkv[:q_dim]
Wk = Wqkv[q_dim : q_dim + k_dim]
Wv = Wqkv[q_dim + k_dim : q_dim + k_dim + v_dim]
state_dict[f"model.layers.{l}.self_attn.q_proj.weight"] = permute(Wq)
state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk)
state_dict[f"model.layers.{l}.self_attn.v_proj.weight"] = Wv
state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None)
def key_mapping_attn(key):
return re.sub(
r"^transformer.layers.(\d+).mixer.out_proj.", r"model.layers.\1.self_attn.o_proj.", key
)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
def config_from_meta_checkpoint(
checkpoint_path: Union[str, os.PathLike], model_name: str
) -> LlamaConfig:
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment