Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
798858f9
Commit
798858f9
authored
Sep 03, 2023
by
Tri Dao
Browse files
Fix test_baichuan
parent
7b33743a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
46 additions
and
86 deletions
+46
-86
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+2
-1
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+2
-2
tests/models/test_baichuan.py
tests/models/test_baichuan.py
+41
-83
tests/models/test_falcon.py
tests/models/test_falcon.py
+1
-0
No files found.
flash_attn/models/gpt.py
View file @
798858f9
...
...
@@ -633,7 +633,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
)
assert
hidden_states
.
ndim
==
3
,
"sequence_parallel is not supported in generation mode"
if
inference_params
is
not
None
:
assert
hidden_states
.
ndim
==
3
,
"sequence_parallel is not supported in generation mode"
if
num_last_tokens
>
0
:
hidden_states
=
hidden_states
[:,
-
num_last_tokens
:]
if
self
.
project_out
is
not
None
:
...
...
flash_attn/modules/mha.py
View file @
798858f9
...
...
@@ -607,7 +607,7 @@ class MHA(nn.Module):
)
seqlen_offset
=
0
if
inference_params
is
None
else
inference_params
.
sequence_len_offset
rotary_max_seqlen
=
(
inference_params
.
max_sequene_len
if
inference_params
is
not
None
else
None
inference_params
.
max_sequen
c
e_len
if
inference_params
is
not
None
else
None
)
if
not
self
.
cross_attn
and
self
.
num_heads_kv
==
self
.
num_heads
:
assert
x_kv
is
None
and
mixer_subset
is
None
...
...
@@ -859,7 +859,7 @@ class ParallelMHA(nn.Module):
qkv
=
rearrange
(
qkv
,
"(b s) ... -> b s ..."
,
s
=
seqlen
)
seqlen_offset
=
0
if
inference_params
is
None
else
inference_params
.
sequence_len_offset
rotary_max_seqlen
=
(
inference_params
.
max_sequene_len
if
inference_params
is
not
None
else
None
inference_params
.
max_sequen
c
e_len
if
inference_params
is
not
None
else
None
)
if
self
.
num_heads_kv
==
self
.
num_heads
:
qkv
=
rearrange
(
qkv
,
"b s (three h d) -> b s three h d"
,
three
=
3
,
d
=
self
.
head_dim
)
...
...
tests/models/test_baichuan.py
View file @
798858f9
...
...
@@ -29,19 +29,15 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
from
flash_attn.utils.generation
import
update_graph_cache
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"Baichuan-7B"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"
baichuan-inc/
Baichuan-7B"
])
def
test_baichuan_state_dict
(
model_name
):
checkpoint_path
=
Path
(
os
.
environ
.
get
(
"CHECKPOINT_DIR"
,
current_dir
.
parent
.
parent
/
"checkpoints"
)
)
config
=
baichuan_config_to_gpt2_config
(
c
onfig
_
from_
checkpoint
(
checkpoint_path
,
model_nam
e
)
AutoC
onfig
.
from_
pretrained
(
model_name
,
trust_remote_code
=
Tru
e
)
)
ckpt_state_dicts
=
state_dicts_from_checkpoint
(
checkpoint_path
,
model_name
)
pretrained_state_dict
=
remap_state_dict_hf_baichuan
(
ckpt_state_dicts
[
0
],
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
"meta"
)
# Without device='meta' init is very slow
pretrained_state_dict
=
remap_state_dict_hf_baichuan
(
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
"meta"
)
# Without device='meta' init is very slow
state_dict
=
model
.
state_dict
()
assert
len
(
state_dict
.
keys
())
==
len
(
pretrained_state_dict
.
keys
())
assert
state_dict
.
keys
()
==
pretrained_state_dict
.
keys
()
...
...
@@ -49,20 +45,16 @@ def test_baichuan_state_dict(model_name):
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"Baichuan-7B"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"
baichuan-inc/
Baichuan-7B"
])
def
test_baichuan_optimized
(
model_name
):
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
checkpoint_path
=
Path
(
os
.
environ
.
get
(
"CHECKPOINT_DIR"
,
current_dir
.
parent
.
parent
/
"checkpoints"
)
)
dtype
=
torch
.
float16
device
=
"cuda"
config
=
baichuan_config_to_gpt2_config
(
c
onfig
_
from_
checkpoint
(
checkpoint_path
,
model_nam
e
)
AutoC
onfig
.
from_
pretrained
(
model_name
,
trust_remote_code
=
Tru
e
)
)
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
...
...
@@ -70,11 +62,9 @@ def test_baichuan_optimized(model_name):
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
ckpt_state_dicts
=
state_dicts_from_checkpoint
(
checkpoint_path
,
model_name
)
pretrained_state_dicts
=
[
remap_state_dict_hf_baichuan
(
s
,
config
)
for
s
in
ckpt_state_dicts
]
pretrained_state_dict
=
combine_state_dicts_tp
(
pretrained_state_dicts
,
config
)
pretrained_state_dict
=
remap_state_dict_hf_baichuan
(
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
pretrained_state_dict
)
model
.
eval
()
...
...
@@ -96,7 +86,7 @@ def test_baichuan_optimized(model_name):
# Without device_map, the model is loaded on the CPU, which is very slow
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref
=
AutoModelForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
device_map
=
"auto"
,
trust_remote_code
=
True
model_name
,
device_map
=
"auto"
,
trust_remote_code
=
True
)
model_ref
.
eval
()
with
torch
.
no_grad
():
...
...
@@ -105,10 +95,7 @@ def test_baichuan_optimized(model_name):
del
model_ref
model_hf
=
AutoModelForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
},
trust_remote_code
=
True
,
model_name
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
},
trust_remote_code
=
True
,
)
model_hf
.
eval
()
with
torch
.
no_grad
():
...
...
@@ -133,23 +120,19 @@ def test_baichuan_optimized(model_name):
).
abs
().
max
().
item
()
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "test_baichuan_parallel"
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "test_baichuan_parallel
_forward
"
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"Baichuan-7B"
])
def
test_baichuan_parallel
(
model_name
,
world_size
):
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"
baichuan-inc/
Baichuan-7B"
])
def
test_baichuan_parallel
_forward
(
model_name
,
world_size
):
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
from
apex.transformer
import
parallel_state
checkpoint_path
=
Path
(
os
.
environ
.
get
(
"CHECKPOINT_DIR"
,
current_dir
.
parent
.
parent
/
"checkpoints"
)
)
dtype
=
torch
.
float16
config
=
baichuan_config_to_gpt2_config
(
c
onfig
_
from_
checkpoint
(
checkpoint_path
,
model_nam
e
)
AutoC
onfig
.
from_
pretrained
(
model_name
,
trust_remote_code
=
Tru
e
)
)
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
...
...
@@ -165,11 +148,12 @@ def test_baichuan_parallel(model_name, world_size):
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
process_group
=
parallel_state
.
get_tensor_model_parallel_group
()
ckpt_state_dicts
=
state_dicts_from_checkpoint
(
checkpoint_path
,
model_name
)
pretrained_state_dicts
=
[
remap_state_dict_hf_baichuan
(
s
,
config
)
for
s
in
ckpt_state_dicts
]
pretrained_state_dict
=
combine_state_dicts_tp
(
pretrained_state_dicts
,
config
)
# Need this, otherwise the Triton kernel seems to launched from the wrong device.
torch
.
cuda
.
set_device
(
device
)
pretrained_state_dict
=
remap_state_dict_hf_baichuan
(
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
,
process_group
=
process_group
,
device
=
device
,
dtype
=
dtype
...
...
@@ -197,13 +181,12 @@ def test_baichuan_parallel(model_name, world_size):
logits
,
_
=
all_gather_raw
(
logits
,
process_group
)
logits
=
rearrange
(
logits
,
"(n b) ... d -> b ... (n d)"
,
b
=
batch_size
)
del
model
parallel_state
.
destroy_model_parallel
()
if
rank
==
0
:
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref
=
AutoModelForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
device_map
=
"auto"
,
trust_remote_code
=
True
,
model_name
,
device_map
=
"auto"
,
trust_remote_code
=
True
)
model_ref
.
eval
()
with
torch
.
no_grad
():
...
...
@@ -212,10 +195,7 @@ def test_baichuan_parallel(model_name, world_size):
del
model_ref
model_hf
=
AutoModelForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
torch_dtype
=
dtype
,
device_map
=
"auto"
,
trust_remote_code
=
True
,
model_name
,
torch_dtype
=
dtype
,
device_map
=
"auto"
,
trust_remote_code
=
True
)
model_hf
.
eval
()
with
torch
.
no_grad
():
...
...
@@ -240,16 +220,12 @@ def test_baichuan_parallel(model_name, world_size):
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"Baichuan-7B"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"
baichuan-inc/
Baichuan-7B"
])
def
test_baichuan_generation
(
model_name
):
checkpoint_path
=
Path
(
os
.
environ
.
get
(
"CHECKPOINT_DIR"
,
current_dir
.
parent
.
parent
/
"checkpoints"
)
)
dtype
=
torch
.
float16
device
=
"cuda"
config
=
baichuan_config_to_gpt2_config
(
c
onfig
_
from_
checkpoint
(
checkpoint_path
,
model_nam
e
)
AutoC
onfig
.
from_
pretrained
(
model_name
,
trust_remote_code
=
Tru
e
)
)
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
...
...
@@ -257,9 +233,7 @@ def test_baichuan_generation(model_name):
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
tokenizer
=
AutoTokenizer
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
trust_remote_code
=
True
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
eos_token_id
=
tokenizer
.
eos_token_id
torch
.
manual_seed
(
0
)
...
...
@@ -271,10 +245,7 @@ def test_baichuan_generation(model_name):
)
model_hf
=
AutoModelForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
},
trust_remote_code
=
True
,
model_name
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
},
trust_remote_code
=
True
)
model_hf
.
eval
()
print
(
"HF fp16"
)
...
...
@@ -292,7 +263,7 @@ def test_baichuan_generation(model_name):
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref
=
AutoModelForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
device_map
=
"auto"
,
trust_remote_code
=
True
model_name
,
device_map
=
"auto"
,
trust_remote_code
=
True
)
model_ref
.
eval
()
with
torch
.
no_grad
():
...
...
@@ -301,11 +272,9 @@ def test_baichuan_generation(model_name):
)
del
model_ref
ckpt_state_dicts
=
state_dicts_from_checkpoint
(
checkpoint_path
,
model_name
)
pretrained_state_dicts
=
[
remap_state_dict_hf_baichuan
(
s
,
config
)
for
s
in
ckpt_state_dicts
]
pretrained_state_dict
=
combine_state_dicts_tp
(
pretrained_state_dicts
,
config
)
pretrained_state_dict
=
remap_state_dict_hf_baichuan
(
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
pretrained_state_dict
)
model
.
eval
()
...
...
@@ -368,7 +337,7 @@ def test_baichuan_generation(model_name):
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "baichuan_parallel_generation"
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"Baichuan-7B"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"
baichuan-inc/
Baichuan-7B"
])
def
test_baichuan_parallel_generation
(
model_name
,
world_size
):
"""Check that our implementation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
...
...
@@ -376,13 +345,9 @@ def test_baichuan_parallel_generation(model_name, world_size):
"""
from
apex.transformer
import
parallel_state
checkpoint_path
=
Path
(
os
.
environ
.
get
(
"CHECKPOINT_DIR"
,
current_dir
.
parent
.
parent
/
"checkpoints"
)
)
dtype
=
torch
.
float16
config
=
baichuan_config_to_gpt2_config
(
c
onfig
_
from_
checkpoint
(
checkpoint_path
,
model_nam
e
)
AutoC
onfig
.
from_
pretrained
(
model_name
,
trust_remote_code
=
Tru
e
)
)
config
.
use_flash_attn
=
False
config
.
fused_bias_fc
=
True
...
...
@@ -413,11 +378,9 @@ def test_baichuan_parallel_generation(model_name, world_size):
# GPU0 and GPU1 and things would hang
torch
.
cuda
.
set_device
(
device
)
ckpt_state_dicts
=
state_dicts_from_checkpoint
(
checkpoint_path
,
model_name
)
pretrained_state_dicts
=
[
remap_state_dict_hf_baichuan
(
s
,
config
)
for
s
in
ckpt_state_dicts
]
pretrained_state_dict
=
combine_state_dicts_tp
(
pretrained_state_dicts
,
config
)
pretrained_state_dict
=
remap_state_dict_hf_baichuan
(
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
,
process_group
=
process_group
,
device
=
device
,
dtype
=
dtype
...
...
@@ -464,10 +427,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
if
rank
==
0
:
# Without device_map, the model is loaded on the CPU, which is very slow
model_hf
=
AutoModelForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
torch_dtype
=
dtype
,
device_map
=
"auto"
,
trust_remote_code
=
True
,
model_name
,
torch_dtype
=
dtype
,
device_map
=
"auto"
,
trust_remote_code
=
True
)
model_hf
.
eval
()
print
(
"HF fp16"
)
...
...
@@ -487,9 +447,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
del
model_hf
model_ref
=
AutoModelForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
device_map
=
"auto"
,
trust_remote_code
=
True
,
model_name
,
device_map
=
"auto"
,
trust_remote_code
=
True
)
model_ref
.
eval
()
with
torch
.
inference_mode
():
...
...
tests/models/test_falcon.py
View file @
798858f9
...
...
@@ -146,6 +146,7 @@ def test_falcon_parallel_forward(model_name, world_size):
logits
,
_
=
all_gather_raw
(
logits
,
process_group
)
logits
=
rearrange
(
logits
,
"(n b) ... d -> b ... (n d)"
,
b
=
batch_size
)
del
model
parallel_state
.
destroy_model_parallel
()
if
rank
==
0
:
model_hf
=
AutoModelForCausalLM
.
from_pretrained
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment