Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
0705d271
Commit
0705d271
authored
Sep 20, 2023
by
Tri Dao
Browse files
[Llama] Fix some tests, add tests for Llama 2 and CodeLlama
parent
e0fbaa70
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
124 additions
and
161 deletions
+124
-161
flash_attn/models/llama.py
flash_attn/models/llama.py
+13
-18
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+1
-1
tests/models/test_gpt_generation_parallel.py
tests/models/test_gpt_generation_parallel.py
+1
-0
tests/models/test_llama.py
tests/models/test_llama.py
+109
-142
No files found.
flash_attn/models/llama.py
View file @
0705d271
...
@@ -13,6 +13,8 @@ import torch.nn.functional as F
...
@@ -13,6 +13,8 @@ import torch.nn.functional as F
from
sentencepiece
import
SentencePieceProcessor
from
sentencepiece
import
SentencePieceProcessor
from
transformers
import
GPT2Config
,
LlamaConfig
from
transformers
import
GPT2Config
,
LlamaConfig
from
einops
import
rearrange
def
remap_state_dict_meta_llama
(
def
remap_state_dict_meta_llama
(
state_dict
:
dict
[
str
,
torch
.
Tensor
],
config
:
GPT2Config
state_dict
:
dict
[
str
,
torch
.
Tensor
],
config
:
GPT2Config
...
@@ -30,9 +32,7 @@ def remap_state_dict_meta_llama(
...
@@ -30,9 +32,7 @@ def remap_state_dict_meta_llama(
# Word embedding
# Word embedding
def
key_mapping_emb
(
key
):
def
key_mapping_emb
(
key
):
return
re
.
sub
(
return
re
.
sub
(
r
"^transformer.tok_embeddings."
,
r
"^transformer.tok_embeddings."
,
"transformer.embeddings.word_embeddings."
,
key
"transformer.embeddings.word_embeddings."
,
key
,
)
)
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
...
@@ -113,7 +113,7 @@ def remap_state_dict_meta_llama(
...
@@ -113,7 +113,7 @@ def remap_state_dict_meta_llama(
def
remap_state_dict_hf_llama
(
def
remap_state_dict_hf_llama
(
state_dict
:
dict
[
str
,
torch
.
Tensor
],
config
:
GPT2Config
,
multi_query
:
bool
=
False
state_dict
:
dict
[
str
,
torch
.
Tensor
],
config
:
GPT2Config
)
->
dict
[
str
,
torch
.
Tensor
]:
)
->
dict
[
str
,
torch
.
Tensor
]:
"""Convert the state_dict in Hugging Face format to standard GPT format.
"""Convert the state_dict in Hugging Face format to standard GPT format.
...
@@ -186,13 +186,11 @@ def remap_state_dict_hf_llama(
...
@@ -186,13 +186,11 @@ def remap_state_dict_hf_llama(
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
def
inv_permute
(
w
,
first_dim
=
None
):
def
inv_permute
(
w
):
# Inverse of permute implemented in:
# Inverse of permute implemented in:
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
return
(
return
rearrange
(
w
.
reshape
(
first_dim
or
config
.
n_head
,
2
,
-
1
,
config
.
n_embd
)
w
,
"(h two d) n -> (h d two) n"
,
d
=
config
.
n_embd
//
config
.
n_head
//
2
,
two
=
2
.
transpose
(
1
,
2
)
.
reshape
(
-
1
,
config
.
n_embd
)
)
)
# Attention
# Attention
...
@@ -202,8 +200,7 @@ def remap_state_dict_hf_llama(
...
@@ -202,8 +200,7 @@ def remap_state_dict_hf_llama(
Wv
=
state_dict
.
pop
(
f
"model.layers.
{
l
}
.self_attn.v_proj.weight"
)
Wv
=
state_dict
.
pop
(
f
"model.layers.
{
l
}
.self_attn.v_proj.weight"
)
state_dict
[
f
"transformer.layers.
{
l
}
.mixer.Wqkv.weight"
]
=
torch
.
cat
(
state_dict
[
f
"transformer.layers.
{
l
}
.mixer.Wqkv.weight"
]
=
torch
.
cat
(
(
inv_permute
(
Wq
),
inv_permute
(
Wk
,
getattr
(
config
,
"n_head_kv"
)),
Wv
),
[
inv_permute
(
Wq
),
inv_permute
(
Wk
),
Wv
],
dim
=
0
dim
=
0
,
)
)
# We don't store these
# We don't store these
state_dict
.
pop
(
f
"model.layers.
{
l
}
.self_attn.rotary_emb.inv_freq"
,
None
)
state_dict
.
pop
(
f
"model.layers.
{
l
}
.self_attn.rotary_emb.inv_freq"
,
None
)
...
@@ -220,7 +217,7 @@ def remap_state_dict_hf_llama(
...
@@ -220,7 +217,7 @@ def remap_state_dict_hf_llama(
def
inv_remap_state_dict_hf_llama
(
def
inv_remap_state_dict_hf_llama
(
state_dict
:
dict
[
str
,
torch
.
Tensor
],
config
:
GPT2Config
,
multi_query
:
bool
=
False
state_dict
:
dict
[
str
,
torch
.
Tensor
],
config
:
GPT2Config
)
->
dict
[
str
,
torch
.
Tensor
]:
)
->
dict
[
str
,
torch
.
Tensor
]:
"""Convert the state_dict in standard GPT format to Hugging Face format.
"""Convert the state_dict in standard GPT format to Hugging Face format.
...
@@ -293,11 +290,9 @@ def inv_remap_state_dict_hf_llama(
...
@@ -293,11 +290,9 @@ def inv_remap_state_dict_hf_llama(
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
def
permute
(
w
,
first_dim
=
None
):
def
permute
(
w
):
return
(
return
rearrange
(
w
.
view
(
first_dim
or
config
.
n_head
,
-
1
,
2
,
config
.
n_embd
)
w
,
"(h d two) n -> (h two d) n"
,
d
=
config
.
n_embd
//
config
.
n_head
//
2
,
two
=
2
.
transpose
(
1
,
2
)
.
reshape
(
-
1
,
config
.
n_embd
)
)
)
n_head
=
config
.
n_head
n_head
=
config
.
n_head
...
@@ -316,7 +311,7 @@ def inv_remap_state_dict_hf_llama(
...
@@ -316,7 +311,7 @@ def inv_remap_state_dict_hf_llama(
Wk
=
Wqkv
[
q_dim
:
q_dim
+
k_dim
]
Wk
=
Wqkv
[
q_dim
:
q_dim
+
k_dim
]
Wv
=
Wqkv
[
q_dim
+
k_dim
:
q_dim
+
k_dim
+
v_dim
]
Wv
=
Wqkv
[
q_dim
+
k_dim
:
q_dim
+
k_dim
+
v_dim
]
state_dict
[
f
"model.layers.
{
l
}
.self_attn.q_proj.weight"
]
=
permute
(
Wq
)
state_dict
[
f
"model.layers.
{
l
}
.self_attn.q_proj.weight"
]
=
permute
(
Wq
)
state_dict
[
f
"model.layers.
{
l
}
.self_attn.k_proj.weight"
]
=
permute
(
Wk
,
n_head_kv
)
state_dict
[
f
"model.layers.
{
l
}
.self_attn.k_proj.weight"
]
=
permute
(
Wk
)
state_dict
[
f
"model.layers.
{
l
}
.self_attn.v_proj.weight"
]
=
Wv
state_dict
[
f
"model.layers.
{
l
}
.self_attn.v_proj.weight"
]
=
Wv
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.inner_attention.rope.freqs"
,
None
)
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.inner_attention.rope.freqs"
,
None
)
...
...
flash_attn/modules/mha.py
View file @
0705d271
...
@@ -725,7 +725,7 @@ class ParallelMHA(nn.Module):
...
@@ -725,7 +725,7 @@ class ParallelMHA(nn.Module):
process_group
,
process_group
,
bias
=
qkv_proj_bias
,
bias
=
qkv_proj_bias
,
sequence_parallel
=
sequence_parallel
,
sequence_parallel
=
sequence_parallel
,
multiple_of
=
self
.
head_dim
*
(
self
.
num_heads
_per_rank
+
2
*
self
.
num_heads_kv
_per_rank
),
multiple_of
=
self
.
head_dim
*
(
self
.
num_heads
//
self
.
num_heads_kv
+
2
),
**
factory_kwargs
,
**
factory_kwargs
,
)
)
inner_attn_cls
=
FlashSelfAttention
if
use_flash_attn
else
SelfAttention
inner_attn_cls
=
FlashSelfAttention
if
use_flash_attn
else
SelfAttention
...
...
tests/models/test_gpt_generation_parallel.py
View file @
0705d271
...
@@ -160,6 +160,7 @@ def test_tensor_parallel(model_name, rotary, world_size):
...
@@ -160,6 +160,7 @@ def test_tensor_parallel(model_name, rotary, world_size):
assert
torch
.
allclose
(
assert
torch
.
allclose
(
torch
.
stack
(
out
.
scores
,
dim
=
1
),
torch
.
stack
(
scores
,
dim
=
1
),
rtol
=
rtol
,
atol
=
atol
torch
.
stack
(
out
.
scores
,
dim
=
1
),
torch
.
stack
(
scores
,
dim
=
1
),
rtol
=
rtol
,
atol
=
atol
)
)
assert
torch
.
equal
(
torch
.
stack
(
out
.
scores
,
dim
=
1
),
torch
.
stack
(
out_cg
.
scores
,
dim
=
1
))
if
not
rotary
:
if
not
rotary
:
assert
torch
.
all
(
out
.
sequences
==
out_ref
.
sequences
)
assert
torch
.
all
(
out
.
sequences
==
out_ref
.
sequences
)
assert
torch
.
all
(
out
.
sequences
==
out_hf
.
sequences
)
assert
torch
.
all
(
out
.
sequences
==
out_hf
.
sequences
)
...
...
tests/models/test_llama.py
View file @
0705d271
# Copyright (c) 2023, Tri Dao.
# Copyright (c) 2023, Tri Dao.
# To run the huggingface implementation, we first need to convert the weights:
# To run the huggingface implementation
of LLaMa (1)
, we first need to convert the weights:
# https://github.com/huggingface/transformers/pull/21955
# https://github.com/huggingface/transformers/pull/21955
# python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR/llama/7B-hf
# python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR/llama/7B-hf
# and repeat for 13B, 30B, 65B
# and repeat for 13B, 30B, 65B
...
@@ -30,6 +30,7 @@ from flash_attn.utils.generation import update_graph_cache
...
@@ -30,6 +30,7 @@ from flash_attn.utils.generation import update_graph_cache
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
transformers
import
LlamaConfig
,
LlamaTokenizer
from
transformers
import
LlamaConfig
,
LlamaTokenizer
from
transformers.models.llama.modeling_llama
import
LlamaForCausalLM
from
transformers.models.llama.modeling_llama
import
LlamaForCausalLM
from
transformers
import
AutoConfig
def
_pretrained_state_dict_from_checkpoint
(
checkpoint_path
,
model_name
,
config
,
checkpoint_format
):
def
_pretrained_state_dict_from_checkpoint
(
checkpoint_path
,
model_name
,
config
,
checkpoint_format
):
...
@@ -60,9 +61,38 @@ def test_llama_state_dict(model_name):
...
@@ -60,9 +61,38 @@ def test_llama_state_dict(model_name):
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"7B"
,
"13B"
])
# TinyLlama-1.1B is to test MQA
@
pytest
.
mark
.
parametrize
(
"checkpoint_format"
,
[
"meta"
,
"hf"
])
@
pytest
.
mark
.
parametrize
(
def
test_llama_optimized
(
model_name
,
checkpoint_format
):
"model_name"
,
[
"meta-llama/Llama-2-7b-hf"
,
"PY007/TinyLlama-1.1B-step-50K-105b"
]
)
def
test_inv_remap_state_dict_hf_llama
(
model_name
):
config
=
llama_config_to_gpt2_config
(
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
)
state_dict
=
state_dict_from_pretrained
(
model_name
)
# inv_remap_state_dict_hf_llama should be the inverse of remap_state_dict_hf_llama
state_dict
=
{
key
:
val
for
key
,
val
in
state_dict
.
items
()
if
"rotary_emb.inv_freq"
not
in
key
}
pretrained_state_dict
=
remap_state_dict_hf_llama
(
state_dict
,
config
)
state_dict_recover
=
inv_remap_state_dict_hf_llama
(
pretrained_state_dict
,
config
)
assert
set
(
state_dict_recover
.
keys
())
==
set
(
state_dict
.
keys
())
for
key
in
state_dict_recover
.
keys
():
torch
.
testing
.
assert_close
(
state_dict_recover
[
key
],
state_dict
[
key
])
# TinyLlama-1.1B is to test MQA
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"7B"
,
# Llama 1
"13B"
,
# Llama 1
"meta-llama/Llama-2-13b-hf"
,
"codellama/CodeLlama-7b-hf"
,
"codellama/CodeLlama-13b-hf"
,
"codellama/CodeLlama-34b-hf"
,
"PY007/TinyLlama-1.1B-step-50K-105b"
,
],
)
def
test_llama_optimized
(
model_name
):
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
forward pass in fp16, when compared to the HF forward pass in fp32.
...
@@ -73,7 +103,12 @@ def test_llama_optimized(model_name, checkpoint_format):
...
@@ -73,7 +103,12 @@ def test_llama_optimized(model_name, checkpoint_format):
dtype
=
torch
.
float16
dtype
=
torch
.
float16
device
=
"cuda"
device
=
"cuda"
config
=
config_from_checkpoint
(
checkpoint_path
,
model_name
,
checkpoint_format
)
if
"/"
in
model_name
:
# Download from HF
config
=
llama_config_to_gpt2_config
(
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
)
else
:
config
=
config_from_checkpoint
(
checkpoint_path
,
model_name
,
checkpoint_format
=
"meta"
)
config
=
llama_config_to_gpt2_config
(
config
)
config
=
llama_config_to_gpt2_config
(
config
)
config
.
use_flash_attn
=
True
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_bias_fc
=
True
...
@@ -81,8 +116,13 @@ def test_llama_optimized(model_name, checkpoint_format):
...
@@ -81,8 +116,13 @@ def test_llama_optimized(model_name, checkpoint_format):
config
.
fused_dropout_add_ln
=
True
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
config
.
residual_in_fp32
=
True
if
"/"
in
model_name
:
# Download from HF
pretrained_state_dict
=
remap_state_dict_hf_llama
(
state_dict_from_pretrained
(
model_name
),
config
)
else
:
pretrained_state_dict
=
_pretrained_state_dict_from_checkpoint
(
pretrained_state_dict
=
_pretrained_state_dict_from_checkpoint
(
checkpoint_path
,
model_name
,
config
,
checkpoint_format
checkpoint_path
,
model_name
,
config
,
checkpoint_format
=
"meta"
)
)
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
pretrained_state_dict
)
model
.
load_state_dict
(
pretrained_state_dict
)
...
@@ -103,7 +143,8 @@ def test_llama_optimized(model_name, checkpoint_format):
...
@@ -103,7 +143,8 @@ def test_llama_optimized(model_name, checkpoint_format):
# Without device_map, the model is loaded on the CPU, which is very slow
# Without device_map, the model is loaded on the CPU, which is very slow
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref
=
LlamaForCausalLM
.
from_pretrained
(
model_ref
=
LlamaForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
,
device_map
=
"auto"
model_name
if
"/"
in
model_name
else
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
,
device_map
=
"auto"
,
)
)
model_ref
.
eval
()
model_ref
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -112,7 +153,9 @@ def test_llama_optimized(model_name, checkpoint_format):
...
@@ -112,7 +153,9 @@ def test_llama_optimized(model_name, checkpoint_format):
del
model_ref
del
model_ref
model_hf
=
LlamaForCausalLM
.
from_pretrained
(
model_hf
=
LlamaForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
}
model_name
if
"/"
in
model_name
else
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
},
)
)
model_hf
.
eval
()
model_hf
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -135,77 +178,12 @@ def test_llama_optimized(model_name, checkpoint_format):
...
@@ -135,77 +178,12 @@ def test_llama_optimized(model_name, checkpoint_format):
).
abs
().
max
().
item
()
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"PY007/TinyLlama-1.1B-step-50K-105b"
])
def
test_mqa_optimized
(
model_name
):
"""Check that our implementation of Llama with MQA/GQA (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype
=
torch
.
float16
device
=
"cuda"
config
=
llama_config_to_gpt2_config
(
LlamaConfig
.
from_pretrained
(
model_name
))
config
.
use_flash_attn
=
True
# FlashAttention-2 supports headdim 256
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
False
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref
=
LlamaForCausalLM
.
from_pretrained
(
model_name
,
device_map
=
{
""
:
device
})
model_ref
.
eval
()
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
remap_state_dict_hf_llama
(
model_ref
.
state_dict
(),
config
))
model
.
eval
()
torch
.
manual_seed
(
0
)
batch_size
=
2
max_seqlen
=
256
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
with
torch
.
no_grad
():
out
=
model
.
transformer
(
input_ids
)
logits
=
model
(
input_ids
).
logits
del
model
with
torch
.
no_grad
():
out_ref
=
model_ref
.
model
(
input_ids
).
last_hidden_state
logits_ref
=
model_ref
(
input_ids
).
logits
del
model_ref
model_hf
=
LlamaForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
}
)
model_hf
.
eval
()
out_hf
=
model_hf
.
model
(
input_ids
).
last_hidden_state
logits_hf
=
model_hf
(
input_ids
).
logits
del
model_hf
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"HF fp16 max diff:
{
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"HF fp16 mean diff:
{
(
out_hf
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
3
*
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
print
(
f
"Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Logits mean diff:
{
(
logits
-
logits_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"HF fp16 max diff:
{
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"HF fp16 mean diff:
{
(
logits_hf
-
logits_ref
).
abs
().
mean
().
item
()
}
"
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
3
*
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel"
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel"
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"13B"
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"checkpoint_format"
,
[
"meta"
,
"hf"
])
"model_name"
,
[
"13B"
,
"meta-llama/Llama-2-13b-hf"
,
"codellama/CodeLlama-34b-hf"
]
def
test_llama_parallel
(
model_name
,
world_size
,
checkpoint_format
):
)
def
test_llama_parallel
(
model_name
,
world_size
):
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
forward pass in fp16, when compared to the HF forward pass in fp32.
...
@@ -217,7 +195,12 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
...
@@ -217,7 +195,12 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
)
)
dtype
=
torch
.
float16
dtype
=
torch
.
float16
config
=
config_from_checkpoint
(
checkpoint_path
,
model_name
,
checkpoint_format
)
if
"/"
in
model_name
:
# Download from HF
config
=
llama_config_to_gpt2_config
(
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
)
else
:
config
=
config_from_checkpoint
(
checkpoint_path
,
model_name
,
checkpoint_format
=
"meta"
)
config
=
llama_config_to_gpt2_config
(
config
)
config
=
llama_config_to_gpt2_config
(
config
)
config
.
use_flash_attn
=
True
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_bias_fc
=
True
...
@@ -233,8 +216,13 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
...
@@ -233,8 +216,13 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
process_group
=
parallel_state
.
get_tensor_model_parallel_group
()
process_group
=
parallel_state
.
get_tensor_model_parallel_group
()
if
"/"
in
model_name
:
# Download from HF
pretrained_state_dict
=
remap_state_dict_hf_llama
(
state_dict_from_pretrained
(
model_name
),
config
)
else
:
pretrained_state_dict
=
_pretrained_state_dict_from_checkpoint
(
pretrained_state_dict
=
_pretrained_state_dict_from_checkpoint
(
checkpoint_path
,
model_name
,
config
,
checkpoint_format
checkpoint_path
,
model_name
,
config
,
checkpoint_format
=
"meta"
)
)
model
=
GPTLMHeadModel
(
config
,
process_group
=
process_group
,
device
=
device
,
dtype
=
dtype
)
model
=
GPTLMHeadModel
(
config
,
process_group
=
process_group
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
shard_state_dict_tp
(
pretrained_state_dict
,
config
,
world_size
,
rank
))
model
.
load_state_dict
(
shard_state_dict_tp
(
pretrained_state_dict
,
config
,
world_size
,
rank
))
...
@@ -260,7 +248,8 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
...
@@ -260,7 +248,8 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
if
rank
==
0
:
if
rank
==
0
:
# Without device_map, the model is loaded on the CPU, which is very slow
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref
=
LlamaForCausalLM
.
from_pretrained
(
model_ref
=
LlamaForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
,
device_map
=
"auto"
model_name
if
"/"
in
model_name
else
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
,
device_map
=
"auto"
,
)
)
model_ref
.
eval
()
model_ref
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -269,7 +258,9 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
...
@@ -269,7 +258,9 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
del
model_ref
del
model_ref
model_hf
=
LlamaForCausalLM
.
from_pretrained
(
model_hf
=
LlamaForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
,
torch_dtype
=
dtype
,
device_map
=
"auto"
model_name
if
"/"
in
model_name
else
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
,
torch_dtype
=
dtype
,
device_map
=
"auto"
,
)
)
model_hf
.
eval
()
model_hf
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -405,9 +396,10 @@ def test_llama_generation(model_name, checkpoint_format):
...
@@ -405,9 +396,10 @@ def test_llama_generation(model_name, checkpoint_format):
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "llama_parallel_generation"
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "llama_parallel_generation"
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"13B"
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"checkpoint_format"
,
[
"meta"
,
"hf"
])
"model_name"
,
[
"13B"
,
"meta-llama/Llama-2-13b-hf"
,
"codellama/CodeLlama-34b-hf"
]
def
test_llama_parallel_generation
(
model_name
,
world_size
,
checkpoint_format
):
)
def
test_llama_parallel_generation
(
model_name
,
world_size
):
"""Check that our implementation matches the HF implementation:
"""Check that our implementation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
the HF scores in fp32.
...
@@ -419,12 +411,17 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
...
@@ -419,12 +411,17 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
)
)
dtype
=
torch
.
float16
dtype
=
torch
.
float16
config
=
config_from_checkpoint
(
checkpoint_path
,
model_name
,
checkpoint_format
)
if
"/"
in
model_name
:
# Download from HF
config
=
llama_config_to_gpt2_config
(
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
)
else
:
config
=
config_from_checkpoint
(
checkpoint_path
,
model_name
,
checkpoint_format
=
"meta"
)
config
=
llama_config_to_gpt2_config
(
config
)
config
=
llama_config_to_gpt2_config
(
config
)
config
.
use_flash_attn
=
Fals
e
config
.
use_flash_attn
=
Tru
e
config
.
fused_bias_fc
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
False
# We don't have fused GatedMLP yet
config
.
fused_mlp
=
False
# We don't have fused GatedMLP yet
config
.
fused_dropout_add_ln
=
Fals
e
config
.
fused_dropout_add_ln
=
Tru
e
config
.
residual_in_fp32
=
True
config
.
residual_in_fp32
=
True
config
.
pad_vocab_size_multiple
=
8
*
world_size
config
.
pad_vocab_size_multiple
=
8
*
world_size
config
.
sequence_parallel
=
False
# Need to set this to False for generation
config
.
sequence_parallel
=
False
# Need to set this to False for generation
...
@@ -450,8 +447,13 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
...
@@ -450,8 +447,13 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
# GPU0 and GPU1 and things would hang
# GPU0 and GPU1 and things would hang
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
if
"/"
in
model_name
:
# Download from HF
pretrained_state_dict
=
remap_state_dict_hf_llama
(
state_dict_from_pretrained
(
model_name
),
config
)
else
:
pretrained_state_dict
=
_pretrained_state_dict_from_checkpoint
(
pretrained_state_dict
=
_pretrained_state_dict_from_checkpoint
(
checkpoint_path
,
model_name
,
config
,
checkpoint_format
checkpoint_path
,
model_name
,
config
,
checkpoint_format
=
"meta"
)
)
model
=
GPTLMHeadModel
(
config
,
process_group
=
process_group
,
device
=
device
,
dtype
=
dtype
)
model
=
GPTLMHeadModel
(
config
,
process_group
=
process_group
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
shard_state_dict_tp
(
pretrained_state_dict
,
config
,
world_size
,
rank
))
model
.
load_state_dict
(
shard_state_dict_tp
(
pretrained_state_dict
,
config
,
world_size
,
rank
))
...
@@ -490,7 +492,9 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
...
@@ -490,7 +492,9 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
if
rank
==
0
:
if
rank
==
0
:
# Without device_map, the model is loaded on the CPU, which is very slow
# Without device_map, the model is loaded on the CPU, which is very slow
model_hf
=
LlamaForCausalLM
.
from_pretrained
(
model_hf
=
LlamaForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
,
torch_dtype
=
dtype
,
device_map
=
"auto"
model_name
if
"/"
in
model_name
else
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
,
torch_dtype
=
dtype
,
device_map
=
"auto"
,
)
)
model_hf
.
eval
()
model_hf
.
eval
()
print
(
"HF fp16"
)
print
(
"HF fp16"
)
...
@@ -508,7 +512,8 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
...
@@ -508,7 +512,8 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
del
model_hf
del
model_hf
model_ref
=
LlamaForCausalLM
.
from_pretrained
(
model_ref
=
LlamaForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
,
device_map
=
"auto"
model_name
if
"/"
in
model_name
else
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
,
device_map
=
"auto"
,
)
)
model_ref
.
eval
()
model_ref
.
eval
()
with
torch
.
inference_mode
():
with
torch
.
inference_mode
():
...
@@ -594,15 +599,16 @@ def test_llama_parallel_uneven_num_heads(world_size):
...
@@ -594,15 +599,16 @@ def test_llama_parallel_uneven_num_heads(world_size):
if
rank
==
0
:
if
rank
==
0
:
model_ref
=
LlamaForCausalLM
.
from_pretrained
(
model_ref
=
LlamaForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
,
device_map
=
"auto"
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
,
device_map
=
{
""
:
device
}
)
)
model_ref
=
model_ref
.
to
(
device
=
device
)
model_ref
.
eval
()
model_ref
.
eval
()
out_ref
=
model_ref
.
model
(
input_ids
).
last_hidden_state
.
to
(
device
=
device
)
out_ref
=
model_ref
.
model
(
input_ids
).
last_hidden_state
logits_ref
=
model_ref
(
input_ids
).
logits
.
to
(
device
=
device
)
logits_ref
=
model_ref
(
input_ids
).
logits
del
model_ref
del
model_ref
model_hf
=
LlamaForCausalLM
.
from_pretrained
(
model_hf
=
LlamaForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
,
torch_dtype
=
dtype
,
device_map
=
"auto"
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
}
)
)
model_hf
.
eval
()
model_hf
.
eval
()
out_hf
=
model_hf
.
model
(
input_ids
).
last_hidden_state
.
to
(
device
=
device
)
out_hf
=
model_hf
.
model
(
input_ids
).
last_hidden_state
.
to
(
device
=
device
)
...
@@ -625,42 +631,3 @@ def test_llama_parallel_uneven_num_heads(world_size):
...
@@ -625,42 +631,3 @@ def test_llama_parallel_uneven_num_heads(world_size):
if
os
.
path
.
exists
(
checkpoint_path
/
f
"
{
model_name
}
-hf"
):
if
os
.
path
.
exists
(
checkpoint_path
/
f
"
{
model_name
}
-hf"
):
shutil
.
rmtree
(
checkpoint_path
/
f
"
{
model_name
}
-hf"
)
shutil
.
rmtree
(
checkpoint_path
/
f
"
{
model_name
}
-hf"
)
@
torch
.
no_grad
()
def
test_inv_remap_state_dict_hf_llama
():
checkpoint_path
=
(
Path
(
os
.
environ
.
get
(
"CHECKPOINT_DIR"
,
current_dir
.
parent
.
parent
/
"checkpoints"
))
/
"llama"
)
model_name
=
f
"teeny"
llama_config
=
LlamaConfig
(
num_attention_heads
=
2
,
hidden_size
=
256
*
2
,
intermediate_size
=
256
*
2
*
4
,
num_hidden_layers
=
4
,
)
config
=
llama_config_to_gpt2_config
(
llama_config
)
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
False
# We don't have fused GatedMLP yet
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
# Set up.
LlamaForCausalLM
(
config
=
llama_config
).
save_pretrained
(
checkpoint_path
/
f
"
{
model_name
}
-hf"
)
# inv_remap_state_dict_hf_llama should be the inverse of remap_state_dict_hf_llama
state_dict
=
state_dict_from_pretrained
(
checkpoint_path
/
f
"
{
model_name
}
-hf"
)
state_dict
=
{
key
:
val
for
key
,
val
in
state_dict
.
items
()
if
"rotary_emb.inv_freq"
not
in
key
}
pretrained_state_dict
=
remap_state_dict_hf_llama
(
state_dict
,
config
)
state_dict_recover
=
inv_remap_state_dict_hf_llama
(
pretrained_state_dict
,
config
)
assert
set
(
state_dict_recover
.
keys
())
==
set
(
state_dict
.
keys
())
for
key
in
state_dict_recover
.
keys
():
torch
.
testing
.
assert_close
(
state_dict_recover
[
key
],
state_dict
[
key
])
# Tear down.
if
os
.
path
.
exists
(
checkpoint_path
/
f
"
{
model_name
}
-hf"
):
shutil
.
rmtree
(
checkpoint_path
/
f
"
{
model_name
}
-hf"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment