Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
c3b21966
"vscode:/vscode.git/clone" did not exist on "5b4f5951a99f3569222910e7b25fc0e0a42418aa"
Commit
c3b21966
authored
Dec 21, 2023
by
Tri Dao
Browse files
Add Alibi to MHA, test with Baichuan-13B
parent
701b51bf
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
84 additions
and
41 deletions
+84
-41
flash_attn/models/baichuan.py
flash_attn/models/baichuan.py
+7
-20
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+2
-0
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+69
-9
flash_attn/modules/mlp.py
flash_attn/modules/mlp.py
+2
-2
tests/models/test_baichuan.py
tests/models/test_baichuan.py
+4
-10
No files found.
flash_attn/models/baichuan.py
View file @
c3b21966
...
@@ -109,29 +109,14 @@ def remap_state_dict_hf_baichuan(state_dict, config):
...
@@ -109,29 +109,14 @@ def remap_state_dict_hf_baichuan(state_dict, config):
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
for
l
in
range
(
config
.
n_layer
):
for
l
in
range
(
config
.
n_layer
):
# pop rotary_emb.inv_freq from state dict
# pop rotary_emb.inv_freq from state dict
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.rotary_emb.inv_freq"
)
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.rotary_emb.inv_freq"
,
None
)
return
state_dict
return
state_dict
def
config_from_checkpoint
(
checkpoint_path
:
str
,
model_name
:
str
)
->
PretrainedConfig
:
"""Load a BaiChuanConfig from a checkpoint path."""
config
=
AutoConfig
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
trust_remote_code
=
True
)
return
config
def
state_dicts_from_checkpoint
(
checkpoint_path
:
str
,
model_name
:
str
)
->
dict
:
# Need to sort, otherwise we mess up the ordering and the weights are wrong
return
[
torch
.
load
(
path
,
map_location
=
"cpu"
)
for
path
in
sorted
(
(
Path
(
checkpoint_path
)
/
model_name
).
glob
(
"pytorch_model*.bin"
)
)
]
def
baichuan_config_to_gpt2_config
(
baichuan_config
:
PretrainedConfig
)
->
GPT2Config
:
def
baichuan_config_to_gpt2_config
(
baichuan_config
:
PretrainedConfig
)
->
GPT2Config
:
# HACK: the config doesn't have say whether it's rotary or alibi.
# So we have to infer from the hidden size (7B -> rotary, 13B -> alibi).
use_rotary
=
baichuan_config
.
hidden_size
<
5000
return
GPT2Config
(
return
GPT2Config
(
vocab_size
=
baichuan_config
.
vocab_size
,
vocab_size
=
baichuan_config
.
vocab_size
,
n_positions
=
0
,
# No absolute position embedding
n_positions
=
0
,
# No absolute position embedding
...
@@ -151,8 +136,10 @@ def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) -> GPT2Con
...
@@ -151,8 +136,10 @@ def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) -> GPT2Con
# These are new arguments not in the original GPT2Config
# These are new arguments not in the original GPT2Config
pad_token_id
=
baichuan_config
.
pad_token_id
,
# Idk if this does anything
pad_token_id
=
baichuan_config
.
pad_token_id
,
# Idk if this does anything
rms_norm
=
True
,
rms_norm
=
True
,
rotary_emb_fraction
=
1.0
,
rotary_emb_fraction
=
1.0
if
use_rotary
else
0.0
,
rotary_emb_interleaved
=
False
,
rotary_emb_interleaved
=
False
,
use_alibi
=
not
use_rotary
,
use_flash_attn
=
not
use_rotary
,
# Alibi code path requires flash_attn
tie_word_embeddings
=
False
,
tie_word_embeddings
=
False
,
qkv_proj_bias
=
False
,
qkv_proj_bias
=
False
,
out_proj_bias
=
False
,
out_proj_bias
=
False
,
...
...
flash_attn/models/gpt.py
View file @
c3b21966
...
@@ -85,6 +85,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
...
@@ -85,6 +85,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
rotary_emb_base
=
getattr
(
config
,
"rotary_emb_base"
,
10000.0
)
rotary_emb_base
=
getattr
(
config
,
"rotary_emb_base"
,
10000.0
)
rotary_emb_scale_base
=
getattr
(
config
,
"rotary_emb_scale_base"
,
None
)
rotary_emb_scale_base
=
getattr
(
config
,
"rotary_emb_scale_base"
,
None
)
rotary_emb_interleaved
=
getattr
(
config
,
"rotary_emb_interleaved"
,
False
)
rotary_emb_interleaved
=
getattr
(
config
,
"rotary_emb_interleaved"
,
False
)
use_alibi
=
getattr
(
config
,
"use_alibi"
,
False
)
use_flash_attn
=
getattr
(
config
,
"use_flash_attn"
,
False
)
use_flash_attn
=
getattr
(
config
,
"use_flash_attn"
,
False
)
fused_bias_fc
=
getattr
(
config
,
"fused_bias_fc"
,
False
)
fused_bias_fc
=
getattr
(
config
,
"fused_bias_fc"
,
False
)
if
not
fused_bias_fc
:
if
not
fused_bias_fc
:
...
@@ -116,6 +117,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
...
@@ -116,6 +117,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
rotary_emb_base
=
rotary_emb_base
,
rotary_emb_base
=
rotary_emb_base
,
rotary_emb_scale_base
=
rotary_emb_scale_base
,
rotary_emb_scale_base
=
rotary_emb_scale_base
,
rotary_emb_interleaved
=
rotary_emb_interleaved
,
rotary_emb_interleaved
=
rotary_emb_interleaved
,
use_alibi
=
use_alibi
,
use_flash_attn
=
use_flash_attn
,
use_flash_attn
=
use_flash_attn
,
**
serial_kwargs
,
**
serial_kwargs
,
**
parallel_kwargs
,
**
parallel_kwargs
,
...
...
flash_attn/modules/mha.py
View file @
c3b21966
...
@@ -33,6 +33,23 @@ except ImportError:
...
@@ -33,6 +33,23 @@ except ImportError:
RotaryEmbedding
=
None
RotaryEmbedding
=
None
# From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
def
get_alibi_slopes
(
nheads
):
def
get_slopes_power_of_2
(
nheads
):
start
=
2
**
(
-
(
2
**
-
(
math
.
log2
(
nheads
)
-
3
)))
ratio
=
start
return
[
start
*
ratio
**
i
for
i
in
range
(
nheads
)]
if
math
.
log2
(
nheads
).
is_integer
():
return
get_slopes_power_of_2
(
nheads
)
else
:
closest_power_of_2
=
2
**
math
.
floor
(
math
.
log2
(
nheads
))
return
(
get_slopes_power_of_2
(
closest_power_of_2
)
+
get_alibi_slopes
(
2
*
closest_power_of_2
)[
0
::
2
][:
nheads
-
closest_power_of_2
]
)
class
FlashSelfAttention
(
nn
.
Module
):
class
FlashSelfAttention
(
nn
.
Module
):
"""Implement the scaled dot product attention with softmax.
"""Implement the scaled dot product attention with softmax.
Arguments
Arguments
...
@@ -44,13 +61,14 @@ class FlashSelfAttention(nn.Module):
...
@@ -44,13 +61,14 @@ class FlashSelfAttention(nn.Module):
(default: 0.0)
(default: 0.0)
"""
"""
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
):
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
alibi_slopes
=
None
):
super
().
__init__
()
super
().
__init__
()
assert
flash_attn_varlen_qkvpacked_func
is
not
None
,
"FlashAttention is not installed"
assert
flash_attn_varlen_qkvpacked_func
is
not
None
,
"FlashAttention is not installed"
assert
flash_attn_qkvpacked_func
is
not
None
,
"FlashAttention is not installed"
assert
flash_attn_qkvpacked_func
is
not
None
,
"FlashAttention is not installed"
self
.
causal
=
causal
self
.
causal
=
causal
self
.
softmax_scale
=
softmax_scale
self
.
softmax_scale
=
softmax_scale
self
.
drop
=
nn
.
Dropout
(
attention_dropout
)
self
.
drop
=
nn
.
Dropout
(
attention_dropout
)
self
.
register_buffer
(
"alibi_slopes"
,
alibi_slopes
,
persistent
=
False
)
def
forward
(
self
,
qkv
,
causal
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
):
def
forward
(
self
,
qkv
,
causal
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
):
"""Implements the multihead softmax attention.
"""Implements the multihead softmax attention.
...
@@ -84,6 +102,7 @@ class FlashSelfAttention(nn.Module):
...
@@ -84,6 +102,7 @@ class FlashSelfAttention(nn.Module):
self
.
drop
.
p
if
self
.
training
else
0.0
,
self
.
drop
.
p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
,
causal
=
causal
,
alibi_slopes
=
self
.
alibi_slopes
,
)
)
else
:
else
:
return
flash_attn_qkvpacked_func
(
return
flash_attn_qkvpacked_func
(
...
@@ -91,6 +110,7 @@ class FlashSelfAttention(nn.Module):
...
@@ -91,6 +110,7 @@ class FlashSelfAttention(nn.Module):
self
.
drop
.
p
if
self
.
training
else
0.0
,
self
.
drop
.
p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
,
causal
=
causal
,
alibi_slopes
=
self
.
alibi_slopes
,
)
)
...
@@ -105,13 +125,14 @@ class FlashCrossAttention(nn.Module):
...
@@ -105,13 +125,14 @@ class FlashCrossAttention(nn.Module):
(default: 0.0)
(default: 0.0)
"""
"""
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
):
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
alibi_slopes
=
None
):
super
().
__init__
()
super
().
__init__
()
assert
flash_attn_varlen_kvpacked_func
is
not
None
,
"FlashAttention is not installed"
assert
flash_attn_varlen_kvpacked_func
is
not
None
,
"FlashAttention is not installed"
assert
flash_attn_kvpacked_func
is
not
None
,
"FlashAttention is not installed"
assert
flash_attn_kvpacked_func
is
not
None
,
"FlashAttention is not installed"
self
.
causal
=
causal
self
.
causal
=
causal
self
.
softmax_scale
=
softmax_scale
self
.
softmax_scale
=
softmax_scale
self
.
drop
=
nn
.
Dropout
(
attention_dropout
)
self
.
drop
=
nn
.
Dropout
(
attention_dropout
)
self
.
register_buffer
(
"alibi_slopes"
,
alibi_slopes
,
persistent
=
False
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -158,6 +179,7 @@ class FlashCrossAttention(nn.Module):
...
@@ -158,6 +179,7 @@ class FlashCrossAttention(nn.Module):
self
.
drop
.
p
if
self
.
training
else
0.0
,
self
.
drop
.
p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
,
causal
=
causal
,
alibi_slopes
=
self
.
alibi_slopes
,
)
)
else
:
else
:
batch_size
,
seqlen_q
=
q
.
shape
[
0
],
q
.
shape
[
1
]
batch_size
,
seqlen_q
=
q
.
shape
[
0
],
q
.
shape
[
1
]
...
@@ -169,6 +191,7 @@ class FlashCrossAttention(nn.Module):
...
@@ -169,6 +191,7 @@ class FlashCrossAttention(nn.Module):
self
.
drop
.
p
if
self
.
training
else
0.0
,
self
.
drop
.
p
if
self
.
training
else
0.0
,
causal
=
causal
,
causal
=
causal
,
softmax_scale
=
self
.
softmax_scale
,
softmax_scale
=
self
.
softmax_scale
,
alibi_slopes
=
self
.
alibi_slopes
,
)
)
...
@@ -315,8 +338,8 @@ def _update_kv_cache(kv, inference_params, layer_idx):
...
@@ -315,8 +338,8 @@ def _update_kv_cache(kv, inference_params, layer_idx):
batch_end
=
batch_start
+
kv
.
shape
[
0
]
batch_end
=
batch_start
+
kv
.
shape
[
0
]
sequence_start
=
inference_params
.
seqlen_offset
sequence_start
=
inference_params
.
seqlen_offset
sequence_end
=
sequence_start
+
kv
.
shape
[
1
]
sequence_end
=
sequence_start
+
kv
.
shape
[
1
]
assert
batch_end
<=
(
kv_cache
.
shape
[
0
]
if
kv_cache
is
not
None
else
v_cache
.
shape
[
0
])
assert
batch_end
<=
kv_cache
.
shape
[
0
]
assert
sequence_end
<=
(
kv_cache
.
shape
[
1
]
if
kv_cache
is
not
None
else
v_cache
.
shape
[
2
])
assert
sequence_end
<=
kv_cache
.
shape
[
1
]
assert
kv_cache
is
not
None
assert
kv_cache
is
not
None
kv_cache
[
batch_start
:
batch_end
,
sequence_start
:
sequence_end
,
...]
=
kv
kv_cache
[
batch_start
:
batch_end
,
sequence_start
:
sequence_end
,
...]
=
kv
return
kv_cache
[
batch_start
:
batch_end
,
:
sequence_end
,
...]
return
kv_cache
[
batch_start
:
batch_end
,
:
sequence_end
,
...]
...
@@ -342,6 +365,7 @@ class MHA(nn.Module):
...
@@ -342,6 +365,7 @@ class MHA(nn.Module):
rotary_emb_base
=
10000.0
,
rotary_emb_base
=
10000.0
,
rotary_emb_scale_base
=
None
,
rotary_emb_scale_base
=
None
,
rotary_emb_interleaved
=
False
,
rotary_emb_interleaved
=
False
,
use_alibi
=
False
,
fused_bias_fc
=
False
,
fused_bias_fc
=
False
,
use_flash_attn
=
False
,
use_flash_attn
=
False
,
return_residual
=
False
,
return_residual
=
False
,
...
@@ -366,6 +390,11 @@ class MHA(nn.Module):
...
@@ -366,6 +390,11 @@ class MHA(nn.Module):
self
.
use_flash_attn
=
use_flash_attn
self
.
use_flash_attn
=
use_flash_attn
self
.
return_residual
=
return_residual
self
.
return_residual
=
return_residual
self
.
checkpointing
=
checkpointing
self
.
checkpointing
=
checkpointing
if
use_alibi
:
assert
use_flash_attn
,
"ALiBi code path requires flash_attn"
alibi_slopes
=
torch
.
tensor
(
get_alibi_slopes
(
num_heads
),
device
=
device
)
else
:
alibi_slopes
=
None
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
num_heads_kv
=
num_heads_kv
if
num_heads_kv
is
not
None
else
num_heads
self
.
num_heads_kv
=
num_heads_kv
if
num_heads_kv
is
not
None
else
num_heads
...
@@ -395,8 +424,16 @@ class MHA(nn.Module):
...
@@ -395,8 +424,16 @@ class MHA(nn.Module):
LinearResidual
if
not
fused_bias_fc
else
partial
(
FusedDense
,
return_residual
=
True
)
LinearResidual
if
not
fused_bias_fc
else
partial
(
FusedDense
,
return_residual
=
True
)
)
)
wqkv_cls
=
linear_cls
if
not
self
.
return_residual
else
linear_resid_cls
wqkv_cls
=
linear_cls
if
not
self
.
return_residual
else
linear_resid_cls
inner_attn_cls
=
FlashSelfAttention
if
use_flash_attn
else
SelfAttention
inner_attn_cls
=
(
inner_cross_attn_cls
=
FlashCrossAttention
if
use_flash_attn
else
CrossAttention
partial
(
FlashSelfAttention
,
alibi_slopes
=
alibi_slopes
)
if
use_flash_attn
else
SelfAttention
)
inner_cross_attn_cls
=
(
partial
(
FlashCrossAttention
,
alibi_slopes
=
alibi_slopes
)
if
use_flash_attn
else
CrossAttention
)
if
not
self
.
cross_attn
:
if
not
self
.
cross_attn
:
self
.
Wqkv
=
wqkv_cls
(
embed_dim
,
qkv_dim
,
bias
=
qkv_proj_bias
,
**
factory_kwargs
)
self
.
Wqkv
=
wqkv_cls
(
embed_dim
,
qkv_dim
,
bias
=
qkv_proj_bias
,
**
factory_kwargs
)
else
:
else
:
...
@@ -413,7 +450,9 @@ class MHA(nn.Module):
...
@@ -413,7 +450,9 @@ class MHA(nn.Module):
)
)
self
.
dwconv_kv
=
nn
.
Conv1d
(
kv_dim
,
kv_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
kv_dim
)
self
.
dwconv_kv
=
nn
.
Conv1d
(
kv_dim
,
kv_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
kv_dim
)
self
.
inner_attn
=
inner_attn_cls
(
self
.
inner_attn
=
inner_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
,
)
)
self
.
inner_cross_attn
=
inner_cross_attn_cls
(
self
.
inner_cross_attn
=
inner_cross_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
...
@@ -672,6 +711,7 @@ class ParallelMHA(nn.Module):
...
@@ -672,6 +711,7 @@ class ParallelMHA(nn.Module):
rotary_emb_base
=
10000.0
,
rotary_emb_base
=
10000.0
,
rotary_emb_scale_base
=
None
,
rotary_emb_scale_base
=
None
,
rotary_emb_interleaved
=
False
,
rotary_emb_interleaved
=
False
,
use_alibi
=
False
,
use_flash_attn
=
False
,
use_flash_attn
=
False
,
checkpointing
=
False
,
checkpointing
=
False
,
sequence_parallel
=
True
,
sequence_parallel
=
True
,
...
@@ -707,6 +747,18 @@ class ParallelMHA(nn.Module):
...
@@ -707,6 +747,18 @@ class ParallelMHA(nn.Module):
self
.
head_dim
=
self
.
embed_dim
//
num_heads
self
.
head_dim
=
self
.
embed_dim
//
num_heads
qkv_dim
=
self
.
head_dim
*
(
self
.
num_heads
+
2
*
self
.
num_heads_kv
)
qkv_dim
=
self
.
head_dim
*
(
self
.
num_heads
+
2
*
self
.
num_heads_kv
)
if
use_alibi
:
assert
use_flash_attn
,
"ALiBi code path requires flash_attn"
num_heads_local
=
math
.
ceil
(
self
.
num_heads
/
self
.
world_size
)
alibi_slopes
=
torch
.
tensor
(
get_alibi_slopes
(
num_heads
)[
self
.
local_rank
*
num_heads_local
:
(
self
.
local_rank
+
1
)
*
num_heads_local
],
device
=
device
,
)
else
:
alibi_slopes
=
None
if
self
.
rotary_emb_dim
>
0
:
if
self
.
rotary_emb_dim
>
0
:
assert
RotaryEmbedding
is
not
None
,
"rotary_emb is not installed"
assert
RotaryEmbedding
is
not
None
,
"rotary_emb is not installed"
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb
=
RotaryEmbedding
(
...
@@ -728,8 +780,16 @@ class ParallelMHA(nn.Module):
...
@@ -728,8 +780,16 @@ class ParallelMHA(nn.Module):
multiple_of
=
self
.
head_dim
*
(
self
.
num_heads
//
self
.
num_heads_kv
+
2
),
multiple_of
=
self
.
head_dim
*
(
self
.
num_heads
//
self
.
num_heads_kv
+
2
),
**
factory_kwargs
,
**
factory_kwargs
,
)
)
inner_attn_cls
=
FlashSelfAttention
if
use_flash_attn
else
SelfAttention
inner_attn_cls
=
(
inner_cross_attn_cls
=
FlashCrossAttention
if
use_flash_attn
else
CrossAttention
partial
(
FlashSelfAttention
,
alibi_slopes
=
alibi_slopes
)
if
use_flash_attn
else
SelfAttention
)
inner_cross_attn_cls
=
(
partial
(
FlashCrossAttention
,
alibi_slopes
=
alibi_slopes
)
if
use_flash_attn
else
CrossAttention
)
self
.
inner_attn
=
inner_attn_cls
(
self
.
inner_attn
=
inner_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
)
)
...
...
flash_attn/modules/mlp.py
View file @
c3b21966
...
@@ -105,7 +105,7 @@ class GatedMlp(nn.Module):
...
@@ -105,7 +105,7 @@ class GatedMlp(nn.Module):
activation
=
F
.
sigmoid
,
activation
=
F
.
sigmoid
,
bias1
=
True
,
bias1
=
True
,
bias2
=
True
,
bias2
=
True
,
multiple_of
=
256
,
multiple_of
=
128
,
return_residual
=
False
,
return_residual
=
False
,
device
=
None
,
device
=
None
,
dtype
=
None
,
dtype
=
None
,
...
@@ -148,7 +148,7 @@ class ParallelGatedMlp(nn.Module):
...
@@ -148,7 +148,7 @@ class ParallelGatedMlp(nn.Module):
activation
=
F
.
sigmoid
,
activation
=
F
.
sigmoid
,
bias1
=
True
,
bias1
=
True
,
bias2
=
True
,
bias2
=
True
,
multiple_of
=
256
,
multiple_of
=
128
,
sequence_parallel
=
True
,
sequence_parallel
=
True
,
device
=
None
,
device
=
None
,
dtype
=
None
,
dtype
=
None
,
...
...
tests/models/test_baichuan.py
View file @
c3b21966
...
@@ -2,8 +2,6 @@ import os
...
@@ -2,8 +2,6 @@ import os
import
time
import
time
from
pathlib
import
Path
from
pathlib
import
Path
current_dir
=
Path
(
__file__
).
parent
.
absolute
()
import
torch
import
torch
import
pytest
import
pytest
...
@@ -20,16 +18,12 @@ from flash_attn.models.baichuan import (
...
@@ -20,16 +18,12 @@ from flash_attn.models.baichuan import (
remap_state_dict_hf_baichuan
,
remap_state_dict_hf_baichuan
,
baichuan_config_to_gpt2_config
,
baichuan_config_to_gpt2_config
,
)
)
from
flash_attn.models.baichuan
import
(
config_from_checkpoint
,
state_dicts_from_checkpoint
,
)
from
flash_attn.utils.distributed
import
all_gather_raw
from
flash_attn.utils.distributed
import
all_gather_raw
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.generation
import
update_graph_cache
from
flash_attn.utils.generation
import
update_graph_cache
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"baichuan-inc/Baichuan-7B"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"baichuan-inc/Baichuan-7B"
,
"baichuan-inc/Baichuan-13B-Base"
])
def
test_baichuan_state_dict
(
model_name
):
def
test_baichuan_state_dict
(
model_name
):
config
=
baichuan_config_to_gpt2_config
(
config
=
baichuan_config_to_gpt2_config
(
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
...
@@ -45,7 +39,7 @@ def test_baichuan_state_dict(model_name):
...
@@ -45,7 +39,7 @@ def test_baichuan_state_dict(model_name):
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"baichuan-inc/Baichuan-7B"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"baichuan-inc/Baichuan-7B"
,
"baichuan-inc/Baichuan-13B-Base"
])
def
test_baichuan_optimized
(
model_name
):
def
test_baichuan_optimized
(
model_name
):
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
...
@@ -122,7 +116,7 @@ def test_baichuan_optimized(model_name):
...
@@ -122,7 +116,7 @@ def test_baichuan_optimized(model_name):
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "test_baichuan_parallel_forward"
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "test_baichuan_parallel_forward"
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"baichuan-inc/Baichuan-7B"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"baichuan-inc/Baichuan-7B"
,
"baichuan-inc/Baichuan-13B-Base"
])
def
test_baichuan_parallel_forward
(
model_name
,
world_size
):
def
test_baichuan_parallel_forward
(
model_name
,
world_size
):
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
...
@@ -217,7 +211,7 @@ def test_baichuan_parallel_forward(model_name, world_size):
...
@@ -217,7 +211,7 @@ def test_baichuan_parallel_forward(model_name, world_size):
).
abs
().
max
().
item
()
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"baichuan-inc/Baichuan-7B"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"baichuan-inc/Baichuan-7B"
,
"baichuan-inc/Baichuan-13B-Base"
])
def
test_baichuan_generation
(
model_name
):
def
test_baichuan_generation
(
model_name
):
dtype
=
torch
.
float16
dtype
=
torch
.
float16
device
=
"cuda"
device
=
"cuda"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment