Unverified Commit 0f7853c6 authored by Xuechen Li's avatar Xuechen Li Committed by GitHub
Browse files

enable loading hf llama checkpoints for training (#446)

* prelim.

* add hf convertion fn.

* mlp.

* change name.

* fix bug.

* inverse permute.

* change comment.

* revert style changes.

* fix.

* add doc.

* revert.

* enable load safe.

* fix safe load.

* fix import.

* fix typing-related lints.

* fix ckpt loading logic.

* make single gpu work.

* test with parallel.

* ckpt format.

* enable pretrained state dict.

* remove unused imports.

* remove unused.

* mark idea related.
parent c60851a8
......@@ -19,3 +19,6 @@ var/
*.egg-info/
.installed.cfg
*.egg
# IDE-related
.idea/
# Copyright (c) 2023, Tri Dao.
import math
import json
import math
import os
import re
from pathlib import Path
from collections import OrderedDict
from pathlib import Path
from typing import Union
import torch
import torch.nn.functional as F
from transformers import GPT2Config, LlamaConfig
......@@ -74,10 +74,91 @@ def remap_state_dict_meta_llama(state_dict, config):
r'transformer.layers.\1.mixer.out_proj.', key)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
state_dict.pop("transformer.rope.freqs", None)
return state_dict
def config_from_checkpoint(checkpoint_path: str, model_name: str) -> LlamaConfig:
def remap_state_dict_hf_llama(state_dict, config):
# Embedding
def key_mapping_emb(key):
return re.sub(r'^model.embed_tokens.', 'transformer.embeddings.word_embeddings.', key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
# LM head
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
else:
output_embeddings = state_dict.pop('lm_head.weight')
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
# differently.
vocab_size = (math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
# MLP
for l in range(config.n_layer):
# Fusing weights this way based on difference in the following:
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/modeling_llama.py#L220
# https://github.com/Dao-AILab/flash-attention/blob/c60851a8253257eb970e06a022c82517a8033e8c/flash_attn/modules/mlp.py#L115
w1 = state_dict.pop(f'model.layers.{l}.mlp.gate_proj.weight')
w3 = state_dict.pop(f'model.layers.{l}.mlp.up_proj.weight')
state_dict[f'transformer.layers.{l}.mlp.fc1.weight'] = torch.cat([w3, w1], dim=0)
def key_mapping_mlp(key):
return re.sub(r'^model.layers.(\d+).mlp.down_proj.',
r'transformer.layers.\1.mlp.fc2.', key)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^model.norm.', r'transformer.ln_f.', key)
key = re.sub(r'^model.layers.(\d+).input_layernorm.', r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^model.layers.(\d+).post_attention_layernorm.', r'transformer.layers.\1.norm2.', key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
def inv_permute(w):
# Inverse of permute implemented in:
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
return w.reshape(
config.n_head, 2, config.n_embd // config.n_head // 2, config.n_embd
).transpose(1, 2).reshape(config.n_embd, config.n_embd)
# Attention
for l in range(config.n_layer):
Wq = state_dict.pop(f'model.layers.{l}.self_attn.q_proj.weight')
Wk = state_dict.pop(f'model.layers.{l}.self_attn.k_proj.weight')
Wv = state_dict.pop(f'model.layers.{l}.self_attn.v_proj.weight')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat(
[inv_permute(Wq), inv_permute(Wk), Wv], dim=0
)
# We don't store these
state_dict.pop(f'model.layers.{l}.self_attn.rotary_emb.inv_freq', None)
def key_mapping_attn(key):
return re.sub(r'^model.layers.(\d+).self_attn.o_proj.',
r'transformer.layers.\1.mixer.out_proj.', key)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
def config_from_meta_checkpoint(checkpoint_path: Union[str, os.PathLike], model_name: str) -> LlamaConfig:
"""Load a LlamaConfig from a checkpoint path."""
with open(Path(checkpoint_path) / model_name / 'params.json') as f:
params = json.load(f)
......@@ -88,7 +169,20 @@ def config_from_checkpoint(checkpoint_path: str, model_name: str) -> LlamaConfig
return config
def state_dicts_from_checkpoint(checkpoint_path: str, model_name: str) -> dict:
def config_from_hf_checkpoint(checkpoint_path: Union[str, os.PathLike], model_name: str) -> LlamaConfig:
return LlamaConfig.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf' / "config.json")
def config_from_checkpoint(
checkpoint_path: Union[str, os.PathLike], model_name: str, checkpoint_format="meta"
) -> LlamaConfig:
if checkpoint_format == "meta":
return config_from_meta_checkpoint(checkpoint_path, model_name)
else:
return config_from_hf_checkpoint(checkpoint_path, model_name)
def state_dicts_from_checkpoint(checkpoint_path: Union[str, os.PathLike], model_name: str) -> list[dict]:
# Need to sort, otherwise we mess up the ordering and the weights are wrong
return [torch.load(path, map_location='cpu')
for path in sorted((Path(checkpoint_path) / model_name).glob('consolidated.*.pth'))]
......
import torch
import os
from functools import partial
from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from transformers.utils import is_remote_url
from transformers.modeling_utils import load_state_dict
import torch
from safetensors.torch import load_file as safe_load_file
from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
from transformers.utils.hub import cached_file, get_checkpoint_shard_files
......@@ -10,15 +11,39 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None):
# If not fp32, then we don't want to load directly to the GPU
mapped_device = 'cpu' if dtype not in [torch.float32, None] else device
is_sharded = False
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
_raise_exceptions_for_missing_entries=False)
if resolved_archive_file is None:
load_safe = False
resolved_archive_file = None
weights_path = os.path.join(model_name, WEIGHTS_NAME)
weights_index_path = os.path.join(model_name, WEIGHTS_INDEX_NAME)
safe_weights_path = os.path.join(model_name, SAFE_WEIGHTS_NAME)
safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME)
if os.path.isfile(weights_path):
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
_raise_exceptions_for_missing_entries=False)
elif os.path.isfile(weights_index_path):
resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME,
_raise_exceptions_for_missing_entries=False)
if resolved_archive_file is not None:
is_sharded = True
is_sharded = True
elif os.path.isfile(safe_weights_path):
resolved_archive_file = cached_file(model_name, SAFE_WEIGHTS_NAME,
_raise_exceptions_for_missing_entries=False)
load_safe = True
elif os.path.isfile(safe_weights_index_path):
resolved_archive_file = cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME,
_raise_exceptions_for_missing_entries=False)
is_sharded = True
load_safe = True
if resolved_archive_file is None:
raise EnvironmentError(f"Model name {model_name} was not found.")
if load_safe:
loader = partial(safe_load_file, device=mapped_device)
else:
loader = partial(torch.load, map_location=mapped_device)
if is_sharded:
# resolved_archive_file becomes a list of files that point to the different
# checkpoint shards in this case.
......@@ -27,9 +52,9 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None):
)
state_dict = {}
for sharded_file in resolved_archive_file:
state_dict.update(torch.load(sharded_file, map_location=mapped_device))
state_dict.update(loader(sharded_file))
else:
state_dict = torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device)
state_dict = loader(resolved_archive_file)
# Convert dtype before moving to GPU to save memory
if dtype is not None:
state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
......
......@@ -8,6 +8,7 @@
import os
import time
from pathlib import Path
current_dir = Path(__file__).parent.absolute()
import torch
......@@ -15,17 +16,28 @@ import pytest
from einops import rearrange
from transformers import LlamaConfig, LlamaTokenizer
from transformers import LlamaTokenizer
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp
from flash_attn.models.llama import remap_state_dict_meta_llama, llama_config_to_gpt2_config
from flash_attn.models.llama import remap_state_dict_meta_llama, llama_config_to_gpt2_config, remap_state_dict_hf_llama
from flash_attn.models.llama import config_from_checkpoint, state_dicts_from_checkpoint
from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import update_graph_cache
def _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, config, checkpoint_format):
if checkpoint_format == "meta":
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts]
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
else:
pretrained_state_dict = state_dict_from_pretrained(Path(checkpoint_path) / f'{model_name}-hf')
pretrained_state_dict = remap_state_dict_hf_llama(pretrained_state_dict, config)
return pretrained_state_dict
@pytest.mark.parametrize('model_name', ["7B"])
def test_llama_state_dict(model_name):
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
......@@ -41,8 +53,8 @@ def test_llama_state_dict(model_name):
@pytest.mark.parametrize('model_name', ["7B", "13B"])
# @pytest.mark.parametrize('model_name', ["7B"])
def test_llama_optimized(model_name):
@pytest.mark.parametrize('checkpoint_format', ["meta", "hf"])
def test_llama_optimized(model_name, checkpoint_format):
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
......@@ -52,16 +64,17 @@ def test_llama_optimized(model_name):
dtype = torch.float16
device = 'cuda'
config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name))
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
config = llama_config_to_gpt2_config(config)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts]
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
checkpoint_path, model_name, config, checkpoint_format
)
model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(pretrained_state_dict)
model.eval()
......@@ -111,7 +124,8 @@ def test_llama_optimized(model_name):
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel"
@pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('model_name', ["13B"])
def test_llama_parallel(model_name, world_size):
@pytest.mark.parametrize('checkpoint_format', ["meta", "hf"])
def test_llama_parallel(model_name, world_size, checkpoint_format):
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
......@@ -122,7 +136,8 @@ def test_llama_parallel(model_name, world_size):
current_dir.parent.parent / 'checkpoints')) / 'llama'
dtype = torch.float16
config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name))
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
config = llama_config_to_gpt2_config(config)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
......@@ -137,10 +152,9 @@ def test_llama_parallel(model_name, world_size):
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts]
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
checkpoint_path, model_name, config, checkpoint_format
)
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
model.eval()
......@@ -196,13 +210,15 @@ def test_llama_parallel(model_name, world_size):
# @pytest.mark.parametrize('model_name', ["7B", "13B"])
@pytest.mark.parametrize('model_name', ["7B"])
def test_llama_generation(model_name):
@pytest.mark.parametrize('checkpoint_format', ["meta", "hf"])
def test_llama_generation(model_name, checkpoint_format):
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
current_dir.parent.parent / 'checkpoints')) / 'llama'
dtype = torch.float16
device = 'cuda'
config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name))
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
config = llama_config_to_gpt2_config(config)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
......@@ -239,9 +255,10 @@ def test_llama_generation(model_name):
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1].to(device=device)
del model_ref
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts]
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
checkpoint_path, model_name, config, checkpoint_format
)
model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(pretrained_state_dict)
model.eval()
......@@ -291,7 +308,8 @@ def test_llama_generation(model_name):
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "llama_parallel_generation"
@pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('model_name', ["13B"])
def test_llama_parallel_generation(model_name, world_size):
@pytest.mark.parametrize('checkpoint_format', ["meta", "hf"])
def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
"""Check that our implementation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
......@@ -302,7 +320,8 @@ def test_llama_parallel_generation(model_name, world_size):
current_dir.parent.parent / 'checkpoints')) / 'llama'
dtype = torch.float16
config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name))
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
config = llama_config_to_gpt2_config(config)
config.use_flash_attn = False
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
......@@ -331,10 +350,9 @@ def test_llama_parallel_generation(model_name, world_size):
# GPU0 and GPU1 and things would hang
torch.cuda.set_device(device)
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts]
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
checkpoint_path, model_name, config, checkpoint_format
)
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
model.eval()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment