Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
c9d4a816
Unverified
Commit
c9d4a816
authored
Aug 30, 2023
by
dan_the_3rd
Committed by
GitHub
Aug 30, 2023
Browse files
Support LLaMa2 and CodeLLaMa (#491)
Co-authored-by: danthe3rd <danthe3rd>
parent
011ec323
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
0 deletions
+26
-0
flash_attn/models/llama.py
flash_attn/models/llama.py
+26
-0
No files found.
flash_attn/models/llama.py
View file @
c9d4a816
...
...
@@ -10,6 +10,7 @@ from typing import Union
import
torch
import
torch.nn.functional
as
F
from
sentencepiece
import
SentencePieceProcessor
from
transformers
import
GPT2Config
,
LlamaConfig
...
...
@@ -308,7 +309,30 @@ def config_from_meta_checkpoint(
num_attention_heads
=
params
[
"n_heads"
],
num_hidden_layers
=
params
[
"n_layers"
],
rms_norm_eps
=
params
[
"norm_eps"
],
num_key_value_heads
=
params
.
get
(
"n_kv_heads"
,
None
),
)
multiple_of
=
params
.
get
(
"multiple_of"
,
1
)
ffn_dim_multiplier
=
params
.
get
(
"ffn_dim_multiplier"
,
None
)
# Compute the hidden dimension of the MLP
# https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L224
intermediate_size
=
4
*
config
.
hidden_size
# https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L195-L199
intermediate_size
=
int
(
2
*
intermediate_size
/
3
)
# custom dim factor multiplier
if
ffn_dim_multiplier
is
not
None
:
intermediate_size
=
int
(
ffn_dim_multiplier
*
intermediate_size
)
intermediate_size
=
multiple_of
*
((
intermediate_size
+
multiple_of
-
1
)
//
multiple_of
)
config
.
intermediate_size
=
intermediate_size
if
"rope_theta"
in
params
:
config
.
rotary_emb_base
=
params
[
"rope_theta"
]
config
.
vocab_size
=
32000
# some CodeLLaMa have vocab_size 32000, some 32016
# Sadly it's not specified in the `params.json` file :(
tokenizer
=
Path
(
checkpoint_path
)
/
model_name
/
"tokenizer.model"
if
tokenizer
.
is_file
():
config
.
vocab_size
=
SentencePieceProcessor
(
str
(
tokenizer
)).
vocab_size
()
return
config
...
...
@@ -364,4 +388,6 @@ def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
out_proj_bias
=
False
,
mlp_fc1_bias
=
False
,
mlp_fc2_bias
=
False
,
rotary_emb_base
=
getattr
(
llama_config
,
"rotary_emb_base"
,
10000.0
),
n_head_kv
=
llama_config
.
num_key_value_heads
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment