Unverified Commit 42832575 authored by Kevin Hu's avatar Kevin Hu Committed by GitHub
Browse files

Fix Llama GQA/MQA (#546)

* Fix llama MQA

* Fix permute shape

* Update llama.py
parent dfe29f5e
...@@ -26,10 +26,13 @@ def remap_state_dict_meta_llama( ...@@ -26,10 +26,13 @@ def remap_state_dict_meta_llama(
return f"transformer.{key}" if not key.startswith("output.") else key return f"transformer.{key}" if not key.startswith("output.") else key
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding # Word embedding
def key_mapping_emb(key): def key_mapping_emb(key):
return re.sub( return re.sub(
r"^transformer.tok_embeddings.", "transformer.embeddings.word_embeddings.", key r"^transformer.tok_embeddings.",
"transformer.embeddings.word_embeddings.",
key,
) )
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
...@@ -61,7 +64,9 @@ def remap_state_dict_meta_llama( ...@@ -61,7 +64,9 @@ def remap_state_dict_meta_llama(
def key_mapping_ln(key): def key_mapping_ln(key):
key = re.sub(r"^transformer.norm.", r"transformer.ln_f.", key) key = re.sub(r"^transformer.norm.", r"transformer.ln_f.", key)
key = re.sub( key = re.sub(
r"^transformer.layers.(\d+).attention_norm.", r"transformer.layers.\1.norm1.", key r"^transformer.layers.(\d+).attention_norm.",
r"transformer.layers.\1.norm1.",
key,
) )
key = re.sub(r"^transformer.layers.(\d+).ffn_norm.", r"transformer.layers.\1.norm2.", key) key = re.sub(r"^transformer.layers.(\d+).ffn_norm.", r"transformer.layers.\1.norm2.", key)
return key return key
...@@ -77,7 +82,9 @@ def remap_state_dict_meta_llama( ...@@ -77,7 +82,9 @@ def remap_state_dict_meta_llama(
def key_mapping_mlp(key): def key_mapping_mlp(key):
return re.sub( return re.sub(
r"^transformer.layers.(\d+).feed_forward.w2.", r"transformer.layers.\1.mlp.fc2.", key r"^transformer.layers.(\d+).feed_forward.w2.",
r"transformer.layers.\1.mlp.fc2.",
key,
) )
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
...@@ -106,12 +113,13 @@ def remap_state_dict_meta_llama( ...@@ -106,12 +113,13 @@ def remap_state_dict_meta_llama(
def remap_state_dict_hf_llama( def remap_state_dict_hf_llama(
state_dict: dict[str, torch.Tensor], config: GPT2Config state_dict: dict[str, torch.Tensor], config: GPT2Config, multi_query: bool = False
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
"""Convert the state_dict in Hugging Face format to standard GPT format. """Convert the state_dict in Hugging Face format to standard GPT format.
This function modifies state_dict in place. This function modifies state_dict in place.
""" """
# Embedding # Embedding
def key_mapping_emb(key): def key_mapping_emb(key):
return re.sub(r"^model.embed_tokens.", "transformer.embeddings.word_embeddings.", key) return re.sub(r"^model.embed_tokens.", "transformer.embeddings.word_embeddings.", key)
...@@ -153,28 +161,38 @@ def remap_state_dict_hf_llama( ...@@ -153,28 +161,38 @@ def remap_state_dict_hf_llama(
state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0) state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0)
def key_mapping_mlp(key): def key_mapping_mlp(key):
return re.sub(r"^model.layers.(\d+).mlp.down_proj.", r"transformer.layers.\1.mlp.fc2.", key) return re.sub(
r"^model.layers.(\d+).mlp.down_proj.",
r"transformer.layers.\1.mlp.fc2.",
key,
)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# LayerNorm # LayerNorm
def key_mapping_ln(key): def key_mapping_ln(key):
key = re.sub(r"^model.norm.", r"transformer.ln_f.", key) key = re.sub(r"^model.norm.", r"transformer.ln_f.", key)
key = re.sub(r"^model.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key)
key = re.sub( key = re.sub(
r"^model.layers.(\d+).post_attention_layernorm.", r"transformer.layers.\1.norm2.", key r"^model.layers.(\d+).input_layernorm.",
r"transformer.layers.\1.norm1.",
key,
)
key = re.sub(
r"^model.layers.(\d+).post_attention_layernorm.",
r"transformer.layers.\1.norm2.",
key,
) )
return key return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
def inv_permute(w): def inv_permute(w, first_dim=None):
# Inverse of permute implemented in: # Inverse of permute implemented in:
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114 # https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
return ( return (
w.reshape(config.n_head, 2, config.n_embd // config.n_head // 2, config.n_embd) w.reshape(first_dim or config.n_head, 2, -1, config.n_embd)
.transpose(1, 2) .transpose(1, 2)
.reshape(config.n_embd, config.n_embd) .reshape(-1, config.n_embd)
) )
# Attention # Attention
...@@ -182,15 +200,19 @@ def remap_state_dict_hf_llama( ...@@ -182,15 +200,19 @@ def remap_state_dict_hf_llama(
Wq = state_dict.pop(f"model.layers.{l}.self_attn.q_proj.weight") Wq = state_dict.pop(f"model.layers.{l}.self_attn.q_proj.weight")
Wk = state_dict.pop(f"model.layers.{l}.self_attn.k_proj.weight") Wk = state_dict.pop(f"model.layers.{l}.self_attn.k_proj.weight")
Wv = state_dict.pop(f"model.layers.{l}.self_attn.v_proj.weight") Wv = state_dict.pop(f"model.layers.{l}.self_attn.v_proj.weight")
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat( state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat(
[inv_permute(Wq), inv_permute(Wk), Wv], dim=0 (inv_permute(Wq), inv_permute(Wk, getattr(config, "n_head_kv")), Wv),
dim=0,
) )
# We don't store these # We don't store these
state_dict.pop(f"model.layers.{l}.self_attn.rotary_emb.inv_freq", None) state_dict.pop(f"model.layers.{l}.self_attn.rotary_emb.inv_freq", None)
def key_mapping_attn(key): def key_mapping_attn(key):
return re.sub( return re.sub(
r"^model.layers.(\d+).self_attn.o_proj.", r"transformer.layers.\1.mixer.out_proj.", key r"^model.layers.(\d+).self_attn.o_proj.",
r"transformer.layers.\1.mixer.out_proj.",
key,
) )
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
...@@ -198,7 +220,7 @@ def remap_state_dict_hf_llama( ...@@ -198,7 +220,7 @@ def remap_state_dict_hf_llama(
def inv_remap_state_dict_hf_llama( def inv_remap_state_dict_hf_llama(
state_dict: dict[str, torch.Tensor], config: GPT2Config state_dict: dict[str, torch.Tensor], config: GPT2Config, multi_query: bool = False
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
"""Convert the state_dict in standard GPT format to Hugging Face format. """Convert the state_dict in standard GPT format to Hugging Face format.
...@@ -246,26 +268,36 @@ def inv_remap_state_dict_hf_llama( ...@@ -246,26 +268,36 @@ def inv_remap_state_dict_hf_llama(
state_dict[f"model.layers.{l}.mlp.up_proj.weight"] = w3 state_dict[f"model.layers.{l}.mlp.up_proj.weight"] = w3
def key_mapping_mlp(key): def key_mapping_mlp(key):
return re.sub(r"^transformer.layers.(\d+).mlp.fc2.", r"model.layers.\1.mlp.down_proj.", key) return re.sub(
r"^transformer.layers.(\d+).mlp.fc2.",
r"model.layers.\1.mlp.down_proj.",
key,
)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# LayerNorm # LayerNorm
def key_mapping_ln(key): def key_mapping_ln(key):
key = re.sub(r"^transformer.ln_f.", r"model.norm.", key) key = re.sub(r"^transformer.ln_f.", r"model.norm.", key)
key = re.sub(r"^transformer.layers.(\d+).norm1.", r"model.layers.\1.input_layernorm.", key)
key = re.sub( key = re.sub(
r"^transformer.layers.(\d+).norm2.", r"model.layers.\1.post_attention_layernorm.", key r"^transformer.layers.(\d+).norm1.",
r"model.layers.\1.input_layernorm.",
key,
)
key = re.sub(
r"^transformer.layers.(\d+).norm2.",
r"model.layers.\1.post_attention_layernorm.",
key,
) )
return key return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
def permute(w): def permute(w, first_dim=None):
return ( return (
w.view(config.n_head, config.n_embd // config.n_head // 2, 2, config.n_embd) w.view(first_dim or config.n_head, -1, 2, config.n_embd)
.transpose(1, 2) .transpose(1, 2)
.reshape(config.n_embd, config.n_embd) .reshape(-1, config.n_embd)
) )
n_head = config.n_head n_head = config.n_head
...@@ -284,13 +316,15 @@ def inv_remap_state_dict_hf_llama( ...@@ -284,13 +316,15 @@ def inv_remap_state_dict_hf_llama(
Wk = Wqkv[q_dim : q_dim + k_dim] Wk = Wqkv[q_dim : q_dim + k_dim]
Wv = Wqkv[q_dim + k_dim : q_dim + k_dim + v_dim] Wv = Wqkv[q_dim + k_dim : q_dim + k_dim + v_dim]
state_dict[f"model.layers.{l}.self_attn.q_proj.weight"] = permute(Wq) state_dict[f"model.layers.{l}.self_attn.q_proj.weight"] = permute(Wq)
state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk) state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk, n_head_kv)
state_dict[f"model.layers.{l}.self_attn.v_proj.weight"] = Wv state_dict[f"model.layers.{l}.self_attn.v_proj.weight"] = Wv
state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None) state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None)
def key_mapping_attn(key): def key_mapping_attn(key):
return re.sub( return re.sub(
r"^transformer.layers.(\d+).mixer.out_proj.", r"model.layers.\1.self_attn.o_proj.", key r"^transformer.layers.(\d+).mixer.out_proj.",
r"model.layers.\1.self_attn.o_proj.",
key,
) )
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
......
...@@ -135,6 +135,72 @@ def test_llama_optimized(model_name, checkpoint_format): ...@@ -135,6 +135,72 @@ def test_llama_optimized(model_name, checkpoint_format):
).abs().max().item() ).abs().max().item()
@pytest.mark.parametrize("model_name", ["PY007/TinyLlama-1.1B-step-50K-105b"])
def test_mqa_optimized(model_name):
"""Check that our implementation of Llama with MQA/GQA (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = "cuda"
config = llama_config_to_gpt2_config(LlamaConfig.from_pretrained(model_name))
config.use_flash_attn = True # FlashAttention-2 supports headdim 256
config.fused_bias_fc = True
config.fused_mlp = False
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref = LlamaForCausalLM.from_pretrained(model_name, device_map={"": device})
model_ref.eval()
model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(remap_state_dict_hf_llama(model_ref.state_dict(), config))
model.eval()
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad():
out = model.transformer(input_ids)
logits = model(input_ids).logits
del model
with torch.no_grad():
out_ref = model_ref.model(input_ids).last_hidden_state
logits_ref = model_ref(input_ids).logits
del model_ref
model_hf = LlamaForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map={"": device}
)
model_hf.eval()
out_hf = model_hf.model(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits
del model_hf
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (
out_hf - out_ref
).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel" # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel"
@pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("model_name", ["13B"]) @pytest.mark.parametrize("model_name", ["13B"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment