Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
2c7d7b73
Commit
2c7d7b73
authored
Dec 22, 2023
by
Tri Dao
Browse files
Implement norm head for Baichuan2
parent
68f178aa
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
64 additions
and
40 deletions
+64
-40
flash_attn/models/baichuan.py
flash_attn/models/baichuan.py
+4
-0
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+16
-5
tests/models/test_baichuan.py
tests/models/test_baichuan.py
+44
-35
No files found.
flash_attn/models/baichuan.py
View file @
2c7d7b73
...
...
@@ -116,6 +116,9 @@ def remap_state_dict_hf_baichuan(state_dict, config):
def
baichuan_config_to_gpt2_config
(
baichuan_config
:
PretrainedConfig
)
->
GPT2Config
:
# HACK: the config doesn't have say whether it's rotary or alibi.
# So we have to infer from the hidden size (7B -> rotary, 13B -> alibi).
# HACK: the config doesn't have say whether it uses norm head.
# So we have to infer from the vocab size
# (v1, vocab size 64k, no norm head; v2, vocab size 128k, norm head).
use_rotary
=
baichuan_config
.
hidden_size
<
5000
return
GPT2Config
(
vocab_size
=
baichuan_config
.
vocab_size
,
...
...
@@ -141,6 +144,7 @@ def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) -> GPT2Con
use_alibi
=
not
use_rotary
,
use_flash_attn
=
not
use_rotary
,
# Alibi code path requires flash_attn
tie_word_embeddings
=
False
,
norm_head
=
baichuan_config
.
vocab_size
>
70000
,
qkv_proj_bias
=
False
,
out_proj_bias
=
False
,
mlp_fc1_bias
=
False
,
...
...
flash_attn/models/gpt.py
View file @
2c7d7b73
...
...
@@ -32,7 +32,12 @@ from flash_attn.modules.mlp import (
ParallelMLP
,
)
from
flash_attn.ops.activations
import
sqrelu_fwd
from
flash_attn.utils.distributed
import
all_gather_raw
,
get_dim_for_local_rank
,
sync_shared_params
from
flash_attn.utils.distributed
import
(
all_gather
,
all_gather_raw
,
get_dim_for_local_rank
,
sync_shared_params
,
)
from
flash_attn.utils.generation
import
GenerationMixin
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
...
...
@@ -355,9 +360,8 @@ class GPTPreTrainedModel(nn.Module):
state_dict
=
remap_state_dict_hf_gpt2
(
state_dict
,
config
)
elif
model_name
.
startswith
(
"facebook/opt"
):
state_dict
=
remap_state_dict_hf_opt
(
state_dict
,
config
)
elif
(
model_name
.
startswith
(
"EleutherAI/gpt-j-"
)
or
model_name
.
startswith
(
"togethercomputer/GPT-JT-"
)
elif
model_name
.
startswith
(
"EleutherAI/gpt-j-"
)
or
model_name
.
startswith
(
"togethercomputer/GPT-JT-"
):
state_dict
=
remap_state_dict_hf_gptj
(
state_dict
,
config
)
elif
(
...
...
@@ -621,6 +625,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
sequence_parallel
=
getattr
(
config
,
"sequence_parallel"
,
True
),
**
factory_kwargs
,
)
self
.
norm_head
=
getattr
(
config
,
"norm_head"
,
False
)
# Initialize weights and apply final processing
self
.
apply
(
partial
(
...
...
@@ -662,7 +667,13 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
hidden_states
=
hidden_states
[:,
-
num_last_tokens
:]
if
self
.
project_out
is
not
None
:
hidden_states
=
self
.
project_out
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
if
not
self
.
norm_head
:
lm_logits
=
self
.
lm_head
(
hidden_states
)
else
:
lm_head_weight
=
F
.
normalize
(
self
.
lm_head
.
weight
)
if
isinstance
(
self
.
lm_head
,
ColumnParallelLinear
)
and
self
.
lm_head
.
sequence_parallel
:
hidden_states
=
all_gather
(
hidden_states
,
self
.
lm_head
.
process_group
)
lm_logits
=
F
.
linear
(
hidden_states
,
lm_head_weight
,
bias
=
self
.
lm_head
.
bias
)
# During inference, we want the full logit for sampling
if
isinstance
(
self
.
lm_head
,
ColumnParallelLinear
)
and
inference_params
is
not
None
:
lm_logits
,
_
=
all_gather_raw
(
lm_logits
,
self
.
lm_head
.
process_group
)
...
...
tests/models/test_baichuan.py
View file @
2c7d7b73
...
...
@@ -23,7 +23,15 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
from
flash_attn.utils.generation
import
update_graph_cache
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"baichuan-inc/Baichuan-7B"
,
"baichuan-inc/Baichuan-13B-Base"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"baichuan-inc/Baichuan-7B"
,
"baichuan-inc/Baichuan-13B-Base"
,
"baichuan-inc/Baichuan2-7B-Base"
,
"baichuan-inc/Baichuan2-13B-Base"
,
],
)
def
test_baichuan_state_dict
(
model_name
):
config
=
baichuan_config_to_gpt2_config
(
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
...
...
@@ -39,7 +47,15 @@ def test_baichuan_state_dict(model_name):
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"baichuan-inc/Baichuan-7B"
,
"baichuan-inc/Baichuan-13B-Base"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"baichuan-inc/Baichuan-7B"
,
"baichuan-inc/Baichuan-13B-Base"
,
"baichuan-inc/Baichuan2-7B-Base"
,
"baichuan-inc/Baichuan2-13B-Base"
,
],
)
def
test_baichuan_optimized
(
model_name
):
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
...
...
@@ -66,9 +82,7 @@ def test_baichuan_optimized(model_name):
torch
.
manual_seed
(
0
)
batch_size
=
2
max_seqlen
=
256
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
device
)
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
device
)
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
...
...
@@ -89,7 +103,10 @@ def test_baichuan_optimized(model_name):
del
model_ref
model_hf
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
},
trust_remote_code
=
True
,
model_name
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
},
trust_remote_code
=
True
,
)
model_hf
.
eval
()
with
torch
.
no_grad
():
...
...
@@ -101,9 +118,7 @@ def test_baichuan_optimized(model_name):
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"HF fp16 max diff:
{
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"HF fp16 mean diff:
{
(
out_hf
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
3
*
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
3
*
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
print
(
f
"Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Logits mean diff:
{
(
logits
-
logits_ref
).
abs
().
mean
().
item
()
}
"
)
...
...
@@ -116,7 +131,15 @@ def test_baichuan_optimized(model_name):
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "test_baichuan_parallel_forward"
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"baichuan-inc/Baichuan-7B"
,
"baichuan-inc/Baichuan-13B-Base"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"baichuan-inc/Baichuan-7B"
,
"baichuan-inc/Baichuan-13B-Base"
,
"baichuan-inc/Baichuan2-7B-Base"
,
"baichuan-inc/Baichuan2-13B-Base"
,
],
)
def
test_baichuan_parallel_forward
(
model_name
,
world_size
):
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
...
...
@@ -146,20 +169,14 @@ def test_baichuan_parallel_forward(model_name, world_size):
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
,
process_group
=
process_group
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
shard_state_dict_tp
(
pretrained_state_dict
,
config
,
world_size
,
rank
)
)
model
=
GPTLMHeadModel
(
config
,
process_group
=
process_group
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
shard_state_dict_tp
(
pretrained_state_dict
,
config
,
world_size
,
rank
))
model
.
eval
()
torch
.
manual_seed
(
0
)
batch_size
=
2
max_seqlen
=
256
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
device
)
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
device
)
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
...
...
@@ -198,9 +215,7 @@ def test_baichuan_parallel_forward(model_name, world_size):
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"HF fp16 max diff:
{
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"HF fp16 mean diff:
{
(
out_hf
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
2
*
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
2
*
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
print
(
f
"Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Logits mean diff:
{
(
logits
-
logits_ref
).
abs
().
mean
().
item
()
}
"
)
...
...
@@ -211,7 +226,9 @@ def test_baichuan_parallel_forward(model_name, world_size):
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"baichuan-inc/Baichuan-7B"
,
"baichuan-inc/Baichuan-13B-Base"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"baichuan-inc/Baichuan-7B"
,
"baichuan-inc/Baichuan-13B-Base"
]
)
def
test_baichuan_generation
(
model_name
):
dtype
=
torch
.
float16
device
=
"cuda"
...
...
@@ -258,9 +275,7 @@ def test_baichuan_generation(model_name):
)
model_ref
.
eval
()
with
torch
.
no_grad
():
logits_ref
=
(
model_ref
(
out_hf
.
sequences
).
logits
[:,
(
seqlen
-
1
)
:
-
1
].
to
(
device
=
device
)
)
logits_ref
=
model_ref
(
out_hf
.
sequences
).
logits
[:,
(
seqlen
-
1
)
:
-
1
].
to
(
device
=
device
)
del
model_ref
pretrained_state_dict
=
remap_state_dict_hf_baichuan
(
...
...
@@ -370,12 +385,8 @@ def test_baichuan_parallel_generation(model_name, world_size):
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
,
process_group
=
process_group
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
shard_state_dict_tp
(
pretrained_state_dict
,
config
,
world_size
,
rank
)
)
model
=
GPTLMHeadModel
(
config
,
process_group
=
process_group
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
shard_state_dict_tp
(
pretrained_state_dict
,
config
,
world_size
,
rank
))
model
.
eval
()
print
(
"Without CUDA graph"
)
...
...
@@ -425,9 +436,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
output_scores
=
True
,
)
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms"
)
print
(
f
"Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms"
)
del
model_hf
model_ref
=
AutoModelForCausalLM
.
from_pretrained
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment