Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
96d10f65
Commit
96d10f65
authored
Apr 18, 2023
by
Tri Dao
Browse files
Implement LLaMa
parent
b630aef5
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
509 additions
and
50 deletions
+509
-50
flash_attn/models/bert.py
flash_attn/models/bert.py
+2
-6
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+40
-13
flash_attn/models/gptj.py
flash_attn/models/gptj.py
+1
-3
flash_attn/models/llama.py
flash_attn/models/llama.py
+124
-0
flash_attn/models/opt.py
flash_attn/models/opt.py
+2
-6
flash_attn/modules/block.py
flash_attn/modules/block.py
+26
-7
flash_attn/modules/mlp.py
flash_attn/modules/mlp.py
+7
-6
flash_attn/ops/layer_norm.py
flash_attn/ops/layer_norm.py
+2
-2
flash_attn/ops/rms_norm.py
flash_attn/ops/rms_norm.py
+18
-2
flash_attn/ops/triton/mlp.py
flash_attn/ops/triton/mlp.py
+6
-5
tests/models/test_gpt_neox.py
tests/models/test_gpt_neox.py
+2
-0
tests/models/test_gptj.py
tests/models/test_gptj.py
+2
-0
tests/models/test_llama.py
tests/models/test_llama.py
+277
-0
No files found.
flash_attn/models/bert.py
View file @
96d10f65
...
...
@@ -487,18 +487,14 @@ def remap_state_dict(state_dict, config):
state_dict
[
f
'bert.encoder.layers.
{
d
}
.mixer.Wqkv.weight'
]
=
torch
.
cat
(
[
Wq
,
Wk
,
Wv
],
dim
=
0
)
state_dict
[
f
'bert.encoder.layers.
{
d
}
.mixer.Wqkv.bias'
]
=
torch
.
cat
(
[
bq
,
bk
,
bv
],
dim
=
0
)
state_dict
[
f
'bert.encoder.layers.
{
d
}
.mixer.Wqkv.bias'
]
=
torch
.
cat
([
bq
,
bk
,
bv
],
dim
=
0
)
else
:
state_dict
[
f
'bert.encoder.layers.
{
d
}
.mixer.Wq.weight'
]
=
Wq
state_dict
[
f
'bert.encoder.layers.
{
d
}
.mixer.Wkv.weight'
]
=
torch
.
cat
(
[
Wk
,
Wv
],
dim
=
0
)
state_dict
[
f
'bert.encoder.layers.
{
d
}
.mixer.Wq.bias'
]
=
bq
state_dict
[
f
'bert.encoder.layers.
{
d
}
.mixer.Wkv.bias'
]
=
torch
.
cat
(
[
bk
,
bv
],
dim
=
0
)
state_dict
[
f
'bert.encoder.layers.
{
d
}
.mixer.Wkv.bias'
]
=
torch
.
cat
([
bk
,
bv
],
dim
=
0
)
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
'^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)'
,
r
'bert.encoder.layers.\1.mixer.out_proj.\2'
,
key
)
...
...
flash_attn/models/gpt.py
View file @
96d10f65
...
...
@@ -43,6 +43,16 @@ try:
except
ImportError
:
dropout_add_layer_norm_parallel_residual
=
None
try
:
from
flash_attn.ops.rms_norm
import
RMSNorm
,
dropout_add_rms_norm
except
ImportError
:
RMSNorm
,
dropout_add_rms_norm
=
None
try
:
from
flash_attn.ops.rms_norm
import
dropout_add_rms_norm_parallel_residual
except
ImportError
:
dropout_add_rms_norm_parallel_residual
=
None
try
:
from
flash_attn.ops.triton.mlp
import
FusedDenseSqreluDense
except
ImportError
:
...
...
@@ -90,6 +100,8 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
def
create_mlp_cls
(
config
,
layer_idx
=
None
,
process_group
=
None
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
mlp_fc1_bias
=
getattr
(
config
,
'mlp_fc1_bias'
,
True
)
mlp_fc2_bias
=
getattr
(
config
,
'mlp_fc2_bias'
,
True
)
fused_mlp
=
getattr
(
config
,
'fused_mlp'
,
False
)
if
fused_mlp
:
assert
config
.
activation_function
in
[
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
,
'relu'
,
'sqrelu'
]
...
...
@@ -108,7 +120,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
else
(
F
.
silu
if
config
.
activation_function
==
'swiglu'
else
F
.
gelu
))
mlp_cls
=
partial
(
GatedMlp
,
hidden_features
=
config
.
n_inner
,
activation
=
activation
,
**
factory_kwargs
)
bias1
=
mlp_fc1_bias
,
bias2
=
mlp_fc2_bias
,
**
factory_kwargs
)
else
:
if
config
.
activation_function
==
'relu'
:
activation
=
partial
(
F
.
relu
,
inplace
=
True
)
...
...
@@ -119,7 +131,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
in
[
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
]
else
'none'
)
activation
=
partial
(
F
.
gelu
,
approximate
=
approximate
)
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
config
.
n_inner
,
activation
=
activation
,
**
factory_kwargs
)
bias1
=
mlp_fc1_bias
,
bias2
=
mlp_fc2_bias
,
**
factory_kwargs
)
else
:
mlp_checkpoint_lvl
=
getattr
(
config
,
'mlp_checkpoint_lvl'
,
0
)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
...
...
@@ -137,6 +149,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
if
process_group
is
not
None
else
{})
mlp_cls
=
partial
(
mlp_cls
,
hidden_features
=
config
.
n_inner
,
activation
=
activation
,
checkpoint_lvl
=
mlp_checkpoint_lvl
,
bias1
=
mlp_fc1_bias
,
bias2
=
mlp_fc2_bias
,
**
parallel_kwargs
,
**
factory_kwargs
)
elif
fused_dense_sqrelu_dense
:
assert
FusedDenseSqreluDense
is
not
None
...
...
@@ -152,7 +165,9 @@ def create_block(config, layer_idx=None, process_group=None, device=None, dtype=
sequence_parallel
=
getattr
(
config
,
'sequence_parallel'
,
True
)
mixer_cls
=
create_mixer_cls
(
config
,
layer_idx
,
process_group
=
process_group
,
**
factory_kwargs
)
mlp_cls
=
create_mlp_cls
(
config
,
layer_idx
,
process_group
=
process_group
,
**
factory_kwargs
)
norm_cls
=
partial
(
nn
.
LayerNorm
,
eps
=
config
.
layer_norm_epsilon
,
**
factory_kwargs
)
use_rms_norm
=
getattr
(
config
,
'rms_norm'
,
False
)
norm_cls
=
partial
(
nn
.
LayerNorm
if
not
use_rms_norm
else
RMSNorm
,
eps
=
config
.
layer_norm_epsilon
,
**
factory_kwargs
)
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
residual_in_fp32
=
getattr
(
config
,
'residual_in_fp32'
,
False
)
resid_dropout1
=
config
.
resid_pdrop
if
layer_idx
is
None
or
layer_idx
>
0
else
config
.
embd_pdrop
...
...
@@ -267,6 +282,7 @@ class GPTModel(GPTPreTrainedModel):
self
.
residual_in_fp32
=
getattr
(
config
,
'residual_in_fp32'
,
False
)
# These 2 options are for OPT-350m
self
.
prenorm
=
getattr
(
config
,
'prenorm'
,
True
)
use_rms_norm
=
getattr
(
config
,
'rms_norm'
,
False
)
word_embed_proj_dim
=
getattr
(
config
,
'word_embed_proj_dim'
,
None
)
# For GPT-J, GPT-NeoX
self
.
parallel_block
=
getattr
(
config
,
'parallel_block'
,
False
)
...
...
@@ -300,7 +316,8 @@ class GPTModel(GPTPreTrainedModel):
raise
ImportError
(
'dropout_layer_norm is not installed'
)
if
self
.
prenorm
:
self
.
drop_f
=
nn
.
Dropout
(
config
.
resid_pdrop
)
self
.
ln_f
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
,
norm_cls
=
nn
.
LayerNorm
if
not
use_rms_norm
else
RMSNorm
self
.
ln_f
=
norm_cls
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
,
**
factory_kwargs
)
if
process_group
is
not
None
:
for
p
in
self
.
ln_f
.
parameters
():
...
...
@@ -512,30 +529,39 @@ def combine_state_dicts_tp(state_dicts, config):
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
config
.
hidden_size
assert
inner_dim
%
world_size
==
0
#
T
he word embeddings
from Megatron are weird, for each shard
on
ly
the
first
#
Sometimes t
he word embeddings
are sharded on the 0th dim, sometimes
on the
1st dim.
# vocab_size // world_size coordinates are nonzero.
def
combine_word_embeddings
(
state_dicts
,
state_dict
,
key
):
assert
all
(
s
[
key
].
shape
[
0
]
==
vocab_size
for
s
in
state_dicts
)
state_dict
[
key
]
=
torch
.
cat
([
s
[
key
]
[:
vocab_size
//
world_size
]
for
s
in
state_dicts
],
dim
=
0
)
dim
=
0
if
state_dicts
[
0
]
[
key
].
shape
[
0
]
==
vocab_size
//
world_size
else
1
state_dict
[
key
]
=
torch
.
cat
([
s
[
key
]
for
s
in
state_dicts
],
dim
=
dim
)
def
combine_dim
(
state_dicts
,
state_dict
,
key
,
dim
=-
1
):
if
key
in
state_dict
:
state_dict
[
key
]
=
torch
.
cat
([
s
[
key
]
for
s
in
state_dicts
],
dim
=
dim
)
def
combine_qkv_headdim
(
state_dicts
,
state_dict
,
key
):
if
key
in
state_dict
:
xs
=
[
rearrange
(
s
[
key
],
'(three d) ... -> three d ...'
,
three
=
3
)
for
s
in
state_dicts
]
state_dict
[
key
]
=
rearrange
(
torch
.
cat
(
xs
,
dim
=
1
),
'three d ... -> (three d) ...'
)
def
combine_gated_mlp
(
state_dicts
,
state_dict
,
key
):
if
key
in
state_dict
:
xs
=
[
rearrange
(
s
[
key
],
'(two d) ... -> two d ...'
,
two
=
2
)
for
s
in
state_dicts
]
state_dict
[
key
]
=
rearrange
(
torch
.
cat
(
xs
,
dim
=
1
),
'two d ... -> (two d) ...'
)
state_dict
=
state_dicts
[
0
].
copy
()
# don't modify state_dict[0] inplace
combine_word_embeddings
(
state_dicts
,
state_dict
,
'transformer.embeddings.word_embeddings.weight'
)
if
'lm_head.weight'
in
state_dict
:
combine_word_embeddings
(
state_dicts
,
state_dict
,
'lm_head.weight'
)
if
'transformer.embeddings.position_embeddings.weight'
in
state_dict
:
combine_dim
(
state_dicts
,
state_dict
,
'transformer.embeddings.position_embeddings.weight'
,
-
1
)
mlp_combine_fn
=
(
combine_gated_mlp
if
config
.
activation_function
in
[
'glu'
,
'swiglu'
,
'geglu'
]
else
partial
(
combine_dim
,
dim
=
0
))
for
i
in
range
(
config
.
num_hidden_layers
):
combine_qkv_headdim
(
state_dicts
,
state_dict
,
f
'transformer.layers.
{
i
}
.mixer.Wqkv.weight'
)
combine_qkv_headdim
(
state_dicts
,
state_dict
,
f
'transformer.layers.
{
i
}
.mixer.Wqkv.bias'
)
combine_dim
(
state_dicts
,
state_dict
,
f
'transformer.layers.
{
i
}
.mixer.out_proj.weight'
,
-
1
)
combine_
dim
(
state_dicts
,
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc1.weight'
,
0
)
mlp_
combine_
fn
(
state_dicts
,
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc1.weight'
)
combine_dim
(
state_dicts
,
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc1.bias'
,
0
)
combine_dim
(
state_dicts
,
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc2.weight'
,
-
1
)
return
state_dict
...
...
@@ -603,7 +629,8 @@ def remap_state_dict_megatron(state_dict, config):
word_embeddings
=
state_dict
.
pop
(
'transformer.embedding.word_embeddings.weight'
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
vocab_size
=
(
math
.
ceil
(
word_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
state_dict
[
'transformer.embeddings.word_embeddings.weight'
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
...
...
flash_attn/models/gptj.py
View file @
96d10f65
...
...
@@ -56,9 +56,7 @@ def remap_state_dict_hf_gptj(state_dict, config):
Wq
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attn.q_proj.weight'
)
Wk
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attn.k_proj.weight'
)
Wv
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attn.v_proj.weight'
)
state_dict
[
f
'transformer.layers.
{
l
}
.mixer.Wqkv.weight'
]
=
torch
.
cat
(
[
Wq
,
Wk
,
Wv
],
dim
=
0
)
state_dict
[
f
'transformer.layers.
{
l
}
.mixer.Wqkv.weight'
]
=
torch
.
cat
([
Wq
,
Wk
,
Wv
],
dim
=
0
)
# We don't store these biases
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attn.bias'
)
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attn.masked_bias'
)
...
...
flash_attn/models/llama.py
0 → 100644
View file @
96d10f65
# Copyright (c) 2023, Tri Dao.
import
math
import
json
import
re
from
pathlib
import
Path
from
collections
import
OrderedDict
import
torch
import
torch.nn.functional
as
F
from
transformers
import
GPT2Config
,
LlamaConfig
def
remap_state_dict_meta_llama
(
state_dict
,
config
):
def
key_mapping_layers
(
key
):
return
f
'transformer.
{
key
}
'
if
not
key
.
startswith
(
'output.'
)
else
key
state_dict
=
OrderedDict
((
key_mapping_layers
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding
def
key_mapping_emb
(
key
):
return
re
.
sub
(
r
'^transformer.tok_embeddings.'
,
'transformer.embeddings.word_embeddings.'
,
key
)
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
'transformer.embeddings.word_embeddings.weight'
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
vocab_size
=
(
math
.
ceil
(
word_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
state_dict
[
'transformer.embeddings.word_embeddings.weight'
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
if
getattr
(
config
,
'tie_word_embeddings'
):
state_dict
[
'lm_head.weight'
]
=
state_dict
[
'transformer.embeddings.word_embeddings.weight'
]
else
:
output_embeddings
=
state_dict
.
pop
(
'output.weight'
)
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
# differently.
vocab_size
=
(
math
.
ceil
(
output_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict
[
'lm_head.weight'
]
=
F
.
pad
(
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
)
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
'^transformer.norm.'
,
r
'transformer.ln_f.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).attention_norm.'
,
r
'transformer.layers.\1.norm1.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).ffn_norm.'
,
r
'transformer.layers.\1.norm2.'
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
for
l
in
range
(
config
.
n_layer
):
w1
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.feed_forward.w1.weight'
)
w3
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.feed_forward.w3.weight'
)
# Our ordering is different
state_dict
[
f
'transformer.layers.
{
l
}
.mlp.fc1.weight'
]
=
torch
.
cat
([
w3
,
w1
],
dim
=
0
)
def
key_mapping_mlp
(
key
):
return
re
.
sub
(
r
'^transformer.layers.(\d+).feed_forward.w2.'
,
r
'transformer.layers.\1.mlp.fc2.'
,
key
)
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
for
l
in
range
(
config
.
n_layer
):
Wq
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attention.wq.weight'
)
Wk
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attention.wk.weight'
)
Wv
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attention.wv.weight'
)
state_dict
[
f
'transformer.layers.
{
l
}
.mixer.Wqkv.weight'
]
=
torch
.
cat
([
Wq
,
Wk
,
Wv
],
dim
=
0
)
# We don't store these
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attention.inner_attention.rope.freqs'
,
None
)
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
'^transformer.layers.(\d+).attention.wo.'
,
r
'transformer.layers.\1.mixer.out_proj.'
,
key
)
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
def
config_from_checkpoint
(
checkpoint_path
:
str
,
model_name
:
str
)
->
LlamaConfig
:
"""Load a LlamaConfig from a checkpoint path."""
with
open
(
Path
(
checkpoint_path
)
/
model_name
/
'params.json'
)
as
f
:
params
=
json
.
load
(
f
)
config
=
LlamaConfig
(
hidden_size
=
params
[
'dim'
],
intermediate_size
=
None
,
num_attention_heads
=
params
[
'n_heads'
],
num_hidden_layers
=
params
[
'n_layers'
],
rms_norm_eps
=
params
[
'norm_eps'
])
return
config
def
state_dicts_from_checkpoint
(
checkpoint_path
:
str
,
model_name
:
str
)
->
dict
:
# Need to sort, otherwise we mess up the ordering and the weights are wrong
return
[
torch
.
load
(
path
,
map_location
=
'cpu'
)
for
path
in
sorted
((
Path
(
checkpoint_path
)
/
model_name
).
glob
(
'consolidated.*.pth'
))]
def
llama_config_to_gpt2_config
(
llama_config
:
LlamaConfig
)
->
GPT2Config
:
return
GPT2Config
(
vocab_size
=
llama_config
.
vocab_size
,
n_positions
=
0
,
# No absolute position embedding
n_embd
=
llama_config
.
hidden_size
,
n_layer
=
llama_config
.
num_hidden_layers
,
n_head
=
llama_config
.
num_attention_heads
,
n_inner
=
llama_config
.
intermediate_size
,
activation_function
=
'swiglu'
,
# Hardcode since HF calls it 'silu'
# Llama doesn't have dropout, idk if it's because they only release the inference code
resid_pdrop
=
0.0
,
embd_pdrop
=
0.0
,
attn_pdrop
=
0.0
,
layer_norm_epsilon
=
llama_config
.
rms_norm_eps
,
initializer_range
=
llama_config
.
initializer_range
,
bos_token_id
=
llama_config
.
bos_token_id
,
eos_token_id
=
llama_config
.
eos_token_id
,
# These are new arguments not in the original GPT2Config
pad_token_id
=
llama_config
.
pad_token_id
,
# Idk if this does anything
rms_norm
=
True
,
rotary_emb_fraction
=
1.0
,
rotary_emb_interleaved
=
True
,
tie_word_embeddings
=
False
,
qkv_proj_bias
=
False
,
out_proj_bias
=
False
,
mlp_fc1_bias
=
False
,
mlp_fc2_bias
=
False
,
)
flash_attn/models/opt.py
View file @
96d10f65
...
...
@@ -66,12 +66,8 @@ def remap_state_dict_hf_opt(state_dict, config):
bq
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.self_attn.q_proj.bias'
)
bk
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.self_attn.k_proj.bias'
)
bv
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.self_attn.v_proj.bias'
)
state_dict
[
f
'transformer.layers.
{
l
}
.mixer.Wqkv.weight'
]
=
torch
.
cat
(
[
Wq
,
Wk
,
Wv
],
dim
=
0
)
state_dict
[
f
'transformer.layers.
{
l
}
.mixer.Wqkv.bias'
]
=
torch
.
cat
(
[
bq
,
bk
,
bv
],
dim
=
0
)
state_dict
[
f
'transformer.layers.
{
l
}
.mixer.Wqkv.weight'
]
=
torch
.
cat
([
Wq
,
Wk
,
Wv
],
dim
=
0
)
state_dict
[
f
'transformer.layers.
{
l
}
.mixer.Wqkv.bias'
]
=
torch
.
cat
([
bq
,
bk
,
bv
],
dim
=
0
)
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
'^transformer.layers.(\d+).self_attn.out_proj.'
,
r
'transformer.layers.\1.mixer.out_proj.'
,
key
)
...
...
flash_attn/modules/block.py
View file @
96d10f65
...
...
@@ -23,6 +23,16 @@ try:
except
ImportError
:
dropout_add_layer_norm_parallel_residual
=
None
try
:
from
flash_attn.ops.rms_norm
import
RMSNorm
,
dropout_add_rms_norm
except
ImportError
:
RMSNorm
,
dropout_add_rms_norm
=
None
try
:
from
flash_attn.ops.rms_norm
import
dropout_add_rms_norm_parallel_residual
except
ImportError
:
dropout_add_rms_norm_parallel_residual
=
None
class
Block
(
nn
.
Module
):
...
...
@@ -70,7 +80,9 @@ class Block(nn.Module):
if
self
.
fused_dropout_add_ln
:
assert
dropout_add_layer_norm
is
not
None
,
'dropout_layer_norm is not installed'
assert
isinstance
(
self
.
norm1
,
nn
.
LayerNorm
)
and
isinstance
(
self
.
dropout1
,
nn
.
Dropout
)
assert
dropout_add_rms_norm
is
not
None
,
'dropout_layer_norm is not installed'
assert
(
isinstance
(
self
.
norm1
,
(
nn
.
LayerNorm
,
RMSNorm
))
and
isinstance
(
self
.
dropout1
,
nn
.
Dropout
))
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# then the input to each worker in the tensor parallel group will be different.
...
...
@@ -104,6 +116,8 @@ class Block(nn.Module):
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
"""
fused_add_norm_fn
=
(
dropout_add_rms_norm
if
isinstance
(
self
.
norm1
,
RMSNorm
)
else
dropout_add_layer_norm
)
if
self
.
prenorm
:
if
not
self
.
fused_dropout_add_ln
:
dropped
=
self
.
drop_path1
(
self
.
dropout1
(
hidden_states
))
...
...
@@ -119,7 +133,7 @@ class Block(nn.Module):
hidden_states
.
shape
[:
-
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
)
hidden_states
,
residual
=
dropout_add_layer
_norm
(
hidden_states
,
residual
=
fused_add
_norm
_fn
(
hidden_states
,
residual
,
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
self
.
norm1
.
eps
,
rowscale
=
rowscale1
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
...
...
@@ -146,7 +160,7 @@ class Block(nn.Module):
hidden_states
.
shape
[:
-
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
)
hidden_states
,
residual
=
dropout_add_layer
_norm
(
hidden_states
,
residual
=
fused_add
_norm
_fn
(
hidden_states
,
residual
,
self
.
norm2
.
weight
,
self
.
norm2
.
bias
,
self
.
dropout2
.
p
if
self
.
training
else
0.0
,
self
.
norm2
.
eps
,
rowscale
=
rowscale2
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
...
...
@@ -170,7 +184,7 @@ class Block(nn.Module):
rowscale1
=
self
.
drop_path1
(
torch
.
ones
(
mixer_out
.
shape
[:
-
1
],
device
=
mixer_out
.
device
,
dtype
=
mixer_out
.
dtype
)
)
hidden_states
=
dropout_add_layer
_norm
(
hidden_states
=
fused_add
_norm
_fn
(
mixer_out
,
hidden_states
,
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
self
.
norm1
.
eps
,
rowscale
=
rowscale1
,
prenorm
=
False
...
...
@@ -189,7 +203,7 @@ class Block(nn.Module):
rowscale2
=
self
.
drop_path2
(
torch
.
ones
(
mlp_out
.
shape
[:
-
1
],
device
=
mlp_out
.
device
,
dtype
=
mlp_out
.
dtype
)
)
hidden_states
=
dropout_add_layer
_norm
(
hidden_states
=
fused_add
_norm
_fn
(
mlp_out
,
hidden_states
,
self
.
norm2
.
weight
,
self
.
norm2
.
bias
,
self
.
dropout2
.
p
if
self
.
training
else
0.0
,
self
.
norm2
.
eps
,
rowscale
=
rowscale2
,
prenorm
=
False
...
...
@@ -234,7 +248,9 @@ class ParallelBlock(nn.Module):
if
self
.
fused_dropout_add_ln
:
assert
dropout_add_layer_norm_parallel_residual
is
not
None
,
'dropout_layer_norm is not installed'
assert
isinstance
(
self
.
norm1
,
nn
.
LayerNorm
)
and
isinstance
(
self
.
dropout1
,
nn
.
Dropout
)
assert
dropout_add_rms_norm_parallel_residual
is
not
None
,
'dropout_layer_norm is not installed'
assert
(
isinstance
(
self
.
norm1
,
(
nn
.
LayerNorm
,
RMSNorm
))
and
isinstance
(
self
.
dropout1
,
nn
.
Dropout
))
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# then the input to each worker in the tensor parallel group will be different.
...
...
@@ -266,6 +282,9 @@ class ParallelBlock(nn.Module):
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
residual.
"""
fused_add_norm_fn
=
(
dropout_add_rms_norm_parallel_residual
if
isinstance
(
self
.
norm1
,
RMSNorm
)
else
dropout_add_layer_norm_parallel_residual
)
if
not
self
.
fused_dropout_add_ln
:
dropped1
=
self
.
dropout1
(
hidden_states1
)
# For the very 1st block, we only want 1 dropout, not two different dropouts
...
...
@@ -283,7 +302,7 @@ class ParallelBlock(nn.Module):
else
:
weight2
,
bias2
=
((
self
.
norm2
.
weight
,
self
.
norm2
.
bias
)
if
not
self
.
tied_norm
else
(
None
,
None
))
hidden_states1
,
hidden_states2
,
residual
=
dropout_add_layer_norm_parallel_residual
(
hidden_states1
,
hidden_states2
,
residual
=
fused_add_norm_fn
(
hidden_states1
,
hidden_states2
,
residual
,
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
weight2
,
bias2
,
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
self
.
norm1
.
eps
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
...
...
flash_attn/modules/mlp.py
View file @
96d10f65
...
...
@@ -13,15 +13,15 @@ except ImportError:
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
gelu
,
return_residual
=
False
,
device
=
None
,
dtype
=
None
):
bias1
=
True
,
bias2
=
True
,
return_residual
=
False
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
*
4
self
.
return_residual
=
return_residual
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
,
**
factory_kwargs
)
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
,
bias
=
bias1
,
**
factory_kwargs
)
self
.
activation
=
activation
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
**
factory_kwargs
)
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias2
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
y
=
self
.
fc1
(
x
)
...
...
@@ -33,16 +33,17 @@ class Mlp(nn.Module):
class
GatedMlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
sigmoid
,
multiple_of
=
128
,
return_residual
=
False
,
device
=
None
,
dtype
=
None
):
bias1
=
True
,
bias2
=
True
,
multiple_of
=
256
,
return_residual
=
False
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
int
(
8
*
in_features
/
3
)
hidden_features
=
(
hidden_features
+
multiple_of
-
1
)
//
multiple_of
*
multiple_of
self
.
return_residual
=
return_residual
self
.
fc1
=
nn
.
Linear
(
in_features
,
2
*
hidden_features
,
**
factory_kwargs
)
self
.
fc1
=
nn
.
Linear
(
in_features
,
2
*
hidden_features
,
bias
=
bias1
,
**
factory_kwargs
)
self
.
activation
=
activation
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
**
factory_kwargs
)
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias1
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
y
=
self
.
fc1
(
x
)
...
...
flash_attn/ops/layer_norm.py
View file @
96d10f65
...
...
@@ -351,7 +351,7 @@ class DropoutAddLayerNorm(torch.nn.Module):
super
().
__init__
()
self
.
prenorm
=
prenorm
self
.
p
=
p
self
.
eps
ilon
=
eps
self
.
eps
=
eps
self
.
residual_in_fp32
=
residual_in_fp32
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
...
...
@@ -363,5 +363,5 @@ class DropoutAddLayerNorm(torch.nn.Module):
def
forward
(
self
,
x0
,
residual
=
None
):
return
dropout_add_layer_norm
(
x0
,
residual
,
self
.
weight
,
self
.
bias
,
self
.
p
if
self
.
training
else
0.0
,
self
.
eps
ilon
,
self
.
p
if
self
.
training
else
0.0
,
self
.
eps
,
prenorm
=
self
.
prenorm
,
residual_in_fp32
=
self
.
residual_in_fp32
)
flash_attn/ops/rms_norm.py
View file @
96d10f65
...
...
@@ -51,6 +51,22 @@ def dropout_add_rms_norm_parallel_residual(
)
class
RMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-5
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
register_parameter
(
'bias'
,
None
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
init
.
ones_
(
self
.
weight
)
def
forward
(
self
,
x
):
return
rms_norm
(
x
,
self
.
weight
,
self
.
eps
)
class
DropoutAddRMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
prenorm
=
False
,
p
=
0.0
,
eps
=
1e-5
,
residual_in_fp32
=
False
,
device
=
None
,
dtype
=
None
):
...
...
@@ -58,7 +74,7 @@ class DropoutAddRMSNorm(torch.nn.Module):
super
().
__init__
()
self
.
prenorm
=
prenorm
self
.
p
=
p
self
.
eps
ilon
=
eps
self
.
eps
=
eps
self
.
residual_in_fp32
=
residual_in_fp32
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
register_parameter
(
'bias'
,
None
)
...
...
@@ -69,5 +85,5 @@ class DropoutAddRMSNorm(torch.nn.Module):
def
forward
(
self
,
x0
,
residual
=
None
):
return
dropout_add_rms_norm
(
x0
,
residual
,
self
.
weight
,
None
,
self
.
p
if
self
.
training
else
0.0
,
self
.
eps
ilon
,
self
.
p
if
self
.
training
else
0.0
,
self
.
eps
,
prenorm
=
self
.
prenorm
,
residual_in_fp32
=
self
.
residual_in_fp32
)
flash_attn/ops/triton/mlp.py
View file @
96d10f65
...
...
@@ -105,7 +105,7 @@ fused_dense_sqrelu_dense_function = FusedDenseSqreluDenseFunc.apply
class
FusedDenseSqreluDense
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
bias
=
True
,
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
bias
1
=
True
,
bias2
=
True
,
checkpoint_lvl
=
0
,
device
=
None
,
dtype
=
None
):
"""
checkpoint_lvl (increasing lvl means slower but more memory saving):
...
...
@@ -117,11 +117,12 @@ class FusedDenseSqreluDense(nn.Module):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
assert
bias
==
True
,
"DenseSqreluDense module without bias is currently not supported"
hidden_features
=
hidden_features
or
in_features
*
4
assert
bias1
==
True
,
"DenseSqreluDense module without bias is currently not supported"
assert
bias2
==
True
,
"DenseSqreluDense module without bias is currently not supported"
self
.
checkpoint_lvl
=
checkpoint_lvl
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
,
bias
=
bias
,
**
factory_kwargs
)
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias
,
**
factory_kwargs
)
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
,
bias
=
bias
1
,
**
factory_kwargs
)
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias
2
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
assert
x
.
is_cuda
...
...
tests/models/test_gpt_neox.py
View file @
96d10f65
# Copyright (c) 2023, Tri Dao.
import
time
import
torch
...
...
tests/models/test_gptj.py
View file @
96d10f65
# Copyright (c) 2023, Tri Dao.
import
time
import
torch
...
...
tests/models/test_llama.py
0 → 100644
View file @
96d10f65
# Copyright (c) 2023, Tri Dao.
# To run the huggingface implementation, we first need to convert the weights:
# https://github.com/huggingface/transformers/pull/21955
# python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR$/llama/7B-hf
# and repeat for 13B, 30B, 65B
import
os
import
time
from
pathlib
import
Path
current_dir
=
Path
(
__file__
).
parent
.
absolute
()
import
torch
import
pytest
from
transformers
import
LlamaConfig
,
LlamaTokenizer
from
transformers.models.llama.modeling_llama
import
LlamaForCausalLM
from
flash_attn.models.gpt
import
GPTLMHeadModel
,
combine_state_dicts_tp
from
flash_attn.models.llama
import
remap_state_dict_meta_llama
,
llama_config_to_gpt2_config
from
flash_attn.models.llama
import
config_from_checkpoint
,
state_dicts_from_checkpoint
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.generation
import
update_graph_cache
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"7B"
])
def
test_llama_state_dict
(
model_name
):
checkpoint_path
=
Path
(
os
.
environ
.
get
(
'CHECKPOINT_DIR'
,
current_dir
.
parent
.
parent
/
'checkpoints'
))
/
'llama'
config
=
llama_config_to_gpt2_config
(
config_from_checkpoint
(
checkpoint_path
,
model_name
))
ckpt_state_dicts
=
state_dicts_from_checkpoint
(
checkpoint_path
,
model_name
)
pretrained_state_dict
=
remap_state_dict_meta_llama
(
ckpt_state_dicts
[
0
],
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
'meta'
)
# Without device='meta' init is very slow
state_dict
=
model
.
state_dict
()
rotary_inv_freq_keys
=
{
f
'transformer.layers.
{
l
}
.mixer.rotary_emb.inv_freq'
for
l
in
range
(
config
.
n_layer
)}
assert
state_dict
.
keys
()
==
pretrained_state_dict
.
keys
()
|
rotary_inv_freq_keys
for
k
in
state_dict
.
keys
()
-
rotary_inv_freq_keys
:
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"7B"
,
"13B"
])
def
test_llama_optimized
(
model_name
):
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
checkpoint_path
=
Path
(
os
.
environ
.
get
(
'CHECKPOINT_DIR'
,
current_dir
.
parent
.
parent
/
'checkpoints'
))
/
'llama'
dtype
=
torch
.
float16
device
=
'cuda'
config
=
llama_config_to_gpt2_config
(
config_from_checkpoint
(
checkpoint_path
,
model_name
))
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
False
# We don't have fused GatedMLP yet
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
ckpt_state_dicts
=
state_dicts_from_checkpoint
(
checkpoint_path
,
model_name
)
pretrained_state_dicts
=
[
remap_state_dict_meta_llama
(
s
,
config
)
for
s
in
ckpt_state_dicts
]
pretrained_state_dict
=
combine_state_dicts_tp
(
pretrained_state_dicts
,
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
pretrained_state_dict
,
strict
=
False
)
model
.
eval
()
torch
.
manual_seed
(
0
)
batch_size
=
2
max_seqlen
=
256
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
device
)
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
with
torch
.
no_grad
():
out
=
model
.
transformer
(
input_ids
)
logits
=
model
(
input_ids
).
logits
del
model
# Without device_map, the model is loaded on the CPU, which is very slow
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref
=
LlamaForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
'
{
model_name
}
-hf'
,
device_map
=
'auto'
)
model_ref
.
eval
()
with
torch
.
no_grad
():
out_ref
=
model_ref
.
model
(
input_ids
).
last_hidden_state
.
to
(
device
=
device
)
logits_ref
=
model_ref
(
input_ids
).
logits
.
to
(
device
=
device
)
del
model_ref
model_hf
=
LlamaForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
'
{
model_name
}
-hf'
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
})
model_hf
.
eval
()
out_hf
=
model_hf
.
model
(
input_ids
).
last_hidden_state
logits_hf
=
model_hf
(
input_ids
).
logits
del
model_hf
print
(
f
'Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
out_hf
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
3
*
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
print
(
f
'Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Logits mean diff:
{
(
logits
-
logits_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
logits_hf
-
logits_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
3
*
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel"
@
pytest
.
mark
.
skip
(
reason
=
"Tensor Parallel is not implemented for GatedMLP yet"
)
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
2
])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"13B"
])
def
test_llama_parallel
(
model_name
,
world_size
):
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
from
apex.transformer
import
parallel_state
checkpoint_path
=
Path
(
os
.
environ
.
get
(
'CHECKPOINT_DIR'
,
current_dir
.
parent
.
parent
/
'checkpoints'
))
/
'llama'
dtype
=
torch
.
float16
device
=
'cuda'
config
=
llama_config_to_gpt2_config
(
config_from_checkpoint
(
checkpoint_path
,
model_name
))
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
False
# We don't have fused GatedMLP yet
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
,
init_method
=
'env://'
)
device
=
f
'cuda:
{
torch
.
distributed
.
get_rank
()
}
'
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
process_group
=
parallel_state
.
get_tensor_model_parallel_group
()
ckpt_state_dicts
=
state_dicts_from_checkpoint
(
checkpoint_path
,
model_name
)
pretrained_state_dicts
=
[
remap_state_dict_meta_llama
(
s
,
config
)
for
s
in
ckpt_state_dicts
]
pretrained_state_dict
=
combine_state_dicts_tp
(
pretrained_state_dicts
,
config
)
model
=
GPTLMHeadModel
(
config
,
process_group
=
process_group
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
shard_state_dict_tp
(
pretrained_state_dict
,
config
,
world_size
,
rank
),
strict
=
False
)
model
.
eval
()
torch
.
manual_seed
(
0
)
batch_size
=
2
max_seqlen
=
256
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
device
)
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
with
torch
.
no_grad
():
out
=
model
.
transformer
(
input_ids
)
logits
=
model
(
input_ids
).
logits
del
model
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref
=
LlamaForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
'
{
model_name
}
-hf'
,
device_map
=
'auto'
)
model_ref
.
eval
()
with
torch
.
no_grad
():
out_ref
=
model_ref
.
model
(
input_ids
).
last_hidden_state
.
to
(
device
=
device
)
logits_ref
=
model_ref
(
input_ids
).
logits
.
to
(
device
=
device
)
del
model_ref
model_hf
=
LlamaForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
'
{
model_name
}
-hf'
,
torch_dtype
=
dtype
,
device_map
=
"auto"
)
model_hf
.
eval
()
out_hf
=
model_hf
.
model
(
input_ids
).
last_hidden_state
.
to
(
device
=
device
)
logits_hf
=
model_hf
(
input_ids
).
logits
.
to
(
device
=
device
)
del
model_hf
print
(
f
'Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
out_hf
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
3
*
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
print
(
f
'Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Logits mean diff:
{
(
logits
-
logits_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
logits_hf
-
logits_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
3
*
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"7B"
])
def
test_llama_generation
(
model_name
):
checkpoint_path
=
Path
(
os
.
environ
.
get
(
'CHECKPOINT_DIR'
,
current_dir
.
parent
.
parent
/
'checkpoints'
))
/
'llama'
dtype
=
torch
.
float16
device
=
'cuda'
config
=
llama_config_to_gpt2_config
(
config_from_checkpoint
(
checkpoint_path
,
model_name
))
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
False
# We don't have fused GatedMLP yet
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
tokenizer
=
LlamaTokenizer
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
'
{
model_name
}
-hf'
)
eos_token_id
=
tokenizer
.
eos_token_id
torch
.
manual_seed
(
0
)
batch_size
=
1
seqlen
=
100
max_length
=
150
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
model_hf
=
LlamaForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
'
{
model_name
}
-hf'
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
})
model_hf
.
eval
()
print
(
"HF fp16"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
torch
.
cuda
.
synchronize
()
print
(
f
'Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms'
)
del
model_hf
model_ref
=
LlamaForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
'
{
model_name
}
-hf'
,
device_map
=
{
""
:
device
})
model_ref
.
eval
()
with
torch
.
no_grad
():
logits_ref
=
model_ref
(
out_hf
.
sequences
).
logits
[:,
(
seqlen
-
1
):
-
1
]
del
model_ref
ckpt_state_dicts
=
state_dicts_from_checkpoint
(
checkpoint_path
,
model_name
)
pretrained_state_dicts
=
[
remap_state_dict_meta_llama
(
s
,
config
)
for
s
in
ckpt_state_dicts
]
pretrained_state_dict
=
combine_state_dicts_tp
(
pretrained_state_dicts
,
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
pretrained_state_dict
,
strict
=
False
)
model
.
eval
()
print
(
'Without CUDA graph'
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
fused_ft_kernel
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
teacher_outputs
=
out_hf
.
sequences
)
torch
.
cuda
.
synchronize
()
print
(
f
'Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms'
)
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
print
(
'With CUDA graph'
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
fused_ft_kernel
=
True
,
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
teacher_outputs
=
out_hf
.
sequences
)
torch
.
cuda
.
synchronize
()
print
(
f
'Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms'
)
with
torch
.
no_grad
():
logits_parallel
=
model
(
out_hf
.
sequences
).
logits
[:,
(
seqlen
-
1
):
-
1
]
logits_hf
=
torch
.
stack
(
out_hf
.
scores
,
dim
=
1
)
logits
=
torch
.
stack
(
out
.
scores
,
dim
=
1
)
logits_cg
=
torch
.
stack
(
out_cg
.
scores
,
dim
=
1
)
del
model
hf_error
=
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
# For some reason logits_parallel is off by quite a bit more than 2x
assert
(
logits_parallel
-
logits_ref
).
abs
().
max
().
item
()
<
8
*
hf_error
print
(
f
'HF fp16 logits max diff:
{
hf_error
}
'
)
print
(
f
'Logits max diff:
{
(
logits
-
logits_parallel
).
abs
().
max
().
item
()
}
'
)
assert
(
logits
-
logits_parallel
).
abs
().
max
().
item
()
<
2
*
hf_error
print
(
f
'Logits CG max diff:
{
(
logits_cg
-
logits_parallel
).
abs
().
max
().
item
()
}
'
)
assert
torch
.
equal
(
logits_cg
,
logits
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment