Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
798858f9
Commit
798858f9
authored
Sep 03, 2023
by
Tri Dao
Browse files
Fix test_baichuan
parent
7b33743a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
46 additions
and
86 deletions
+46
-86
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+2
-1
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+2
-2
tests/models/test_baichuan.py
tests/models/test_baichuan.py
+41
-83
tests/models/test_falcon.py
tests/models/test_falcon.py
+1
-0
No files found.
flash_attn/models/gpt.py
View file @
798858f9
...
@@ -633,7 +633,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
...
@@ -633,7 +633,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
hidden_states
=
self
.
transformer
(
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
)
)
assert
hidden_states
.
ndim
==
3
,
"sequence_parallel is not supported in generation mode"
if
inference_params
is
not
None
:
assert
hidden_states
.
ndim
==
3
,
"sequence_parallel is not supported in generation mode"
if
num_last_tokens
>
0
:
if
num_last_tokens
>
0
:
hidden_states
=
hidden_states
[:,
-
num_last_tokens
:]
hidden_states
=
hidden_states
[:,
-
num_last_tokens
:]
if
self
.
project_out
is
not
None
:
if
self
.
project_out
is
not
None
:
...
...
flash_attn/modules/mha.py
View file @
798858f9
...
@@ -607,7 +607,7 @@ class MHA(nn.Module):
...
@@ -607,7 +607,7 @@ class MHA(nn.Module):
)
)
seqlen_offset
=
0
if
inference_params
is
None
else
inference_params
.
sequence_len_offset
seqlen_offset
=
0
if
inference_params
is
None
else
inference_params
.
sequence_len_offset
rotary_max_seqlen
=
(
rotary_max_seqlen
=
(
inference_params
.
max_sequene_len
if
inference_params
is
not
None
else
None
inference_params
.
max_sequen
c
e_len
if
inference_params
is
not
None
else
None
)
)
if
not
self
.
cross_attn
and
self
.
num_heads_kv
==
self
.
num_heads
:
if
not
self
.
cross_attn
and
self
.
num_heads_kv
==
self
.
num_heads
:
assert
x_kv
is
None
and
mixer_subset
is
None
assert
x_kv
is
None
and
mixer_subset
is
None
...
@@ -859,7 +859,7 @@ class ParallelMHA(nn.Module):
...
@@ -859,7 +859,7 @@ class ParallelMHA(nn.Module):
qkv
=
rearrange
(
qkv
,
"(b s) ... -> b s ..."
,
s
=
seqlen
)
qkv
=
rearrange
(
qkv
,
"(b s) ... -> b s ..."
,
s
=
seqlen
)
seqlen_offset
=
0
if
inference_params
is
None
else
inference_params
.
sequence_len_offset
seqlen_offset
=
0
if
inference_params
is
None
else
inference_params
.
sequence_len_offset
rotary_max_seqlen
=
(
rotary_max_seqlen
=
(
inference_params
.
max_sequene_len
if
inference_params
is
not
None
else
None
inference_params
.
max_sequen
c
e_len
if
inference_params
is
not
None
else
None
)
)
if
self
.
num_heads_kv
==
self
.
num_heads
:
if
self
.
num_heads_kv
==
self
.
num_heads
:
qkv
=
rearrange
(
qkv
,
"b s (three h d) -> b s three h d"
,
three
=
3
,
d
=
self
.
head_dim
)
qkv
=
rearrange
(
qkv
,
"b s (three h d) -> b s three h d"
,
three
=
3
,
d
=
self
.
head_dim
)
...
...
tests/models/test_baichuan.py
View file @
798858f9
...
@@ -29,19 +29,15 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
...
@@ -29,19 +29,15 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
from
flash_attn.utils.generation
import
update_graph_cache
from
flash_attn.utils.generation
import
update_graph_cache
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"Baichuan-7B"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"
baichuan-inc/
Baichuan-7B"
])
def
test_baichuan_state_dict
(
model_name
):
def
test_baichuan_state_dict
(
model_name
):
checkpoint_path
=
Path
(
os
.
environ
.
get
(
"CHECKPOINT_DIR"
,
current_dir
.
parent
.
parent
/
"checkpoints"
)
)
config
=
baichuan_config_to_gpt2_config
(
config
=
baichuan_config_to_gpt2_config
(
c
onfig
_
from_
checkpoint
(
checkpoint_path
,
model_nam
e
)
AutoC
onfig
.
from_
pretrained
(
model_name
,
trust_remote_code
=
Tru
e
)
)
)
ckpt_state_dicts
=
state_dicts_from_checkpoint
(
checkpoint_path
,
model_name
)
pretrained_state_dict
=
remap_state_dict_hf_baichuan
(
pretrained_state_dict
=
remap_state_dict_hf_baichuan
(
ckpt_state_dicts
[
0
],
config
)
state_dict_from_pretrained
(
model_name
),
config
model
=
GPTLMHeadModel
(
)
config
,
device
=
"meta"
model
=
GPTLMHeadModel
(
config
,
device
=
"meta"
)
# Without device='meta' init is very slow
)
# Without device='meta' init is very slow
state_dict
=
model
.
state_dict
()
state_dict
=
model
.
state_dict
()
assert
len
(
state_dict
.
keys
())
==
len
(
pretrained_state_dict
.
keys
())
assert
len
(
state_dict
.
keys
())
==
len
(
pretrained_state_dict
.
keys
())
assert
state_dict
.
keys
()
==
pretrained_state_dict
.
keys
()
assert
state_dict
.
keys
()
==
pretrained_state_dict
.
keys
()
...
@@ -49,20 +45,16 @@ def test_baichuan_state_dict(model_name):
...
@@ -49,20 +45,16 @@ def test_baichuan_state_dict(model_name):
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"Baichuan-7B"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"
baichuan-inc/
Baichuan-7B"
])
def
test_baichuan_optimized
(
model_name
):
def
test_baichuan_optimized
(
model_name
):
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
"""
checkpoint_path
=
Path
(
os
.
environ
.
get
(
"CHECKPOINT_DIR"
,
current_dir
.
parent
.
parent
/
"checkpoints"
)
)
dtype
=
torch
.
float16
dtype
=
torch
.
float16
device
=
"cuda"
device
=
"cuda"
config
=
baichuan_config_to_gpt2_config
(
config
=
baichuan_config_to_gpt2_config
(
c
onfig
_
from_
checkpoint
(
checkpoint_path
,
model_nam
e
)
AutoC
onfig
.
from_
pretrained
(
model_name
,
trust_remote_code
=
Tru
e
)
)
)
config
.
use_flash_attn
=
True
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_bias_fc
=
True
...
@@ -70,11 +62,9 @@ def test_baichuan_optimized(model_name):
...
@@ -70,11 +62,9 @@ def test_baichuan_optimized(model_name):
config
.
fused_dropout_add_ln
=
True
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
config
.
residual_in_fp32
=
True
ckpt_state_dicts
=
state_dicts_from_checkpoint
(
checkpoint_path
,
model_name
)
pretrained_state_dict
=
remap_state_dict_hf_baichuan
(
pretrained_state_dicts
=
[
state_dict_from_pretrained
(
model_name
),
config
remap_state_dict_hf_baichuan
(
s
,
config
)
for
s
in
ckpt_state_dicts
)
]
pretrained_state_dict
=
combine_state_dicts_tp
(
pretrained_state_dicts
,
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
pretrained_state_dict
)
model
.
load_state_dict
(
pretrained_state_dict
)
model
.
eval
()
model
.
eval
()
...
@@ -96,7 +86,7 @@ def test_baichuan_optimized(model_name):
...
@@ -96,7 +86,7 @@ def test_baichuan_optimized(model_name):
# Without device_map, the model is loaded on the CPU, which is very slow
# Without device_map, the model is loaded on the CPU, which is very slow
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref
=
AutoModelForCausalLM
.
from_pretrained
(
model_ref
=
AutoModelForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
device_map
=
"auto"
,
trust_remote_code
=
True
model_name
,
device_map
=
"auto"
,
trust_remote_code
=
True
)
)
model_ref
.
eval
()
model_ref
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -105,10 +95,7 @@ def test_baichuan_optimized(model_name):
...
@@ -105,10 +95,7 @@ def test_baichuan_optimized(model_name):
del
model_ref
del
model_ref
model_hf
=
AutoModelForCausalLM
.
from_pretrained
(
model_hf
=
AutoModelForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
model_name
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
},
trust_remote_code
=
True
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
},
trust_remote_code
=
True
,
)
)
model_hf
.
eval
()
model_hf
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -133,23 +120,19 @@ def test_baichuan_optimized(model_name):
...
@@ -133,23 +120,19 @@ def test_baichuan_optimized(model_name):
).
abs
().
max
().
item
()
).
abs
().
max
().
item
()
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "test_baichuan_parallel"
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "test_baichuan_parallel
_forward
"
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"Baichuan-7B"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"
baichuan-inc/
Baichuan-7B"
])
def
test_baichuan_parallel
(
model_name
,
world_size
):
def
test_baichuan_parallel
_forward
(
model_name
,
world_size
):
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
"""
from
apex.transformer
import
parallel_state
from
apex.transformer
import
parallel_state
checkpoint_path
=
Path
(
os
.
environ
.
get
(
"CHECKPOINT_DIR"
,
current_dir
.
parent
.
parent
/
"checkpoints"
)
)
dtype
=
torch
.
float16
dtype
=
torch
.
float16
config
=
baichuan_config_to_gpt2_config
(
config
=
baichuan_config_to_gpt2_config
(
c
onfig
_
from_
checkpoint
(
checkpoint_path
,
model_nam
e
)
AutoC
onfig
.
from_
pretrained
(
model_name
,
trust_remote_code
=
Tru
e
)
)
)
config
.
use_flash_attn
=
True
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_bias_fc
=
True
...
@@ -165,11 +148,12 @@ def test_baichuan_parallel(model_name, world_size):
...
@@ -165,11 +148,12 @@ def test_baichuan_parallel(model_name, world_size):
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
process_group
=
parallel_state
.
get_tensor_model_parallel_group
()
process_group
=
parallel_state
.
get_tensor_model_parallel_group
()
ckpt_state_dicts
=
state_dicts_from_checkpoint
(
checkpoint_path
,
model_name
)
# Need this, otherwise the Triton kernel seems to launched from the wrong device.
pretrained_state_dicts
=
[
torch
.
cuda
.
set_device
(
device
)
remap_state_dict_hf_baichuan
(
s
,
config
)
for
s
in
ckpt_state_dicts
]
pretrained_state_dict
=
remap_state_dict_hf_baichuan
(
pretrained_state_dict
=
combine_state_dicts_tp
(
pretrained_state_dicts
,
config
)
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
model
=
GPTLMHeadModel
(
config
,
process_group
=
process_group
,
device
=
device
,
dtype
=
dtype
config
,
process_group
=
process_group
,
device
=
device
,
dtype
=
dtype
...
@@ -197,13 +181,12 @@ def test_baichuan_parallel(model_name, world_size):
...
@@ -197,13 +181,12 @@ def test_baichuan_parallel(model_name, world_size):
logits
,
_
=
all_gather_raw
(
logits
,
process_group
)
logits
,
_
=
all_gather_raw
(
logits
,
process_group
)
logits
=
rearrange
(
logits
,
"(n b) ... d -> b ... (n d)"
,
b
=
batch_size
)
logits
=
rearrange
(
logits
,
"(n b) ... d -> b ... (n d)"
,
b
=
batch_size
)
del
model
del
model
parallel_state
.
destroy_model_parallel
()
if
rank
==
0
:
if
rank
==
0
:
# Without device_map, the model is loaded on the CPU, which is very slow
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref
=
AutoModelForCausalLM
.
from_pretrained
(
model_ref
=
AutoModelForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
model_name
,
device_map
=
"auto"
,
trust_remote_code
=
True
device_map
=
"auto"
,
trust_remote_code
=
True
,
)
)
model_ref
.
eval
()
model_ref
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -212,10 +195,7 @@ def test_baichuan_parallel(model_name, world_size):
...
@@ -212,10 +195,7 @@ def test_baichuan_parallel(model_name, world_size):
del
model_ref
del
model_ref
model_hf
=
AutoModelForCausalLM
.
from_pretrained
(
model_hf
=
AutoModelForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
model_name
,
torch_dtype
=
dtype
,
device_map
=
"auto"
,
trust_remote_code
=
True
torch_dtype
=
dtype
,
device_map
=
"auto"
,
trust_remote_code
=
True
,
)
)
model_hf
.
eval
()
model_hf
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -240,16 +220,12 @@ def test_baichuan_parallel(model_name, world_size):
...
@@ -240,16 +220,12 @@ def test_baichuan_parallel(model_name, world_size):
).
abs
().
max
().
item
()
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"Baichuan-7B"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"
baichuan-inc/
Baichuan-7B"
])
def
test_baichuan_generation
(
model_name
):
def
test_baichuan_generation
(
model_name
):
checkpoint_path
=
Path
(
os
.
environ
.
get
(
"CHECKPOINT_DIR"
,
current_dir
.
parent
.
parent
/
"checkpoints"
)
)
dtype
=
torch
.
float16
dtype
=
torch
.
float16
device
=
"cuda"
device
=
"cuda"
config
=
baichuan_config_to_gpt2_config
(
config
=
baichuan_config_to_gpt2_config
(
c
onfig
_
from_
checkpoint
(
checkpoint_path
,
model_nam
e
)
AutoC
onfig
.
from_
pretrained
(
model_name
,
trust_remote_code
=
Tru
e
)
)
)
config
.
use_flash_attn
=
True
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_bias_fc
=
True
...
@@ -257,9 +233,7 @@ def test_baichuan_generation(model_name):
...
@@ -257,9 +233,7 @@ def test_baichuan_generation(model_name):
config
.
fused_dropout_add_ln
=
True
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
config
.
residual_in_fp32
=
True
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
Path
(
checkpoint_path
)
/
model_name
,
trust_remote_code
=
True
)
eos_token_id
=
tokenizer
.
eos_token_id
eos_token_id
=
tokenizer
.
eos_token_id
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
...
@@ -271,10 +245,7 @@ def test_baichuan_generation(model_name):
...
@@ -271,10 +245,7 @@ def test_baichuan_generation(model_name):
)
)
model_hf
=
AutoModelForCausalLM
.
from_pretrained
(
model_hf
=
AutoModelForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
model_name
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
},
trust_remote_code
=
True
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
},
trust_remote_code
=
True
,
)
)
model_hf
.
eval
()
model_hf
.
eval
()
print
(
"HF fp16"
)
print
(
"HF fp16"
)
...
@@ -292,7 +263,7 @@ def test_baichuan_generation(model_name):
...
@@ -292,7 +263,7 @@ def test_baichuan_generation(model_name):
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref
=
AutoModelForCausalLM
.
from_pretrained
(
model_ref
=
AutoModelForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
device_map
=
"auto"
,
trust_remote_code
=
True
model_name
,
device_map
=
"auto"
,
trust_remote_code
=
True
)
)
model_ref
.
eval
()
model_ref
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -301,11 +272,9 @@ def test_baichuan_generation(model_name):
...
@@ -301,11 +272,9 @@ def test_baichuan_generation(model_name):
)
)
del
model_ref
del
model_ref
ckpt_state_dicts
=
state_dicts_from_checkpoint
(
checkpoint_path
,
model_name
)
pretrained_state_dict
=
remap_state_dict_hf_baichuan
(
pretrained_state_dicts
=
[
state_dict_from_pretrained
(
model_name
),
config
remap_state_dict_hf_baichuan
(
s
,
config
)
for
s
in
ckpt_state_dicts
)
]
pretrained_state_dict
=
combine_state_dicts_tp
(
pretrained_state_dicts
,
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
pretrained_state_dict
)
model
.
load_state_dict
(
pretrained_state_dict
)
model
.
eval
()
model
.
eval
()
...
@@ -368,7 +337,7 @@ def test_baichuan_generation(model_name):
...
@@ -368,7 +337,7 @@ def test_baichuan_generation(model_name):
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "baichuan_parallel_generation"
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "baichuan_parallel_generation"
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"Baichuan-7B"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"
baichuan-inc/
Baichuan-7B"
])
def
test_baichuan_parallel_generation
(
model_name
,
world_size
):
def
test_baichuan_parallel_generation
(
model_name
,
world_size
):
"""Check that our implementation matches the HF implementation:
"""Check that our implementation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
...
@@ -376,13 +345,9 @@ def test_baichuan_parallel_generation(model_name, world_size):
...
@@ -376,13 +345,9 @@ def test_baichuan_parallel_generation(model_name, world_size):
"""
"""
from
apex.transformer
import
parallel_state
from
apex.transformer
import
parallel_state
checkpoint_path
=
Path
(
os
.
environ
.
get
(
"CHECKPOINT_DIR"
,
current_dir
.
parent
.
parent
/
"checkpoints"
)
)
dtype
=
torch
.
float16
dtype
=
torch
.
float16
config
=
baichuan_config_to_gpt2_config
(
config
=
baichuan_config_to_gpt2_config
(
c
onfig
_
from_
checkpoint
(
checkpoint_path
,
model_nam
e
)
AutoC
onfig
.
from_
pretrained
(
model_name
,
trust_remote_code
=
Tru
e
)
)
)
config
.
use_flash_attn
=
False
config
.
use_flash_attn
=
False
config
.
fused_bias_fc
=
True
config
.
fused_bias_fc
=
True
...
@@ -413,11 +378,9 @@ def test_baichuan_parallel_generation(model_name, world_size):
...
@@ -413,11 +378,9 @@ def test_baichuan_parallel_generation(model_name, world_size):
# GPU0 and GPU1 and things would hang
# GPU0 and GPU1 and things would hang
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
ckpt_state_dicts
=
state_dicts_from_checkpoint
(
checkpoint_path
,
model_name
)
pretrained_state_dict
=
remap_state_dict_hf_baichuan
(
pretrained_state_dicts
=
[
state_dict_from_pretrained
(
model_name
),
config
remap_state_dict_hf_baichuan
(
s
,
config
)
for
s
in
ckpt_state_dicts
)
]
pretrained_state_dict
=
combine_state_dicts_tp
(
pretrained_state_dicts
,
config
)
model
=
GPTLMHeadModel
(
model
=
GPTLMHeadModel
(
config
,
process_group
=
process_group
,
device
=
device
,
dtype
=
dtype
config
,
process_group
=
process_group
,
device
=
device
,
dtype
=
dtype
...
@@ -464,10 +427,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
...
@@ -464,10 +427,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
if
rank
==
0
:
if
rank
==
0
:
# Without device_map, the model is loaded on the CPU, which is very slow
# Without device_map, the model is loaded on the CPU, which is very slow
model_hf
=
AutoModelForCausalLM
.
from_pretrained
(
model_hf
=
AutoModelForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
model_name
,
torch_dtype
=
dtype
,
device_map
=
"auto"
,
trust_remote_code
=
True
torch_dtype
=
dtype
,
device_map
=
"auto"
,
trust_remote_code
=
True
,
)
)
model_hf
.
eval
()
model_hf
.
eval
()
print
(
"HF fp16"
)
print
(
"HF fp16"
)
...
@@ -487,9 +447,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
...
@@ -487,9 +447,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
del
model_hf
del
model_hf
model_ref
=
AutoModelForCausalLM
.
from_pretrained
(
model_ref
=
AutoModelForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
model_name
,
device_map
=
"auto"
,
trust_remote_code
=
True
device_map
=
"auto"
,
trust_remote_code
=
True
,
)
)
model_ref
.
eval
()
model_ref
.
eval
()
with
torch
.
inference_mode
():
with
torch
.
inference_mode
():
...
...
tests/models/test_falcon.py
View file @
798858f9
...
@@ -146,6 +146,7 @@ def test_falcon_parallel_forward(model_name, world_size):
...
@@ -146,6 +146,7 @@ def test_falcon_parallel_forward(model_name, world_size):
logits
,
_
=
all_gather_raw
(
logits
,
process_group
)
logits
,
_
=
all_gather_raw
(
logits
,
process_group
)
logits
=
rearrange
(
logits
,
"(n b) ... d -> b ... (n d)"
,
b
=
batch_size
)
logits
=
rearrange
(
logits
,
"(n b) ... d -> b ... (n d)"
,
b
=
batch_size
)
del
model
del
model
parallel_state
.
destroy_model_parallel
()
if
rank
==
0
:
if
rank
==
0
:
model_hf
=
AutoModelForCausalLM
.
from_pretrained
(
model_hf
=
AutoModelForCausalLM
.
from_pretrained
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment