Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
42832575
Unverified
Commit
42832575
authored
Sep 19, 2023
by
Kevin Hu
Committed by
GitHub
Sep 19, 2023
Browse files
Fix Llama GQA/MQA (#546)
* Fix llama MQA * Fix permute shape * Update llama.py
parent
dfe29f5e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
121 additions
and
21 deletions
+121
-21
flash_attn/models/llama.py
flash_attn/models/llama.py
+55
-21
tests/models/test_llama.py
tests/models/test_llama.py
+66
-0
No files found.
flash_attn/models/llama.py
View file @
42832575
...
...
@@ -26,10 +26,13 @@ def remap_state_dict_meta_llama(
return
f
"transformer.
{
key
}
"
if
not
key
.
startswith
(
"output."
)
else
key
state_dict
=
OrderedDict
((
key_mapping_layers
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding
def
key_mapping_emb
(
key
):
return
re
.
sub
(
r
"^transformer.tok_embeddings."
,
"transformer.embeddings.word_embeddings."
,
key
r
"^transformer.tok_embeddings."
,
"transformer.embeddings.word_embeddings."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
...
...
@@ -61,7 +64,9 @@ def remap_state_dict_meta_llama(
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
"^transformer.norm."
,
r
"transformer.ln_f."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).attention_norm."
,
r
"transformer.layers.\1.norm1."
,
key
r
"^transformer.layers.(\d+).attention_norm."
,
r
"transformer.layers.\1.norm1."
,
key
,
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).ffn_norm."
,
r
"transformer.layers.\1.norm2."
,
key
)
return
key
...
...
@@ -77,7 +82,9 @@ def remap_state_dict_meta_llama(
def
key_mapping_mlp
(
key
):
return
re
.
sub
(
r
"^transformer.layers.(\d+).feed_forward.w2."
,
r
"transformer.layers.\1.mlp.fc2."
,
key
r
"^transformer.layers.(\d+).feed_forward.w2."
,
r
"transformer.layers.\1.mlp.fc2."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
...
...
@@ -106,12 +113,13 @@ def remap_state_dict_meta_llama(
def
remap_state_dict_hf_llama
(
state_dict
:
dict
[
str
,
torch
.
Tensor
],
config
:
GPT2Config
state_dict
:
dict
[
str
,
torch
.
Tensor
],
config
:
GPT2Config
,
multi_query
:
bool
=
False
)
->
dict
[
str
,
torch
.
Tensor
]:
"""Convert the state_dict in Hugging Face format to standard GPT format.
This function modifies state_dict in place.
"""
# Embedding
def
key_mapping_emb
(
key
):
return
re
.
sub
(
r
"^model.embed_tokens."
,
"transformer.embeddings.word_embeddings."
,
key
)
...
...
@@ -153,28 +161,38 @@ def remap_state_dict_hf_llama(
state_dict
[
f
"transformer.layers.
{
l
}
.mlp.fc1.weight"
]
=
torch
.
cat
([
w3
,
w1
],
dim
=
0
)
def
key_mapping_mlp
(
key
):
return
re
.
sub
(
r
"^model.layers.(\d+).mlp.down_proj."
,
r
"transformer.layers.\1.mlp.fc2."
,
key
)
return
re
.
sub
(
r
"^model.layers.(\d+).mlp.down_proj."
,
r
"transformer.layers.\1.mlp.fc2."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
"^model.norm."
,
r
"transformer.ln_f."
,
key
)
key
=
re
.
sub
(
r
"^model.layers.(\d+).input_layernorm."
,
r
"transformer.layers.\1.norm1."
,
key
)
key
=
re
.
sub
(
r
"^model.layers.(\d+).post_attention_layernorm."
,
r
"transformer.layers.\1.norm2."
,
key
r
"^model.layers.(\d+).input_layernorm."
,
r
"transformer.layers.\1.norm1."
,
key
,
)
key
=
re
.
sub
(
r
"^model.layers.(\d+).post_attention_layernorm."
,
r
"transformer.layers.\1.norm2."
,
key
,
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
def
inv_permute
(
w
):
def
inv_permute
(
w
,
first_dim
=
None
):
# Inverse of permute implemented in:
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
return
(
w
.
reshape
(
config
.
n_head
,
2
,
config
.
n_embd
//
config
.
n_head
//
2
,
config
.
n_embd
)
w
.
reshape
(
first_dim
or
config
.
n_head
,
2
,
-
1
,
config
.
n_embd
)
.
transpose
(
1
,
2
)
.
reshape
(
config
.
n_embd
,
config
.
n_embd
)
.
reshape
(
-
1
,
config
.
n_embd
)
)
# Attention
...
...
@@ -182,15 +200,19 @@ def remap_state_dict_hf_llama(
Wq
=
state_dict
.
pop
(
f
"model.layers.
{
l
}
.self_attn.q_proj.weight"
)
Wk
=
state_dict
.
pop
(
f
"model.layers.
{
l
}
.self_attn.k_proj.weight"
)
Wv
=
state_dict
.
pop
(
f
"model.layers.
{
l
}
.self_attn.v_proj.weight"
)
state_dict
[
f
"transformer.layers.
{
l
}
.mixer.Wqkv.weight"
]
=
torch
.
cat
(
[
inv_permute
(
Wq
),
inv_permute
(
Wk
),
Wv
],
dim
=
0
(
inv_permute
(
Wq
),
inv_permute
(
Wk
,
getattr
(
config
,
"n_head_kv"
)),
Wv
),
dim
=
0
,
)
# We don't store these
state_dict
.
pop
(
f
"model.layers.
{
l
}
.self_attn.rotary_emb.inv_freq"
,
None
)
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
"^model.layers.(\d+).self_attn.o_proj."
,
r
"transformer.layers.\1.mixer.out_proj."
,
key
r
"^model.layers.(\d+).self_attn.o_proj."
,
r
"transformer.layers.\1.mixer.out_proj."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
...
...
@@ -198,7 +220,7 @@ def remap_state_dict_hf_llama(
def
inv_remap_state_dict_hf_llama
(
state_dict
:
dict
[
str
,
torch
.
Tensor
],
config
:
GPT2Config
state_dict
:
dict
[
str
,
torch
.
Tensor
],
config
:
GPT2Config
,
multi_query
:
bool
=
False
)
->
dict
[
str
,
torch
.
Tensor
]:
"""Convert the state_dict in standard GPT format to Hugging Face format.
...
...
@@ -246,26 +268,36 @@ def inv_remap_state_dict_hf_llama(
state_dict
[
f
"model.layers.
{
l
}
.mlp.up_proj.weight"
]
=
w3
def
key_mapping_mlp
(
key
):
return
re
.
sub
(
r
"^transformer.layers.(\d+).mlp.fc2."
,
r
"model.layers.\1.mlp.down_proj."
,
key
)
return
re
.
sub
(
r
"^transformer.layers.(\d+).mlp.fc2."
,
r
"model.layers.\1.mlp.down_proj."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
"^transformer.ln_f."
,
r
"model.norm."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).norm1."
,
r
"model.layers.\1.input_layernorm."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).norm2."
,
r
"model.layers.\1.post_attention_layernorm."
,
key
r
"^transformer.layers.(\d+).norm1."
,
r
"model.layers.\1.input_layernorm."
,
key
,
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).norm2."
,
r
"model.layers.\1.post_attention_layernorm."
,
key
,
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
def
permute
(
w
):
def
permute
(
w
,
first_dim
=
None
):
return
(
w
.
view
(
config
.
n_head
,
config
.
n_embd
//
config
.
n_head
//
2
,
2
,
config
.
n_embd
)
w
.
view
(
first_dim
or
config
.
n_head
,
-
1
,
2
,
config
.
n_embd
)
.
transpose
(
1
,
2
)
.
reshape
(
config
.
n_embd
,
config
.
n_embd
)
.
reshape
(
-
1
,
config
.
n_embd
)
)
n_head
=
config
.
n_head
...
...
@@ -284,13 +316,15 @@ def inv_remap_state_dict_hf_llama(
Wk
=
Wqkv
[
q_dim
:
q_dim
+
k_dim
]
Wv
=
Wqkv
[
q_dim
+
k_dim
:
q_dim
+
k_dim
+
v_dim
]
state_dict
[
f
"model.layers.
{
l
}
.self_attn.q_proj.weight"
]
=
permute
(
Wq
)
state_dict
[
f
"model.layers.
{
l
}
.self_attn.k_proj.weight"
]
=
permute
(
Wk
)
state_dict
[
f
"model.layers.
{
l
}
.self_attn.k_proj.weight"
]
=
permute
(
Wk
,
n_head_kv
)
state_dict
[
f
"model.layers.
{
l
}
.self_attn.v_proj.weight"
]
=
Wv
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.inner_attention.rope.freqs"
,
None
)
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
"^transformer.layers.(\d+).mixer.out_proj."
,
r
"model.layers.\1.self_attn.o_proj."
,
key
r
"^transformer.layers.(\d+).mixer.out_proj."
,
r
"model.layers.\1.self_attn.o_proj."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
...
...
tests/models/test_llama.py
View file @
42832575
...
...
@@ -135,6 +135,72 @@ def test_llama_optimized(model_name, checkpoint_format):
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"PY007/TinyLlama-1.1B-step-50K-105b"
])
def
test_mqa_optimized
(
model_name
):
"""Check that our implementation of Llama with MQA/GQA (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype
=
torch
.
float16
device
=
"cuda"
config
=
llama_config_to_gpt2_config
(
LlamaConfig
.
from_pretrained
(
model_name
))
config
.
use_flash_attn
=
True
# FlashAttention-2 supports headdim 256
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
False
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref
=
LlamaForCausalLM
.
from_pretrained
(
model_name
,
device_map
=
{
""
:
device
})
model_ref
.
eval
()
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
remap_state_dict_hf_llama
(
model_ref
.
state_dict
(),
config
))
model
.
eval
()
torch
.
manual_seed
(
0
)
batch_size
=
2
max_seqlen
=
256
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
with
torch
.
no_grad
():
out
=
model
.
transformer
(
input_ids
)
logits
=
model
(
input_ids
).
logits
del
model
with
torch
.
no_grad
():
out_ref
=
model_ref
.
model
(
input_ids
).
last_hidden_state
logits_ref
=
model_ref
(
input_ids
).
logits
del
model_ref
model_hf
=
LlamaForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
}
)
model_hf
.
eval
()
out_hf
=
model_hf
.
model
(
input_ids
).
last_hidden_state
logits_hf
=
model_hf
(
input_ids
).
logits
del
model_hf
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"HF fp16 max diff:
{
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"HF fp16 mean diff:
{
(
out_hf
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
3
*
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
print
(
f
"Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Logits mean diff:
{
(
logits
-
logits_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"HF fp16 max diff:
{
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"HF fp16 mean diff:
{
(
logits_hf
-
logits_ref
).
abs
().
mean
().
item
()
}
"
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
3
*
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel"
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"13B"
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment