Unverified Commit 0f7853c6 authored by Xuechen Li's avatar Xuechen Li Committed by GitHub
Browse files

enable loading hf llama checkpoints for training (#446)

* prelim.

* add hf convertion fn.

* mlp.

* change name.

* fix bug.

* inverse permute.

* change comment.

* revert style changes.

* fix.

* add doc.

* revert.

* enable load safe.

* fix safe load.

* fix import.

* fix typing-related lints.

* fix ckpt loading logic.

* make single gpu work.

* test with parallel.

* ckpt format.

* enable pretrained state dict.

* remove unused imports.

* remove unused.

* mark idea related.
parent c60851a8
...@@ -19,3 +19,6 @@ var/ ...@@ -19,3 +19,6 @@ var/
*.egg-info/ *.egg-info/
.installed.cfg .installed.cfg
*.egg *.egg
# IDE-related
.idea/
# Copyright (c) 2023, Tri Dao. # Copyright (c) 2023, Tri Dao.
import math
import json import json
import math
import os
import re import re
from pathlib import Path
from collections import OrderedDict from collections import OrderedDict
from pathlib import Path
from typing import Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import GPT2Config, LlamaConfig from transformers import GPT2Config, LlamaConfig
...@@ -74,10 +74,91 @@ def remap_state_dict_meta_llama(state_dict, config): ...@@ -74,10 +74,91 @@ def remap_state_dict_meta_llama(state_dict, config):
r'transformer.layers.\1.mixer.out_proj.', key) r'transformer.layers.\1.mixer.out_proj.', key)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
state_dict.pop("transformer.rope.freqs", None)
return state_dict
def remap_state_dict_hf_llama(state_dict, config):
# Embedding
def key_mapping_emb(key):
return re.sub(r'^model.embed_tokens.', 'transformer.embeddings.word_embeddings.', key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
# LM head
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
else:
output_embeddings = state_dict.pop('lm_head.weight')
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
# differently.
vocab_size = (math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
# MLP
for l in range(config.n_layer):
# Fusing weights this way based on difference in the following:
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/modeling_llama.py#L220
# https://github.com/Dao-AILab/flash-attention/blob/c60851a8253257eb970e06a022c82517a8033e8c/flash_attn/modules/mlp.py#L115
w1 = state_dict.pop(f'model.layers.{l}.mlp.gate_proj.weight')
w3 = state_dict.pop(f'model.layers.{l}.mlp.up_proj.weight')
state_dict[f'transformer.layers.{l}.mlp.fc1.weight'] = torch.cat([w3, w1], dim=0)
def key_mapping_mlp(key):
return re.sub(r'^model.layers.(\d+).mlp.down_proj.',
r'transformer.layers.\1.mlp.fc2.', key)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^model.norm.', r'transformer.ln_f.', key)
key = re.sub(r'^model.layers.(\d+).input_layernorm.', r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^model.layers.(\d+).post_attention_layernorm.', r'transformer.layers.\1.norm2.', key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
def inv_permute(w):
# Inverse of permute implemented in:
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
return w.reshape(
config.n_head, 2, config.n_embd // config.n_head // 2, config.n_embd
).transpose(1, 2).reshape(config.n_embd, config.n_embd)
# Attention
for l in range(config.n_layer):
Wq = state_dict.pop(f'model.layers.{l}.self_attn.q_proj.weight')
Wk = state_dict.pop(f'model.layers.{l}.self_attn.k_proj.weight')
Wv = state_dict.pop(f'model.layers.{l}.self_attn.v_proj.weight')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat(
[inv_permute(Wq), inv_permute(Wk), Wv], dim=0
)
# We don't store these
state_dict.pop(f'model.layers.{l}.self_attn.rotary_emb.inv_freq', None)
def key_mapping_attn(key):
return re.sub(r'^model.layers.(\d+).self_attn.o_proj.',
r'transformer.layers.\1.mixer.out_proj.', key)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict return state_dict
def config_from_checkpoint(checkpoint_path: str, model_name: str) -> LlamaConfig: def config_from_meta_checkpoint(checkpoint_path: Union[str, os.PathLike], model_name: str) -> LlamaConfig:
"""Load a LlamaConfig from a checkpoint path.""" """Load a LlamaConfig from a checkpoint path."""
with open(Path(checkpoint_path) / model_name / 'params.json') as f: with open(Path(checkpoint_path) / model_name / 'params.json') as f:
params = json.load(f) params = json.load(f)
...@@ -88,7 +169,20 @@ def config_from_checkpoint(checkpoint_path: str, model_name: str) -> LlamaConfig ...@@ -88,7 +169,20 @@ def config_from_checkpoint(checkpoint_path: str, model_name: str) -> LlamaConfig
return config return config
def state_dicts_from_checkpoint(checkpoint_path: str, model_name: str) -> dict: def config_from_hf_checkpoint(checkpoint_path: Union[str, os.PathLike], model_name: str) -> LlamaConfig:
return LlamaConfig.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf' / "config.json")
def config_from_checkpoint(
checkpoint_path: Union[str, os.PathLike], model_name: str, checkpoint_format="meta"
) -> LlamaConfig:
if checkpoint_format == "meta":
return config_from_meta_checkpoint(checkpoint_path, model_name)
else:
return config_from_hf_checkpoint(checkpoint_path, model_name)
def state_dicts_from_checkpoint(checkpoint_path: Union[str, os.PathLike], model_name: str) -> list[dict]:
# Need to sort, otherwise we mess up the ordering and the weights are wrong # Need to sort, otherwise we mess up the ordering and the weights are wrong
return [torch.load(path, map_location='cpu') return [torch.load(path, map_location='cpu')
for path in sorted((Path(checkpoint_path) / model_name).glob('consolidated.*.pth'))] for path in sorted((Path(checkpoint_path) / model_name).glob('consolidated.*.pth'))]
......
import torch import os
from functools import partial
from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME import torch
from transformers.utils import is_remote_url from safetensors.torch import load_file as safe_load_file
from transformers.modeling_utils import load_state_dict from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
from transformers.utils.hub import cached_file, get_checkpoint_shard_files from transformers.utils.hub import cached_file, get_checkpoint_shard_files
...@@ -10,15 +11,39 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None): ...@@ -10,15 +11,39 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None):
# If not fp32, then we don't want to load directly to the GPU # If not fp32, then we don't want to load directly to the GPU
mapped_device = 'cpu' if dtype not in [torch.float32, None] else device mapped_device = 'cpu' if dtype not in [torch.float32, None] else device
is_sharded = False is_sharded = False
load_safe = False
resolved_archive_file = None
weights_path = os.path.join(model_name, WEIGHTS_NAME)
weights_index_path = os.path.join(model_name, WEIGHTS_INDEX_NAME)
safe_weights_path = os.path.join(model_name, SAFE_WEIGHTS_NAME)
safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME)
if os.path.isfile(weights_path):
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
_raise_exceptions_for_missing_entries=False) _raise_exceptions_for_missing_entries=False)
if resolved_archive_file is None: elif os.path.isfile(weights_index_path):
resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME, resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME,
_raise_exceptions_for_missing_entries=False) _raise_exceptions_for_missing_entries=False)
if resolved_archive_file is not None:
is_sharded = True is_sharded = True
elif os.path.isfile(safe_weights_path):
resolved_archive_file = cached_file(model_name, SAFE_WEIGHTS_NAME,
_raise_exceptions_for_missing_entries=False)
load_safe = True
elif os.path.isfile(safe_weights_index_path):
resolved_archive_file = cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME,
_raise_exceptions_for_missing_entries=False)
is_sharded = True
load_safe = True
if resolved_archive_file is None: if resolved_archive_file is None:
raise EnvironmentError(f"Model name {model_name} was not found.") raise EnvironmentError(f"Model name {model_name} was not found.")
if load_safe:
loader = partial(safe_load_file, device=mapped_device)
else:
loader = partial(torch.load, map_location=mapped_device)
if is_sharded: if is_sharded:
# resolved_archive_file becomes a list of files that point to the different # resolved_archive_file becomes a list of files that point to the different
# checkpoint shards in this case. # checkpoint shards in this case.
...@@ -27,9 +52,9 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None): ...@@ -27,9 +52,9 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None):
) )
state_dict = {} state_dict = {}
for sharded_file in resolved_archive_file: for sharded_file in resolved_archive_file:
state_dict.update(torch.load(sharded_file, map_location=mapped_device)) state_dict.update(loader(sharded_file))
else: else:
state_dict = torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device) state_dict = loader(resolved_archive_file)
# Convert dtype before moving to GPU to save memory # Convert dtype before moving to GPU to save memory
if dtype is not None: if dtype is not None:
state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
import os import os
import time import time
from pathlib import Path from pathlib import Path
current_dir = Path(__file__).parent.absolute() current_dir = Path(__file__).parent.absolute()
import torch import torch
...@@ -15,17 +16,28 @@ import pytest ...@@ -15,17 +16,28 @@ import pytest
from einops import rearrange from einops import rearrange
from transformers import LlamaConfig, LlamaTokenizer from transformers import LlamaTokenizer
from transformers.models.llama.modeling_llama import LlamaForCausalLM from transformers.models.llama.modeling_llama import LlamaForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp
from flash_attn.models.llama import remap_state_dict_meta_llama, llama_config_to_gpt2_config from flash_attn.models.llama import remap_state_dict_meta_llama, llama_config_to_gpt2_config, remap_state_dict_hf_llama
from flash_attn.models.llama import config_from_checkpoint, state_dicts_from_checkpoint from flash_attn.models.llama import config_from_checkpoint, state_dicts_from_checkpoint
from flash_attn.utils.distributed import all_gather_raw from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.pretrained import state_dict_from_pretrained from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import update_graph_cache from flash_attn.utils.generation import update_graph_cache
def _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, config, checkpoint_format):
if checkpoint_format == "meta":
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts]
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
else:
pretrained_state_dict = state_dict_from_pretrained(Path(checkpoint_path) / f'{model_name}-hf')
pretrained_state_dict = remap_state_dict_hf_llama(pretrained_state_dict, config)
return pretrained_state_dict
@pytest.mark.parametrize('model_name', ["7B"]) @pytest.mark.parametrize('model_name', ["7B"])
def test_llama_state_dict(model_name): def test_llama_state_dict(model_name):
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR', checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
...@@ -41,8 +53,8 @@ def test_llama_state_dict(model_name): ...@@ -41,8 +53,8 @@ def test_llama_state_dict(model_name):
@pytest.mark.parametrize('model_name', ["7B", "13B"]) @pytest.mark.parametrize('model_name', ["7B", "13B"])
# @pytest.mark.parametrize('model_name', ["7B"]) @pytest.mark.parametrize('checkpoint_format', ["meta", "hf"])
def test_llama_optimized(model_name): def test_llama_optimized(model_name, checkpoint_format):
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the """Check that our implementation of LLaMa (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32. forward pass in fp16, when compared to the HF forward pass in fp32.
...@@ -52,16 +64,17 @@ def test_llama_optimized(model_name): ...@@ -52,16 +64,17 @@ def test_llama_optimized(model_name):
dtype = torch.float16 dtype = torch.float16
device = 'cuda' device = 'cuda'
config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name)) config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
config = llama_config_to_gpt2_config(config)
config.use_flash_attn = True config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = True config.fused_dropout_add_ln = True
config.residual_in_fp32 = True config.residual_in_fp32 = True
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name) pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts] checkpoint_path, model_name, config, checkpoint_format
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config) )
model = GPTLMHeadModel(config, device=device, dtype=dtype) model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(pretrained_state_dict) model.load_state_dict(pretrained_state_dict)
model.eval() model.eval()
...@@ -111,7 +124,8 @@ def test_llama_optimized(model_name): ...@@ -111,7 +124,8 @@ def test_llama_optimized(model_name):
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel" # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel"
@pytest.mark.parametrize('world_size', [2]) @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('model_name', ["13B"]) @pytest.mark.parametrize('model_name', ["13B"])
def test_llama_parallel(model_name, world_size): @pytest.mark.parametrize('checkpoint_format', ["meta", "hf"])
def test_llama_parallel(model_name, world_size, checkpoint_format):
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the """Check that our implementation of LLaMa (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32. forward pass in fp16, when compared to the HF forward pass in fp32.
...@@ -122,7 +136,8 @@ def test_llama_parallel(model_name, world_size): ...@@ -122,7 +136,8 @@ def test_llama_parallel(model_name, world_size):
current_dir.parent.parent / 'checkpoints')) / 'llama' current_dir.parent.parent / 'checkpoints')) / 'llama'
dtype = torch.float16 dtype = torch.float16
config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name)) config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
config = llama_config_to_gpt2_config(config)
config.use_flash_attn = True config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet config.fused_mlp = False # We don't have fused GatedMLP yet
...@@ -137,10 +152,9 @@ def test_llama_parallel(model_name, world_size): ...@@ -137,10 +152,9 @@ def test_llama_parallel(model_name, world_size):
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group() process_group = parallel_state.get_tensor_model_parallel_group()
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name) pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts] checkpoint_path, model_name, config, checkpoint_format
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config) )
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype) model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)) model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
model.eval() model.eval()
...@@ -196,13 +210,15 @@ def test_llama_parallel(model_name, world_size): ...@@ -196,13 +210,15 @@ def test_llama_parallel(model_name, world_size):
# @pytest.mark.parametrize('model_name', ["7B", "13B"]) # @pytest.mark.parametrize('model_name', ["7B", "13B"])
@pytest.mark.parametrize('model_name', ["7B"]) @pytest.mark.parametrize('model_name', ["7B"])
def test_llama_generation(model_name): @pytest.mark.parametrize('checkpoint_format', ["meta", "hf"])
def test_llama_generation(model_name, checkpoint_format):
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR', checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
current_dir.parent.parent / 'checkpoints')) / 'llama' current_dir.parent.parent / 'checkpoints')) / 'llama'
dtype = torch.float16 dtype = torch.float16
device = 'cuda' device = 'cuda'
config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name)) config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
config = llama_config_to_gpt2_config(config)
config.use_flash_attn = True config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet config.fused_mlp = False # We don't have fused GatedMLP yet
...@@ -239,9 +255,10 @@ def test_llama_generation(model_name): ...@@ -239,9 +255,10 @@ def test_llama_generation(model_name):
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1].to(device=device) logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1].to(device=device)
del model_ref del model_ref
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts] pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config) checkpoint_path, model_name, config, checkpoint_format
)
model = GPTLMHeadModel(config, device=device, dtype=dtype) model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(pretrained_state_dict) model.load_state_dict(pretrained_state_dict)
model.eval() model.eval()
...@@ -291,7 +308,8 @@ def test_llama_generation(model_name): ...@@ -291,7 +308,8 @@ def test_llama_generation(model_name):
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "llama_parallel_generation" # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "llama_parallel_generation"
@pytest.mark.parametrize('world_size', [2]) @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('model_name', ["13B"]) @pytest.mark.parametrize('model_name', ["13B"])
def test_llama_parallel_generation(model_name, world_size): @pytest.mark.parametrize('checkpoint_format', ["meta", "hf"])
def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
"""Check that our implementation matches the HF implementation: """Check that our implementation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32. the HF scores in fp32.
...@@ -302,7 +320,8 @@ def test_llama_parallel_generation(model_name, world_size): ...@@ -302,7 +320,8 @@ def test_llama_parallel_generation(model_name, world_size):
current_dir.parent.parent / 'checkpoints')) / 'llama' current_dir.parent.parent / 'checkpoints')) / 'llama'
dtype = torch.float16 dtype = torch.float16
config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name)) config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
config = llama_config_to_gpt2_config(config)
config.use_flash_attn = False config.use_flash_attn = False
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet config.fused_mlp = False # We don't have fused GatedMLP yet
...@@ -331,10 +350,9 @@ def test_llama_parallel_generation(model_name, world_size): ...@@ -331,10 +350,9 @@ def test_llama_parallel_generation(model_name, world_size):
# GPU0 and GPU1 and things would hang # GPU0 and GPU1 and things would hang
torch.cuda.set_device(device) torch.cuda.set_device(device)
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name) pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts] checkpoint_path, model_name, config, checkpoint_format
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config) )
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype) model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)) model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
model.eval() model.eval()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment