Unverified Commit e0b09891 authored by Aman Gupta Karmani's avatar Aman Gupta Karmani Committed by GitHub
Browse files

add llama support to GPTPreTrainedModel.from_pretrained (#479)

parent 6711b3bc
......@@ -16,6 +16,7 @@ from transformers import GPT2Config
from flash_attn.models.falcon import remap_state_dict_hf_falcon
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox
from flash_attn.models.gptj import remap_state_dict_hf_gptj
from flash_attn.models.llama import remap_state_dict_hf_llama
from flash_attn.models.opt import remap_state_dict_hf_opt
from flash_attn.modules.block import Block, ParallelBlock
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
......@@ -349,6 +350,8 @@ class GPTPreTrainedModel(nn.Module):
state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)
elif model_name.startswith("tiiuae/falcon-"):
state_dict = remap_state_dict_hf_falcon(state_dict, config)
elif model_name.startswith("meta-llama/Llama-"):
state_dict = remap_state_dict_hf_llama(state_dict, config)
else:
raise NotImplementedError(f"Model {model_name} not supported")
if world_size > 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment