"examples/textual_inversion/textual_inversion_sdxl.py" did not exist on "029fb41695a7940c213d914471fb41a1df67aa17"
Unverified Commit c9d4a816 authored by dan_the_3rd's avatar dan_the_3rd Committed by GitHub
Browse files

Support LLaMa2 and CodeLLaMa (#491)

Co-authored-by: danthe3rd <danthe3rd>
parent 011ec323
......@@ -10,6 +10,7 @@ from typing import Union
import torch
import torch.nn.functional as F
from sentencepiece import SentencePieceProcessor
from transformers import GPT2Config, LlamaConfig
......@@ -308,7 +309,30 @@ def config_from_meta_checkpoint(
num_attention_heads=params["n_heads"],
num_hidden_layers=params["n_layers"],
rms_norm_eps=params["norm_eps"],
num_key_value_heads=params.get("n_kv_heads", None),
)
multiple_of = params.get("multiple_of", 1)
ffn_dim_multiplier = params.get("ffn_dim_multiplier", None)
# Compute the hidden dimension of the MLP
# https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L224
intermediate_size = 4 * config.hidden_size
# https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L195-L199
intermediate_size = int(2 * intermediate_size / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
intermediate_size = int(ffn_dim_multiplier * intermediate_size)
intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
config.intermediate_size = intermediate_size
if "rope_theta" in params:
config.rotary_emb_base = params["rope_theta"]
config.vocab_size = 32000
# some CodeLLaMa have vocab_size 32000, some 32016
# Sadly it's not specified in the `params.json` file :(
tokenizer = Path(checkpoint_path) / model_name / "tokenizer.model"
if tokenizer.is_file():
config.vocab_size = SentencePieceProcessor(str(tokenizer)).vocab_size()
return config
......@@ -364,4 +388,6 @@ def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
out_proj_bias=False,
mlp_fc1_bias=False,
mlp_fc2_bias=False,
rotary_emb_base=getattr(llama_config, "rotary_emb_base", 10000.0),
n_head_kv=llama_config.num_key_value_heads,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment