Commit 0705d271 authored by Tri Dao's avatar Tri Dao
Browse files

[Llama] Fix some tests, add tests for Llama 2 and CodeLlama

parent e0fbaa70
...@@ -13,6 +13,8 @@ import torch.nn.functional as F ...@@ -13,6 +13,8 @@ import torch.nn.functional as F
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
from transformers import GPT2Config, LlamaConfig from transformers import GPT2Config, LlamaConfig
from einops import rearrange
def remap_state_dict_meta_llama( def remap_state_dict_meta_llama(
state_dict: dict[str, torch.Tensor], config: GPT2Config state_dict: dict[str, torch.Tensor], config: GPT2Config
...@@ -30,9 +32,7 @@ def remap_state_dict_meta_llama( ...@@ -30,9 +32,7 @@ def remap_state_dict_meta_llama(
# Word embedding # Word embedding
def key_mapping_emb(key): def key_mapping_emb(key):
return re.sub( return re.sub(
r"^transformer.tok_embeddings.", r"^transformer.tok_embeddings.", "transformer.embeddings.word_embeddings.", key
"transformer.embeddings.word_embeddings.",
key,
) )
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
...@@ -113,7 +113,7 @@ def remap_state_dict_meta_llama( ...@@ -113,7 +113,7 @@ def remap_state_dict_meta_llama(
def remap_state_dict_hf_llama( def remap_state_dict_hf_llama(
state_dict: dict[str, torch.Tensor], config: GPT2Config, multi_query: bool = False state_dict: dict[str, torch.Tensor], config: GPT2Config
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
"""Convert the state_dict in Hugging Face format to standard GPT format. """Convert the state_dict in Hugging Face format to standard GPT format.
...@@ -186,13 +186,11 @@ def remap_state_dict_hf_llama( ...@@ -186,13 +186,11 @@ def remap_state_dict_hf_llama(
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
def inv_permute(w, first_dim=None): def inv_permute(w):
# Inverse of permute implemented in: # Inverse of permute implemented in:
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114 # https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
return ( return rearrange(
w.reshape(first_dim or config.n_head, 2, -1, config.n_embd) w, "(h two d) n -> (h d two) n", d=config.n_embd // config.n_head // 2, two=2
.transpose(1, 2)
.reshape(-1, config.n_embd)
) )
# Attention # Attention
...@@ -202,8 +200,7 @@ def remap_state_dict_hf_llama( ...@@ -202,8 +200,7 @@ def remap_state_dict_hf_llama(
Wv = state_dict.pop(f"model.layers.{l}.self_attn.v_proj.weight") Wv = state_dict.pop(f"model.layers.{l}.self_attn.v_proj.weight")
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat( state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat(
(inv_permute(Wq), inv_permute(Wk, getattr(config, "n_head_kv")), Wv), [inv_permute(Wq), inv_permute(Wk), Wv], dim=0
dim=0,
) )
# We don't store these # We don't store these
state_dict.pop(f"model.layers.{l}.self_attn.rotary_emb.inv_freq", None) state_dict.pop(f"model.layers.{l}.self_attn.rotary_emb.inv_freq", None)
...@@ -220,7 +217,7 @@ def remap_state_dict_hf_llama( ...@@ -220,7 +217,7 @@ def remap_state_dict_hf_llama(
def inv_remap_state_dict_hf_llama( def inv_remap_state_dict_hf_llama(
state_dict: dict[str, torch.Tensor], config: GPT2Config, multi_query: bool = False state_dict: dict[str, torch.Tensor], config: GPT2Config
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
"""Convert the state_dict in standard GPT format to Hugging Face format. """Convert the state_dict in standard GPT format to Hugging Face format.
...@@ -293,11 +290,9 @@ def inv_remap_state_dict_hf_llama( ...@@ -293,11 +290,9 @@ def inv_remap_state_dict_hf_llama(
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
def permute(w, first_dim=None): def permute(w):
return ( return rearrange(
w.view(first_dim or config.n_head, -1, 2, config.n_embd) w, "(h d two) n -> (h two d) n", d=config.n_embd // config.n_head // 2, two=2
.transpose(1, 2)
.reshape(-1, config.n_embd)
) )
n_head = config.n_head n_head = config.n_head
...@@ -316,7 +311,7 @@ def inv_remap_state_dict_hf_llama( ...@@ -316,7 +311,7 @@ def inv_remap_state_dict_hf_llama(
Wk = Wqkv[q_dim : q_dim + k_dim] Wk = Wqkv[q_dim : q_dim + k_dim]
Wv = Wqkv[q_dim + k_dim : q_dim + k_dim + v_dim] Wv = Wqkv[q_dim + k_dim : q_dim + k_dim + v_dim]
state_dict[f"model.layers.{l}.self_attn.q_proj.weight"] = permute(Wq) state_dict[f"model.layers.{l}.self_attn.q_proj.weight"] = permute(Wq)
state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk, n_head_kv) state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk)
state_dict[f"model.layers.{l}.self_attn.v_proj.weight"] = Wv state_dict[f"model.layers.{l}.self_attn.v_proj.weight"] = Wv
state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None) state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None)
......
...@@ -725,7 +725,7 @@ class ParallelMHA(nn.Module): ...@@ -725,7 +725,7 @@ class ParallelMHA(nn.Module):
process_group, process_group,
bias=qkv_proj_bias, bias=qkv_proj_bias,
sequence_parallel=sequence_parallel, sequence_parallel=sequence_parallel,
multiple_of=self.head_dim * (self.num_heads_per_rank + 2 * self.num_heads_kv_per_rank), multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),
**factory_kwargs, **factory_kwargs,
) )
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
......
...@@ -160,6 +160,7 @@ def test_tensor_parallel(model_name, rotary, world_size): ...@@ -160,6 +160,7 @@ def test_tensor_parallel(model_name, rotary, world_size):
assert torch.allclose( assert torch.allclose(
torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
) )
assert torch.equal(torch.stack(out.scores, dim=1), torch.stack(out_cg.scores, dim=1))
if not rotary: if not rotary:
assert torch.all(out.sequences == out_ref.sequences) assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences) assert torch.all(out.sequences == out_hf.sequences)
......
# Copyright (c) 2023, Tri Dao. # Copyright (c) 2023, Tri Dao.
# To run the huggingface implementation, we first need to convert the weights: # To run the huggingface implementation of LLaMa (1), we first need to convert the weights:
# https://github.com/huggingface/transformers/pull/21955 # https://github.com/huggingface/transformers/pull/21955
# python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR/llama/7B-hf # python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR/llama/7B-hf
# and repeat for 13B, 30B, 65B # and repeat for 13B, 30B, 65B
...@@ -30,6 +30,7 @@ from flash_attn.utils.generation import update_graph_cache ...@@ -30,6 +30,7 @@ from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import LlamaConfig, LlamaTokenizer from transformers import LlamaConfig, LlamaTokenizer
from transformers.models.llama.modeling_llama import LlamaForCausalLM from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers import AutoConfig
def _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, config, checkpoint_format): def _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, config, checkpoint_format):
...@@ -60,9 +61,38 @@ def test_llama_state_dict(model_name): ...@@ -60,9 +61,38 @@ def test_llama_state_dict(model_name):
assert state_dict[k].shape == pretrained_state_dict[k].shape assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize("model_name", ["7B", "13B"]) # TinyLlama-1.1B is to test MQA
@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"]) @pytest.mark.parametrize(
def test_llama_optimized(model_name, checkpoint_format): "model_name", ["meta-llama/Llama-2-7b-hf", "PY007/TinyLlama-1.1B-step-50K-105b"]
)
def test_inv_remap_state_dict_hf_llama(model_name):
config = llama_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
state_dict = state_dict_from_pretrained(model_name)
# inv_remap_state_dict_hf_llama should be the inverse of remap_state_dict_hf_llama
state_dict = {key: val for key, val in state_dict.items() if "rotary_emb.inv_freq" not in key}
pretrained_state_dict = remap_state_dict_hf_llama(state_dict, config)
state_dict_recover = inv_remap_state_dict_hf_llama(pretrained_state_dict, config)
assert set(state_dict_recover.keys()) == set(state_dict.keys())
for key in state_dict_recover.keys():
torch.testing.assert_close(state_dict_recover[key], state_dict[key])
# TinyLlama-1.1B is to test MQA
@pytest.mark.parametrize(
"model_name",
[
"7B", # Llama 1
"13B", # Llama 1
"meta-llama/Llama-2-13b-hf",
"codellama/CodeLlama-7b-hf",
"codellama/CodeLlama-13b-hf",
"codellama/CodeLlama-34b-hf",
"PY007/TinyLlama-1.1B-step-50K-105b",
],
)
def test_llama_optimized(model_name):
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the """Check that our implementation of LLaMa (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32. forward pass in fp16, when compared to the HF forward pass in fp32.
...@@ -73,17 +103,27 @@ def test_llama_optimized(model_name, checkpoint_format): ...@@ -73,17 +103,27 @@ def test_llama_optimized(model_name, checkpoint_format):
dtype = torch.float16 dtype = torch.float16
device = "cuda" device = "cuda"
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format) if "/" in model_name: # Download from HF
config = llama_config_to_gpt2_config(config) config = llama_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
else:
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta")
config = llama_config_to_gpt2_config(config)
config.use_flash_attn = True config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = True config.fused_dropout_add_ln = True
config.residual_in_fp32 = True config.residual_in_fp32 = True
pretrained_state_dict = _pretrained_state_dict_from_checkpoint( if "/" in model_name: # Download from HF
checkpoint_path, model_name, config, checkpoint_format pretrained_state_dict = remap_state_dict_hf_llama(
) state_dict_from_pretrained(model_name), config
)
else:
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
checkpoint_path, model_name, config, checkpoint_format="meta"
)
model = GPTLMHeadModel(config, device=device, dtype=dtype) model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(pretrained_state_dict) model.load_state_dict(pretrained_state_dict)
model.eval() model.eval()
...@@ -103,7 +143,8 @@ def test_llama_optimized(model_name, checkpoint_format): ...@@ -103,7 +143,8 @@ def test_llama_optimized(model_name, checkpoint_format):
# Without device_map, the model is loaded on the CPU, which is very slow # Without device_map, the model is loaded on the CPU, which is very slow
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref = LlamaForCausalLM.from_pretrained( model_ref = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f"{model_name}-hf", device_map="auto" model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
device_map="auto",
) )
model_ref.eval() model_ref.eval()
with torch.no_grad(): with torch.no_grad():
...@@ -112,7 +153,9 @@ def test_llama_optimized(model_name, checkpoint_format): ...@@ -112,7 +153,9 @@ def test_llama_optimized(model_name, checkpoint_format):
del model_ref del model_ref
model_hf = LlamaForCausalLM.from_pretrained( model_hf = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device} model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
torch_dtype=dtype,
device_map={"": device},
) )
model_hf.eval() model_hf.eval()
with torch.no_grad(): with torch.no_grad():
...@@ -135,77 +178,12 @@ def test_llama_optimized(model_name, checkpoint_format): ...@@ -135,77 +178,12 @@ def test_llama_optimized(model_name, checkpoint_format):
).abs().max().item() ).abs().max().item()
@pytest.mark.parametrize("model_name", ["PY007/TinyLlama-1.1B-step-50K-105b"])
def test_mqa_optimized(model_name):
"""Check that our implementation of Llama with MQA/GQA (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = "cuda"
config = llama_config_to_gpt2_config(LlamaConfig.from_pretrained(model_name))
config.use_flash_attn = True # FlashAttention-2 supports headdim 256
config.fused_bias_fc = True
config.fused_mlp = False
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref = LlamaForCausalLM.from_pretrained(model_name, device_map={"": device})
model_ref.eval()
model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(remap_state_dict_hf_llama(model_ref.state_dict(), config))
model.eval()
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad():
out = model.transformer(input_ids)
logits = model(input_ids).logits
del model
with torch.no_grad():
out_ref = model_ref.model(input_ids).last_hidden_state
logits_ref = model_ref(input_ids).logits
del model_ref
model_hf = LlamaForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map={"": device}
)
model_hf.eval()
out_hf = model_hf.model(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits
del model_hf
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (
out_hf - out_ref
).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel" # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel"
@pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("model_name", ["13B"]) @pytest.mark.parametrize(
@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"]) "model_name", ["13B", "meta-llama/Llama-2-13b-hf", "codellama/CodeLlama-34b-hf"]
def test_llama_parallel(model_name, world_size, checkpoint_format): )
def test_llama_parallel(model_name, world_size):
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the """Check that our implementation of LLaMa (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32. forward pass in fp16, when compared to the HF forward pass in fp32.
...@@ -217,8 +195,13 @@ def test_llama_parallel(model_name, world_size, checkpoint_format): ...@@ -217,8 +195,13 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
) )
dtype = torch.float16 dtype = torch.float16
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format) if "/" in model_name: # Download from HF
config = llama_config_to_gpt2_config(config) config = llama_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
else:
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta")
config = llama_config_to_gpt2_config(config)
config.use_flash_attn = True config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet config.fused_mlp = False # We don't have fused GatedMLP yet
...@@ -233,9 +216,14 @@ def test_llama_parallel(model_name, world_size, checkpoint_format): ...@@ -233,9 +216,14 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group() process_group = parallel_state.get_tensor_model_parallel_group()
pretrained_state_dict = _pretrained_state_dict_from_checkpoint( if "/" in model_name: # Download from HF
checkpoint_path, model_name, config, checkpoint_format pretrained_state_dict = remap_state_dict_hf_llama(
) state_dict_from_pretrained(model_name), config
)
else:
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
checkpoint_path, model_name, config, checkpoint_format="meta"
)
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype) model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)) model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
model.eval() model.eval()
...@@ -260,7 +248,8 @@ def test_llama_parallel(model_name, world_size, checkpoint_format): ...@@ -260,7 +248,8 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
if rank == 0: if rank == 0:
# Without device_map, the model is loaded on the CPU, which is very slow # Without device_map, the model is loaded on the CPU, which is very slow
model_ref = LlamaForCausalLM.from_pretrained( model_ref = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f"{model_name}-hf", device_map="auto" model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
device_map="auto",
) )
model_ref.eval() model_ref.eval()
with torch.no_grad(): with torch.no_grad():
...@@ -269,7 +258,9 @@ def test_llama_parallel(model_name, world_size, checkpoint_format): ...@@ -269,7 +258,9 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
del model_ref del model_ref
model_hf = LlamaForCausalLM.from_pretrained( model_hf = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto" model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
torch_dtype=dtype,
device_map="auto",
) )
model_hf.eval() model_hf.eval()
with torch.no_grad(): with torch.no_grad():
...@@ -405,9 +396,10 @@ def test_llama_generation(model_name, checkpoint_format): ...@@ -405,9 +396,10 @@ def test_llama_generation(model_name, checkpoint_format):
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "llama_parallel_generation" # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "llama_parallel_generation"
@pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("model_name", ["13B"]) @pytest.mark.parametrize(
@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"]) "model_name", ["13B", "meta-llama/Llama-2-13b-hf", "codellama/CodeLlama-34b-hf"]
def test_llama_parallel_generation(model_name, world_size, checkpoint_format): )
def test_llama_parallel_generation(model_name, world_size):
"""Check that our implementation matches the HF implementation: """Check that our implementation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32. the HF scores in fp32.
...@@ -419,12 +411,17 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): ...@@ -419,12 +411,17 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
) )
dtype = torch.float16 dtype = torch.float16
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format) if "/" in model_name: # Download from HF
config = llama_config_to_gpt2_config(config) config = llama_config_to_gpt2_config(
config.use_flash_attn = False AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
else:
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta")
config = llama_config_to_gpt2_config(config)
config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = False config.fused_dropout_add_ln = True
config.residual_in_fp32 = True config.residual_in_fp32 = True
config.pad_vocab_size_multiple = 8 * world_size config.pad_vocab_size_multiple = 8 * world_size
config.sequence_parallel = False # Need to set this to False for generation config.sequence_parallel = False # Need to set this to False for generation
...@@ -450,9 +447,14 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): ...@@ -450,9 +447,14 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
# GPU0 and GPU1 and things would hang # GPU0 and GPU1 and things would hang
torch.cuda.set_device(device) torch.cuda.set_device(device)
pretrained_state_dict = _pretrained_state_dict_from_checkpoint( if "/" in model_name: # Download from HF
checkpoint_path, model_name, config, checkpoint_format pretrained_state_dict = remap_state_dict_hf_llama(
) state_dict_from_pretrained(model_name), config
)
else:
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
checkpoint_path, model_name, config, checkpoint_format="meta"
)
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype) model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)) model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
model.eval() model.eval()
...@@ -490,7 +492,9 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): ...@@ -490,7 +492,9 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
if rank == 0: if rank == 0:
# Without device_map, the model is loaded on the CPU, which is very slow # Without device_map, the model is loaded on the CPU, which is very slow
model_hf = LlamaForCausalLM.from_pretrained( model_hf = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto" model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
torch_dtype=dtype,
device_map="auto",
) )
model_hf.eval() model_hf.eval()
print("HF fp16") print("HF fp16")
...@@ -508,7 +512,8 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): ...@@ -508,7 +512,8 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
del model_hf del model_hf
model_ref = LlamaForCausalLM.from_pretrained( model_ref = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f"{model_name}-hf", device_map="auto" model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
device_map="auto",
) )
model_ref.eval() model_ref.eval()
with torch.inference_mode(): with torch.inference_mode():
...@@ -594,15 +599,16 @@ def test_llama_parallel_uneven_num_heads(world_size): ...@@ -594,15 +599,16 @@ def test_llama_parallel_uneven_num_heads(world_size):
if rank == 0: if rank == 0:
model_ref = LlamaForCausalLM.from_pretrained( model_ref = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f"{model_name}-hf", device_map="auto" Path(checkpoint_path) / f"{model_name}-hf", device_map={"": device}
) )
model_ref = model_ref.to(device=device)
model_ref.eval() model_ref.eval()
out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device) out_ref = model_ref.model(input_ids).last_hidden_state
logits_ref = model_ref(input_ids).logits.to(device=device) logits_ref = model_ref(input_ids).logits
del model_ref del model_ref
model_hf = LlamaForCausalLM.from_pretrained( model_hf = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto" Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device}
) )
model_hf.eval() model_hf.eval()
out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device) out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device)
...@@ -625,42 +631,3 @@ def test_llama_parallel_uneven_num_heads(world_size): ...@@ -625,42 +631,3 @@ def test_llama_parallel_uneven_num_heads(world_size):
if os.path.exists(checkpoint_path / f"{model_name}-hf"): if os.path.exists(checkpoint_path / f"{model_name}-hf"):
shutil.rmtree(checkpoint_path / f"{model_name}-hf") shutil.rmtree(checkpoint_path / f"{model_name}-hf")
@torch.no_grad()
def test_inv_remap_state_dict_hf_llama():
checkpoint_path = (
Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
)
model_name = f"teeny"
llama_config = LlamaConfig(
num_attention_heads=2,
hidden_size=256 * 2,
intermediate_size=256 * 2 * 4,
num_hidden_layers=4,
)
config = llama_config_to_gpt2_config(llama_config)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
# Set up.
LlamaForCausalLM(config=llama_config).save_pretrained(checkpoint_path / f"{model_name}-hf")
# inv_remap_state_dict_hf_llama should be the inverse of remap_state_dict_hf_llama
state_dict = state_dict_from_pretrained(checkpoint_path / f"{model_name}-hf")
state_dict = {key: val for key, val in state_dict.items() if "rotary_emb.inv_freq" not in key}
pretrained_state_dict = remap_state_dict_hf_llama(state_dict, config)
state_dict_recover = inv_remap_state_dict_hf_llama(pretrained_state_dict, config)
assert set(state_dict_recover.keys()) == set(state_dict.keys())
for key in state_dict_recover.keys():
torch.testing.assert_close(state_dict_recover[key], state_dict[key])
# Tear down.
if os.path.exists(checkpoint_path / f"{model_name}-hf"):
shutil.rmtree(checkpoint_path / f"{model_name}-hf")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment