Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
88173a1a
Commit
88173a1a
authored
Jan 17, 2023
by
Tri Dao
Browse files
[FusedDense] Support relu, rename FusedDenseGeluDense -> FusedMLP
parent
780e8eea
Changes
20
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
654 additions
and
779 deletions
+654
-779
csrc/fused_dense_lib/fused_dense.cpp
csrc/fused_dense_lib/fused_dense.cpp
+41
-45
csrc/fused_dense_lib/fused_dense_cuda.cu
csrc/fused_dense_lib/fused_dense_cuda.cu
+266
-478
flash_attn/models/bert.py
flash_attn/models/bert.py
+7
-7
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+18
-14
flash_attn/models/vit.py
flash_attn/models/vit.py
+8
-8
flash_attn/modules/block.py
flash_attn/modules/block.py
+2
-1
flash_attn/modules/mlp.py
flash_attn/modules/mlp.py
+2
-2
flash_attn/ops/fused_dense.py
flash_attn/ops/fused_dense.py
+111
-64
tests/models/test_bert.py
tests/models/test_bert.py
+6
-6
tests/models/test_gpt.py
tests/models/test_gpt.py
+1
-1
tests/models/test_gpt_generation.py
tests/models/test_gpt_generation.py
+5
-118
tests/models/test_gpt_generation_parallel.py
tests/models/test_gpt_generation_parallel.py
+131
-0
tests/models/test_gpt_parallel.py
tests/models/test_gpt_parallel.py
+7
-3
tests/models/test_vit.py
tests/models/test_vit.py
+6
-5
tests/modules/test_block_parallel.py
tests/modules/test_block_parallel.py
+5
-6
tests/ops/test_fused_dense.py
tests/ops/test_fused_dense.py
+25
-9
tests/ops/test_fused_dense_parallel.py
tests/ops/test_fused_dense_parallel.py
+8
-9
training/README.md
training/README.md
+1
-1
training/configs/experiment/owt/gpt2s-flash.yaml
training/configs/experiment/owt/gpt2s-flash.yaml
+2
-1
training/configs/experiment/pile/gpt3s-flash.yaml
training/configs/experiment/pile/gpt3s-flash.yaml
+2
-1
No files found.
csrc/fused_dense_lib/fused_dense.cpp
View file @
88173a1a
...
...
@@ -28,19 +28,19 @@
}
template
<
typename
T
>
int
linear_bias_wgrad_cuda
(
T
*
input
,
T
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
T
*
d_weight
,
T
*
d_bias
,
void
*
lt_workspace
);
int
linear_bias_wgrad_cuda
(
const
T
*
input
,
const
T
*
d_output
,
int
64_t
in_features
,
int
64_t
batch_size
,
int
64_t
out_features
,
T
*
d_weight
,
T
*
d_bias
);
template
<
typename
T
>
int
linear_
gelu
_forward_cuda
(
T
*
input
,
T
*
weight
,
T
*
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
int
heuristic
,
T
*
output
,
T
*
gelu_in
,
void
*
lt_workspace
)
;
int
linear_
act
_forward_cuda
(
const
T
*
input
,
const
T
*
weight
,
const
T
*
bias
,
int
64_t
in_features
,
int
64_t
batch_size
,
int
64_t
out_features
,
bool
is_gelu
,
int
heuristic
,
T
*
output
,
void
*
pre_act
)
;
template
<
typename
T
>
int
bias_
gelu
_linear_dgrad_bgrad_cuda
(
T
*
weight
,
T
*
d_output
,
T
*
gelu_in
,
int
in_features
,
int
batch_size
,
int
out_features
,
int
heuristic
,
T
*
d_input
,
T
*
d_bias
,
void
*
lt_workspace
);
int
bias_
act
_linear_dgrad_bgrad_cuda
(
const
T
*
weight
,
const
T
*
d_output
,
const
void
*
pre_act
,
int
64_t
in_features
,
int
64_t
batch_size
,
int
64_t
out_features
,
bool
is_gelu
,
int
heuristic
,
T
*
d_input
,
T
*
d_bias
);
std
::
vector
<
at
::
Tensor
>
linear_bias_wgrad
(
at
::
Tensor
input
,
at
::
Tensor
d_output
,
bool
has_d_bias
)
{
int
batch_size
=
input
.
size
(
0
);
int
in_features
=
input
.
size
(
1
);
int
out_features
=
d_output
.
size
(
1
);
int
64_t
batch_size
=
input
.
size
(
0
);
int
64_t
in_features
=
input
.
size
(
1
);
int
64_t
out_features
=
d_output
.
size
(
1
);
TORCH_CHECK
(
input
.
dtype
()
==
torch
::
kFloat16
||
input
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
input
.
dtype
()
==
d_output
.
dtype
());
...
...
@@ -66,8 +66,6 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output,
d_bias
=
at
::
empty
({
out_features
},
opts
);
#endif
}
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_bias_wgrad"
,
[
&
]
{
auto
result
=
linear_bias_wgrad_cuda
<
scalar_t
>
(
...
...
@@ -77,21 +75,20 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output,
batch_size
,
out_features
,
d_weight
.
data_ptr
<
scalar_t
>
(),
has_d_bias
?
d_bias
.
data_ptr
<
scalar_t
>
()
:
nullptr
,
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
has_d_bias
?
d_bias
.
data_ptr
<
scalar_t
>
()
:
nullptr
);
TORCH_CHECK
(
result
==
0
,
"linear_bias_wgrad failed."
);
});
return
{
d_weight
,
d_bias
};
}
std
::
vector
<
at
::
Tensor
>
linear_
gelu
_forward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
c10
::
optional
<
at
::
Tensor
>
bias_
,
bool
save_
gelu_in
,
int
heuristic
)
{
std
::
vector
<
at
::
Tensor
>
linear_
act
_forward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
c10
::
optional
<
at
::
Tensor
>
bias_
,
bool
is_gelu
,
bool
save_
pre_act
,
int
heuristic
)
{
int
batch_size
=
input
.
size
(
0
);
int
in_features
=
input
.
size
(
1
);
int
out_features
=
weight
.
size
(
0
);
int
64_t
batch_size
=
input
.
size
(
0
);
int
64_t
in_features
=
input
.
size
(
1
);
int
64_t
out_features
=
weight
.
size
(
0
);
TORCH_CHECK
(
input
.
dtype
()
==
torch
::
kFloat16
||
input
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
input
.
dtype
()
==
weight
.
dtype
());
...
...
@@ -116,51 +113,52 @@ std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight,
// create output/workspace tensor
auto
opts
=
input
.
options
();
auto
output
=
at
::
empty
({
batch_size
,
out_features
},
opts
);
at
::
Tensor
gelu_in
;
if
(
save_gelu_in
)
{
gelu_in
=
at
::
empty
({
batch_size
,
out_features
},
opts
);
}
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
at
::
Tensor
pre_act
;
// If ReLU, cuBlasLT stores a bit-mask (1 bit per element)
if
(
save_pre_act
)
{
pre_act
=
at
::
empty
({
batch_size
,
is_gelu
?
out_features
:
out_features
/
8
},
is_gelu
?
opts
:
opts
.
dtype
(
torch
::
kUInt8
));
}
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_
gelu
_forward"
,
[
&
]
{
auto
result
=
linear_
gelu
_forward_cuda
<
scalar_t
>
(
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_
act
_forward"
,
[
&
]
{
auto
result
=
linear_
act
_forward_cuda
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
bias_
.
has_value
()
?
bias_
.
value
().
data_ptr
<
scalar_t
>
()
:
nullptr
,
in_features
,
batch_size
,
out_features
,
is_gelu
,
heuristic
,
output
.
data_ptr
<
scalar_t
>
(),
save_gelu_in
?
gelu_in
.
data_ptr
<
scalar_t
>
()
:
nullptr
,
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
TORCH_CHECK
(
result
==
0
,
"linear_gelu_forward failed."
);
save_pre_act
?
pre_act
.
data_ptr
()
:
nullptr
);
TORCH_CHECK
(
result
==
0
,
"linear_act_forward failed."
);
});
std
::
vector
<
at
::
Tensor
>
result
=
{
output
};
if
(
save_
gelu_in
)
{
result
.
push_back
(
gelu_in
);
};
if
(
save_
pre_act
)
{
result
.
push_back
(
pre_act
);
};
return
result
;
}
std
::
vector
<
at
::
Tensor
>
bias_
gelu
_linear_dgrad_bgrad
(
at
::
Tensor
weight
,
at
::
Tensor
d_output
,
at
::
Tensor
gelu
_in
,
int
heuristic
std
::
vector
<
at
::
Tensor
>
bias_
act
_linear_dgrad_bgrad
(
at
::
Tensor
weight
,
at
::
Tensor
d_output
,
at
::
Tensor
pre_act
,
bool
is_
gelu
,
int
heuristic
)
{
int
batch_size
=
d_output
.
size
(
0
);
int
out_features
=
d_output
.
size
(
1
);
int
in_features
=
weight
.
size
(
1
);
int
64_t
batch_size
=
d_output
.
size
(
0
);
int
64_t
out_features
=
d_output
.
size
(
1
);
int
64_t
in_features
=
weight
.
size
(
1
);
TORCH_CHECK
(
weight
.
dtype
()
==
torch
::
kFloat16
||
weight
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
weight
.
dtype
()
==
d_output
.
dtype
());
TORCH_CHECK
(
weight
.
dtype
()
==
gelu_in
.
dtype
(
));
TORCH_CHECK
(
is_gelu
?
(
pre_act
.
dtype
()
==
weight
.
dtype
())
:
(
pre_act
.
dtype
()
==
torch
::
kUInt8
));
TORCH_CHECK
(
weight
.
is_cuda
());
TORCH_CHECK
(
d_output
.
is_cuda
());
TORCH_CHECK
(
gelu_in
.
is_cuda
());
TORCH_CHECK
(
pre_act
.
is_cuda
());
TORCH_CHECK
(
weight
.
is_contiguous
());
TORCH_CHECK
(
d_output
.
is_contiguous
());
TORCH_CHECK
(
gelu_in
.
is_contiguous
());
TORCH_CHECK
(
pre_act
.
is_contiguous
());
CHECK_SHAPE
(
weight
,
out_features
,
in_features
);
CHECK_SHAPE
(
d_output
,
batch_size
,
out_features
);
CHECK_SHAPE
(
gelu_in
,
batch_size
,
in_features
);
// If ReLU, cuBlasLT stores a bit-mask (1 bit per element)
CHECK_SHAPE
(
pre_act
,
batch_size
,
is_gelu
?
in_features
:
in_features
/
8
);
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
...
...
@@ -170,22 +168,20 @@ std::vector<at::Tensor> bias_gelu_linear_dgrad_bgrad(
auto
opts
=
weight
.
options
();
auto
d_bias
=
at
::
empty
({
in_features
},
opts
);
auto
d_input
=
at
::
empty
({
batch_size
,
in_features
},
opts
);
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
DISPATCH_HALF_AND_BF16
(
weight
.
scalar_type
(),
"bias_
gelu
_linear_dgrad_bgrad"
,
[
&
]
{
auto
result
=
bias_
gelu
_linear_dgrad_bgrad_cuda
<
scalar_t
>
(
DISPATCH_HALF_AND_BF16
(
weight
.
scalar_type
(),
"bias_
act
_linear_dgrad_bgrad"
,
[
&
]
{
auto
result
=
bias_
act
_linear_dgrad_bgrad_cuda
<
scalar_t
>
(
weight
.
data_ptr
<
scalar_t
>
(),
d_output
.
data_ptr
<
scalar_t
>
(),
gelu_in
.
data_ptr
<
scalar_t
>
(),
pre_act
.
data_ptr
(),
in_features
,
batch_size
,
out_features
,
is_gelu
,
heuristic
,
d_input
.
data_ptr
<
scalar_t
>
(),
d_bias
.
data_ptr
<
scalar_t
>
(),
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
TORCH_CHECK
(
result
==
0
,
"bias_gelu_linear_dgrad_bgrad failed."
);
d_bias
.
data_ptr
<
scalar_t
>
());
TORCH_CHECK
(
result
==
0
,
"bias_act_linear_dgrad_bgrad failed."
);
});
return
{
d_input
,
d_bias
};
...
...
@@ -193,6 +189,6 @@ std::vector<at::Tensor> bias_gelu_linear_dgrad_bgrad(
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"linear_bias_wgrad"
,
&
linear_bias_wgrad
,
"linear bias wgrad"
);
m
.
def
(
"linear_
gelu
_forward"
,
&
linear_
gelu
_forward
,
"linear gelu forward"
);
m
.
def
(
"bias_
gelu
_linear_dgrad_bgrad"
,
&
bias_
gelu
_linear_dgrad_bgrad
,
"bias gelu linear dgrad bgrad"
);
m
.
def
(
"linear_
act
_forward"
,
&
linear_
act
_forward
,
"linear gelu
/relu
forward"
);
m
.
def
(
"bias_
act
_linear_dgrad_bgrad"
,
&
bias_
act
_linear_dgrad_bgrad
,
"bias gelu
/relu
linear dgrad bgrad"
);
}
csrc/fused_dense_lib/fused_dense_cuda.cu
View file @
88173a1a
This diff is collapsed.
Click to expand it.
flash_attn/models/bert.py
View file @
88173a1a
...
...
@@ -23,7 +23,7 @@ from transformers.models.bert.modeling_bert import BertForPreTrainingOutput
from
einops
import
rearrange
from
flash_attn.modules.mha
import
MHA
from
flash_attn.modules.mlp
import
Mlp
,
Fused
DenseGeluDense
from
flash_attn.modules.mlp
import
Mlp
,
Fused
MLP
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.embedding
import
BertEmbeddings
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
...
...
@@ -61,24 +61,24 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
def
create_mlp_cls
(
config
,
layer_idx
=
None
,
return_residual
=
False
):
inner_dim
=
config
.
intermediate_size
fused_
dense_gelu_dense
=
getattr
(
config
,
'fused_
dense_gelu_dense
'
,
False
)
if
fused_
dense_gelu_dense
:
assert
config
.
hidden_act
in
[
'gelu_new'
,
'gelu_fast'
],
(
'fused_
dense_gelu_dense
only '
fused_
mlp
=
getattr
(
config
,
'fused_
mlp
'
,
False
)
if
fused_
mlp
:
assert
config
.
hidden_act
in
[
'gelu_new'
,
'gelu_fast'
],
(
'fused_
mlp
only '
'supports approximate gelu'
)
if
not
fused_
dense_gelu_dense
:
if
not
fused_
mlp
:
approximate
=
'tanh'
if
config
.
hidden_act
in
[
'gelu_new'
,
'gelu_fast'
]
else
'none'
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
inner_dim
,
activation
=
partial
(
F
.
gelu
,
approximate
=
approximate
),
return_residual
=
return_residual
)
else
:
if
Fused
DenseGeluDense
is
None
:
if
Fused
MLP
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
mlp_checkpoint_lvl
=
getattr
(
config
,
'mlp_checkpoint_lvl'
,
0
)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
if
isinstance
(
mlp_checkpoint_lvl
,
Sequence
):
assert
layer_idx
is
not
None
mlp_checkpoint_lvl
=
mlp_checkpoint_lvl
[
layer_idx
]
mlp_cls
=
partial
(
Fused
DenseGeluDense
,
hidden_features
=
inner_dim
,
mlp_cls
=
partial
(
Fused
MLP
,
hidden_features
=
inner_dim
,
checkpoint_lvl
=
mlp_checkpoint_lvl
,
return_residual
=
return_residual
)
return
mlp_cls
...
...
flash_attn/models/gpt.py
View file @
88173a1a
...
...
@@ -17,7 +17,7 @@ from transformers import GPT2Config
from
einops
import
rearrange
from
flash_attn.modules.mha
import
MHA
,
ParallelMHA
from
flash_attn.modules.mlp
import
Mlp
,
Fused
DenseGeluDense
,
ParallelFused
DenseGeluDense
from
flash_attn.modules.mlp
import
Mlp
,
Fused
MLP
,
ParallelFused
MLP
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.embedding
import
GPT2Embeddings
,
ParallelGPT2Embeddings
from
flash_attn.utils.distributed
import
sync_shared_params
,
all_gather_raw
...
...
@@ -77,22 +77,22 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
def
create_mlp_cls
(
config
,
layer_idx
=
None
,
process_group
=
None
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
config
.
hidden_size
fused_dense_gelu_dense
=
getattr
(
config
,
'fused_dense_gelu_dense'
,
False
)
if
fused_dense_gelu_dense
:
assert
config
.
activation_function
in
[
'gelu_new'
,
'gelu_fast'
],
(
'fused_dense_gelu_dense only '
'supports approximate gelu'
)
fused_mlp
=
getattr
(
config
,
'fused_mlp'
,
False
)
if
fused_mlp
:
assert
config
.
activation_function
in
[
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
,
'relu'
]
fused_dense_sqrelu_dense
=
getattr
(
config
,
'fused_dense_sqrelu_dense'
,
False
)
if
fused_dense_sqrelu_dense
:
assert
config
.
activation_function
==
'sqrelu'
,
(
'fused_dense_sqrelu_dense only '
'supports approximate activation_function sqrelu'
)
assert
not
(
fused_dense_sqrelu_dense
and
fused_
dense_gelu_dense
)
assert
not
(
fused_dense_sqrelu_dense
and
fused_
mlp
)
if
process_group
is
not
None
:
assert
fused_
dense_gelu_dense
,
'Tensor Parallel is only implemented for Fused
DenseGeluDense
'
if
not
fused_
dense_gelu_dense
and
not
fused_dense_sqrelu_dense
:
assert
fused_
mlp
,
'Tensor Parallel is only implemented for Fused
MLP
'
if
not
fused_
mlp
and
not
fused_dense_sqrelu_dense
:
if
config
.
activation_function
==
'relu'
:
activation
=
partial
(
F
.
relu
,
inplace
=
True
)
else
:
approximate
=
'tanh'
if
config
.
activation_function
in
[
'gelu_new'
,
'gelu_fast'
]
else
'none'
approximate
=
(
'tanh'
if
config
.
activation_function
in
[
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
]
else
'none'
)
activation
=
partial
(
F
.
gelu
,
approximate
=
approximate
)
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
inner_dim
,
activation
=
activation
,
**
factory_kwargs
)
else
:
...
...
@@ -101,14 +101,17 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
if
isinstance
(
mlp_checkpoint_lvl
,
Sequence
):
assert
layer_idx
is
not
None
mlp_checkpoint_lvl
=
mlp_checkpoint_lvl
[
layer_idx
]
if
fused_
dense_gelu_dense
:
if
Fused
DenseGeluDense
is
None
:
if
fused_
mlp
:
if
Fused
MLP
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
mlp_cls
=
FusedDenseGeluDense
if
process_group
is
None
else
ParallelFusedDenseGeluDense
activation
=
(
'gelu_approx'
if
config
.
activation_function
in
[
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
]
else
'relu'
)
mlp_cls
=
FusedMLP
if
process_group
is
None
else
ParallelFusedMLP
parallel_kwargs
=
({
'process_group'
:
process_group
,
'sequence_parallel'
:
getattr
(
config
,
'sequence_parallel'
,
True
)}
if
process_group
is
not
None
else
{})
mlp_cls
=
partial
(
mlp_cls
,
hidden_features
=
inner_dim
,
checkpoint_lvl
=
mlp_checkpoint_lvl
,
mlp_cls
=
partial
(
mlp_cls
,
hidden_features
=
inner_dim
,
activation
=
activation
,
checkpoint_lvl
=
mlp_checkpoint_lvl
,
**
parallel_kwargs
,
**
factory_kwargs
)
elif
fused_dense_sqrelu_dense
:
assert
FusedDenseSqreluDense
is
not
None
...
...
@@ -210,7 +213,8 @@ class GPTModel(GPTPreTrainedModel):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
self
.
process_group
=
process_group
self
.
sequence_parallel
=
getattr
(
config
,
'sequence_parallel'
,
True
)
assert
config
.
activation_function
in
[
'gelu'
,
'gelu_new'
,
'gelu_fast'
,
'relu'
,
'sqrelu'
]
assert
config
.
activation_function
in
[
'gelu'
,
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
,
'relu'
,
'sqrelu'
]
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
...
...
flash_attn/models/vit.py
View file @
88173a1a
...
...
@@ -20,7 +20,7 @@ from timm.models.helpers import named_apply
from
flash_attn.layers.patch_embed
import
PatchEmbed
from
flash_attn.modules.mha
import
MHA
from
flash_attn.modules.mlp
import
Mlp
,
Fused
DenseGeluDense
from
flash_attn.modules.mlp
import
Mlp
,
Fused
MLP
from
flash_attn.modules.block
import
Block
try
:
...
...
@@ -37,22 +37,22 @@ def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_
return
mixer_cls
def
create_mlp_cls
(
embed_dim
,
mlp_ratio
,
act_layer
,
fused_
dense_gelu_dense
):
def
create_mlp_cls
(
embed_dim
,
mlp_ratio
,
act_layer
,
fused_
mlp
):
inner_dim
=
int
(
embed_dim
*
mlp_ratio
)
if
not
fused_
dense_gelu_dense
:
if
not
fused_
mlp
:
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
inner_dim
,
activation
=
act_layer
())
else
:
mlp_cls
=
partial
(
Fused
DenseGeluDense
,
hidden_features
=
inner_dim
)
mlp_cls
=
partial
(
Fused
MLP
,
hidden_features
=
inner_dim
)
return
mlp_cls
def
create_block
(
embed_dim
,
num_heads
,
mlp_ratio
,
qkv_bias
,
drop_rate
,
attn_drop_rate
,
drop_path1
,
drop_path2
,
norm_layer
,
act_layer
,
use_flash_attn
,
fused_bias_fc
,
fused_
dense_gelu_dense
,
fused_dropout_add_ln
,
layer_idx
=
None
,
n_layer
=
None
,
fused_
mlp
,
fused_dropout_add_ln
,
layer_idx
=
None
,
n_layer
=
None
,
last_layer_subset
=
False
):
mixer_cls
=
create_mixer_cls
(
num_heads
,
qkv_bias
,
attn_drop_rate
,
use_flash_attn
,
fused_bias_fc
,
cross_attn
=
(
last_layer_subset
and
layer_idx
==
n_layer
-
1
))
mlp_cls
=
create_mlp_cls
(
embed_dim
,
mlp_ratio
,
act_layer
,
fused_
dense_gelu_dense
)
mlp_cls
=
create_mlp_cls
(
embed_dim
,
mlp_ratio
,
act_layer
,
fused_
mlp
)
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
block
=
Block
(
embed_dim
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_layer
,
prenorm
=
True
,
resid_dropout1
=
drop_rate
,
resid_dropout2
=
drop_rate
,
...
...
@@ -92,7 +92,7 @@ class VisionTransformer(nn.Module):
act_layer
=
None
,
use_flash_attn
=
False
,
fused_bias_fc
=
False
,
fused_
dense_gelu_dense
=
False
,
fused_
mlp
=
False
,
fused_dropout_add_ln
=
False
,
):
"""
...
...
@@ -164,7 +164,7 @@ class VisionTransformer(nn.Module):
embed_dim
,
num_heads
,
mlp_ratio
,
qkv_bias
,
drop_rate
,
attn_drop_rate
,
drop_path1
=
dpr
[
i
-
1
]
if
i
>
0
else
0.
,
drop_path2
=
dpr
[
i
],
norm_layer
=
norm_layer
,
act_layer
=
act_layer
,
use_flash_attn
=
use_flash_attn
,
fused_bias_fc
=
fused_bias_fc
,
fused_
dense_gelu_dense
=
fused_dense_gelu_dense
,
fused_bias_fc
=
fused_bias_fc
,
fused_
mlp
=
fused_mlp
,
fused_dropout_add_ln
=
fused_dropout_add_ln
,
layer_idx
=
i
,
n_layer
=
depth
,
last_layer_subset
=
(
global_pool
==
'token'
)
)
for
i
in
range
(
depth
)])
...
...
flash_attn/modules/block.py
View file @
88173a1a
...
...
@@ -121,7 +121,8 @@ class Block(nn.Module):
)
if
mixer_kwargs
is
None
:
mixer_kwargs
=
{}
mixer_kwargs
[
'mixer_subset'
]
=
mixer_subset
if
mixer_subset
is
not
None
:
mixer_kwargs
[
'mixer_subset'
]
=
mixer_subset
hidden_states
=
self
.
mixer
(
hidden_states
,
**
mixer_kwargs
)
if
mixer_subset
is
not
None
:
residual
=
residual
[:,
mixer_subset
]
...
...
flash_attn/modules/mlp.py
View file @
88173a1a
...
...
@@ -5,9 +5,9 @@ import torch.nn as nn
import
torch.nn.functional
as
F
try
:
from
flash_attn.ops.fused_dense
import
Fused
DenseGeluDense
,
ParallelFused
DenseGeluDense
from
flash_attn.ops.fused_dense
import
Fused
MLP
,
ParallelFused
MLP
except
ImportError
:
Fused
DenseGeluDense
,
ParallelFused
DenseGeluDense
=
None
,
None
Fused
MLP
,
ParallelFused
MLP
=
None
,
None
class
Mlp
(
nn
.
Module
):
...
...
flash_attn/ops/fused_dense.py
View file @
88173a1a
# Copyright (c) 202
2
, Tri Dao.
# Copyright (c) 202
3
, Tri Dao.
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
# We make it work with pytorch amp and with bfloat16.
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
from
typing
import
Optional
from
functools
import
partial
import
torch
import
torch.nn
as
nn
...
...
@@ -19,6 +20,11 @@ from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all
from
flash_attn.utils.distributed
import
reduce_scatter
,
all_reduce
@
torch
.
jit
.
script
def
relu_bwd
(
g
,
x
):
return
torch
.
where
(
x
>=
0
,
g
,
0.0
).
to
(
dtype
=
x
.
dtype
)
class
FusedDenseFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
...
...
@@ -185,12 +191,13 @@ class RowParallelLinear(nn.Linear):
return
reduce_fn
(
out
,
self
.
process_group
)
class
Fused
DenseGeluDense
Func
(
torch
.
autograd
.
Function
):
class
Fused
MLP
Func
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
x
,
weight1
,
bias1
,
weight2
,
bias2
,
save_pre_act
=
True
,
return_residual
=
False
,
checkpoint_lvl
=
0
,
heuristic
=
0
,
process_group
=
None
,
sequence_parallel
=
True
):
def
forward
(
ctx
,
x
,
weight1
,
bias1
,
weight2
,
bias2
,
activation
=
'gelu_approx'
,
save_pre_act
=
True
,
return_residual
=
False
,
checkpoint_lvl
=
0
,
heuristic
=
0
,
process_group
=
None
,
sequence_parallel
=
True
):
"""
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
with sequence parallelism: we do an all_gather of x before doing the matmul.
...
...
@@ -198,10 +205,11 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
checkpoint_lvl:
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute
gelu_in and g
elu_out in the bwd
1: recompute gelu_out
/ relu_out
in the bwd
2: recompute
pre_act and gelu_out / r
elu_out in the bwd
"""
assert
-
1
<=
heuristic
<=
4
assert
activation
in
[
'gelu_approx'
,
'relu'
]
if
not
save_pre_act
:
checkpoint_lvl
=
2
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
...
...
@@ -209,6 +217,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
ctx
.
process_group
=
process_group
ctx
.
sequence_parallel
=
sequence_parallel
ctx
.
checkpoint_lvl
=
checkpoint_lvl
ctx
.
activation
=
activation
ctx
.
heuristic
=
heuristic
if
torch
.
is_autocast_enabled
():
...
...
@@ -237,23 +246,27 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
if
min
(
batch_dim
,
n
,
*
weight1
.
shape
,
*
weight2
.
shape
)
>
65535
*
32
:
raise
RuntimeError
(
'fused_dense only supports matrix dims <= 2M'
)
if
heuristic
==
-
1
:
gelu_in
=
F
.
linear
(
total_x
,
weight1
,
bias1
)
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
pre_act
=
F
.
linear
(
total_x
,
weight1
,
bias1
)
activation_fn
=
(
partial
(
F
.
gelu
,
approximate
=
'tanh'
)
if
activation
==
'gelu_approx'
else
F
.
relu
)
output1
=
activation_fn
(
pre_act
)
# This is before adding bias1
#
gelu_in
= F.linear(total_x.reshape(batch_dim, n), weight1)
#
pre_act
= F.linear(total_x.reshape(batch_dim, n), weight1)
# with torch.jit.fuser('fuser2'):
# output1 = bias_gelu(
gelu_in
, bias1)
# output1 = bias_gelu(
pre_act
, bias1)
else
:
output1
,
*
rest
=
fused_dense_cuda
.
linear_gelu_forward
(
total_x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
save_pre_act
,
heuristic
is_gelu
=
activation
==
'gelu_approx'
output1
,
*
rest
=
fused_dense_cuda
.
linear_act_forward
(
total_x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
is_gelu
,
save_pre_act
,
heuristic
)
if
save_pre_act
:
gelu_in
=
rest
[
0
]
pre_act
=
rest
[
0
]
output2
=
F
.
linear
(
output1
,
weight2
,
bias2
)
if
checkpoint_lvl
==
0
:
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
gelu_in
,
output1
)
if
checkpoint_lvl
==
0
or
(
checkpoint_lvl
==
1
and
activation
==
'relu'
):
# For RELU the pre_act is very small (just a bit-mask) so we just save it
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
pre_act
,
output1
)
elif
checkpoint_lvl
==
1
:
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
gelu_in
)
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
pre_act
)
elif
checkpoint_lvl
==
2
:
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
bias1
)
output2
=
output2
.
reshape
(
*
batch_shape
,
output2
.
shape
[
-
1
])
...
...
@@ -264,6 +277,9 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
def
backward
(
ctx
,
grad_output
,
*
args
):
grad_output
=
grad_output
.
contiguous
()
checkpoint_lvl
=
ctx
.
checkpoint_lvl
activation
=
ctx
.
activation
activation_fn
=
(
partial
(
F
.
gelu
,
approximate
=
'tanh'
)
if
activation
==
'gelu_approx'
else
F
.
relu
)
if
ctx
.
return_residual
:
grad_input
,
=
args
grad_input
=
grad_input
.
contiguous
()
...
...
@@ -277,27 +293,27 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
if
checkpoint_lvl
in
[
0
,
1
]:
if
process_group
is
not
None
and
sequence_parallel
:
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
if
checkpoint_lvl
==
0
:
gelu_in
,
output1
=
rest
if
checkpoint_lvl
==
0
or
(
checkpoint_lvl
==
1
and
activation
==
'relu'
)
:
pre_act
,
output1
=
rest
elif
checkpoint_lvl
==
1
:
gelu_in
,
=
rest
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
pre_act
,
=
rest
output1
=
activation_fn
(
pre_act
)
elif
checkpoint_lvl
==
2
:
bias1
,
=
rest
if
process_group
is
not
None
and
sequence_parallel
:
total_x
,
_
=
all_gather_raw
(
x
,
process_group
)
if
ctx
.
heuristic
==
-
1
:
gelu_in
=
F
.
linear
(
total_x
,
weight1
,
bias1
)
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
pre_act
=
F
.
linear
(
total_x
,
weight1
,
bias1
)
output1
=
activation_fn
(
pre_act
)
else
:
output1
,
gelu_in
=
fused_dense_cuda
.
linear_
gelu
_forward
(
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]),
weight1
,
bias1
,
True
,
ctx
.
heuristic
output1
,
pre_act
=
fused_dense_cuda
.
linear_
act
_forward
(
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]),
weight1
,
bias1
,
activation
==
'gelu_approx'
,
True
,
ctx
.
heuristic
)
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
output1
=
output1
.
reshape
(
batch_dim
,
output1
.
shape
[
-
1
])
gelu_in
=
gelu_in
.
reshape
(
batch_dim
,
gelu_in
.
shape
[
-
1
])
pre_act
=
pre_act
.
reshape
(
batch_dim
,
pre_act
.
shape
[
-
1
])
if
ctx
.
needs_input_grad
[
3
]:
grad_weight2
,
grad_bias2
=
fused_dense_cuda
.
linear_bias_wgrad
(
output1
,
grad_output
,
ctx
.
needs_input_grad
[
4
]
...
...
@@ -306,24 +322,25 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
grad_weight2
=
None
grad_bias2
=
grad_output
if
ctx
.
needs_input_grad
[
4
]
else
None
if
ctx
.
heuristic
==
-
1
:
# grad_
gelu
= matmul_dgelu(grad_output, weight2,
gelu_in
)
# grad_
pre_act
= matmul_dgelu(grad_output, weight2,
pre_act
)
grad_output1
=
F
.
linear
(
grad_output
,
weight2
.
t
())
with
torch
.
jit
.
fuser
(
'fuser2'
):
grad_gelu
=
gelu_bwd
(
grad_output1
,
gelu_in
)
activation_grad_fn
=
gelu_bwd
if
activation
==
'gelu_approx'
else
relu_bwd
grad_pre_act
=
activation_grad_fn
(
grad_output1
,
pre_act
)
else
:
# The cublasLt epilogue has to compute both gelu grad and bias grad, we can't
# just compute gelu grad
grad_
gelu
,
grad_bias1
=
fused_dense_cuda
.
bias_
gelu
_linear_dgrad_bgrad
(
weight2
,
grad_output
,
gelu_in
,
ctx
.
heuristic
# The cublasLt epilogue has to compute both gelu
/relu
grad and bias grad, we can't
# just compute gelu
/relu
grad
grad_
pre_act
,
grad_bias1
=
fused_dense_cuda
.
bias_
act
_linear_dgrad_bgrad
(
weight2
,
grad_output
,
pre_act
,
activation
==
'gelu_approx'
,
ctx
.
heuristic
)
if
not
ctx
.
needs_input_grad
[
2
]:
grad_bias1
=
None
if
ctx
.
needs_input_grad
[
0
]:
if
not
ctx
.
return_residual
:
grad_input
=
F
.
linear
(
grad_
gelu
,
weight1
.
t
())
grad_input
=
F
.
linear
(
grad_
pre_act
,
weight1
.
t
())
else
:
grad_input
=
torch
.
addmm
(
grad_input
.
reshape
(
batch_dim
,
grad_input
.
shape
[
-
1
]),
grad_
gelu
,
weight1
)
grad_
pre_act
,
weight1
)
grad_input
=
grad_input
.
reshape
(
*
batch_shape
,
grad_input
.
shape
[
-
1
])
if
process_group
is
not
None
:
reduce_fn
=
reduce_scatter_raw
if
sequence_parallel
else
all_reduce_raw
...
...
@@ -335,55 +352,60 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
if
process_group
is
not
None
and
sequence_parallel
:
handle_x
.
wait
()
grad_weight1
,
grad_bias1
=
fused_dense_cuda
.
linear_bias_wgrad
(
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]),
grad_
gelu
,
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]),
grad_
pre_act
,
ctx
.
needs_input_grad
[
2
]
)
else
:
grad_weight1
=
None
grad_bias1
=
grad_
gelu
if
ctx
.
needs_input_grad
[
2
]
else
None
grad_bias1
=
grad_
pre_act
if
ctx
.
needs_input_grad
[
2
]
else
None
else
:
if
ctx
.
needs_input_grad
[
1
]:
if
process_group
is
not
None
and
sequence_parallel
:
handle_x
.
wait
()
grad_weight1
=
F
.
linear
(
grad_
gelu
.
t
(),
grad_weight1
=
F
.
linear
(
grad_
pre_act
.
t
(),
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]).
t
())
else
:
grad_weight1
=
None
if
process_group
is
not
None
and
ctx
.
needs_input_grad
[
0
]:
handle_grad_input
.
wait
()
return
(
grad_input
,
grad_weight1
,
grad_bias1
,
grad_weight2
,
grad_bias2
,
None
,
None
,
None
,
None
,
None
,
None
)
None
,
None
,
None
,
None
,
None
,
None
,
None
)
def
fused_
dense_gelu_dense
_func
(
def
fused_
mlp
_func
(
x
:
Tensor
,
weight1
:
Tensor
,
weight2
:
Tensor
,
bias1
:
Optional
[
Tensor
]
=
None
,
bias2
:
Optional
[
Tensor
]
=
None
,
bias2
:
Optional
[
Tensor
]
=
None
,
activation
:
str
=
'gelu_approx'
,
save_pre_act
:
bool
=
True
,
return_residual
:
bool
=
False
,
checkpoint_lvl
:
int
=
0
,
heuristic
:
int
=
0
,
process_group
:
Optional
[
ProcessGroup
]
=
None
,
sequence_parallel
:
bool
=
True
):
assert
activation
in
[
'gelu_approx'
,
'relu'
]
dtype_eligible
=
(
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
or
(
x
.
dtype
==
torch
.
float32
and
torch
.
is_autocast_enabled
()))
# If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu)
dim_eligible
=
not
save_pre_act
or
(
x
.
shape
[
-
1
]
%
(
128
if
activation
==
'relu'
else
8
)
==
0
)
if
(
x
.
is_cuda
and
weight1
.
is_cuda
and
weight2
.
is_cuda
and
(
bias1
is
None
or
bias1
.
is_cuda
)
and
(
bias2
is
None
or
bias2
.
is_cuda
)
and
dtype_eligible
):
return
Fused
DenseGeluDense
Func
.
apply
(
x
,
weight1
,
bias1
,
weight2
,
bias2
,
save_pre_act
,
return_residual
,
and
(
bias2
is
None
or
bias2
.
is_cuda
)
and
dtype_eligible
and
dim_eligible
):
return
Fused
MLP
Func
.
apply
(
x
,
weight1
,
bias1
,
weight2
,
bias2
,
activation
,
save_pre_act
,
return_residual
,
checkpoint_lvl
,
heuristic
,
process_group
,
sequence_parallel
)
else
:
assert
process_group
is
None
gelu_in
=
F
.
linear
(
x
,
weight1
,
bias1
)
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
pre_act
=
F
.
linear
(
x
,
weight1
,
bias1
)
activation_fn
=
(
partial
(
F
.
gelu
,
approximate
=
'tanh'
)
if
activation
==
'gelu_approx'
else
partial
(
F
.
relu
,
inplace
=
True
))
output1
=
activation_fn
(
pre_act
)
output2
=
F
.
linear
(
output1
,
weight2
,
bias2
)
return
output2
if
not
return_residual
else
(
output2
,
x
)
class
Fused
DenseGeluDense
(
nn
.
Module
):
class
Fused
MLP
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
,
out_features
=
None
,
bias1
=
True
,
bias2
=
True
,
return_residual
=
False
,
checkpoint_lvl
=
0
,
heuristic
=
0
,
device
=
None
,
dtype
=
None
):
bias2
=
True
,
activation
=
'gelu_approx'
,
return_residual
=
False
,
checkpoint_lvl
=
0
,
heuristic
=
'auto'
,
device
=
None
,
dtype
=
None
):
"""
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul, gelu, then matmul.
...
...
@@ -392,21 +414,24 @@ class FusedDenseGeluDense(nn.Module):
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute
gelu_in
and gelu_out in the bwd
2: recompute
pre_act
and gelu_out in the bwd
heuristic:
-1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu
For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
'auto': heuristic will be picked automatically:
For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
assert
activation
in
[
'gelu_approx'
,
'relu'
]
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
if
out_features
is
None
:
out_features
=
in_features
self
.
activation
=
activation
self
.
return_residual
=
return_residual
self
.
checkpoint_lvl
=
checkpoint_lvl
self
.
heuristic
=
heuristic
...
...
@@ -414,11 +439,20 @@ class FusedDenseGeluDense(nn.Module):
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias2
,
**
factory_kwargs
)
def
forward
(
self
,
x
,
process_group
=
None
):
out
=
fused_dense_gelu_dense_func
(
dtype
=
x
.
dtype
if
not
torch
.
is_autocast_enabled
()
else
torch
.
get_autocast_gpu_dtype
()
if
self
.
heuristic
==
'auto'
:
if
self
.
activation
==
'gelu_approx'
:
cuda_ver
=
tuple
(
map
(
int
,
torch
.
version
.
cuda
.
split
(
'.'
)))
heuristic
=
0
if
cuda_ver
>=
(
11
,
8
)
else
(
1
if
dtype
==
torch
.
float16
else
-
1
)
else
:
heuristic
=
0
else
:
heuristic
=
self
.
heuristic
out
=
fused_mlp_func
(
x
,
self
.
fc1
.
weight
,
self
.
fc2
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
bias
,
save_pre_act
=
self
.
training
,
return_residual
=
self
.
return_residual
,
checkpoint_lvl
=
self
.
checkpoint_lvl
,
heuristic
=
self
.
he
uristic
,
process_group
=
process_group
activation
=
self
.
activation
,
save_pre_act
=
self
.
training
,
return_residual
=
self
.
return_residual
,
checkpoint_lvl
=
self
.
c
he
ckpoint_lvl
,
heuristic
=
heuristic
,
process_group
=
process_group
)
if
self
.
return_residual
:
out
,
x
=
out
...
...
@@ -427,11 +461,12 @@ class FusedDenseGeluDense(nn.Module):
return
out
if
not
self
.
return_residual
else
(
out
,
x
)
class
ParallelFused
DenseGeluDense
(
nn
.
Module
):
class
ParallelFused
MLP
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
,
out_features
=
None
,
def
__init__
(
self
,
in_features
,
hidden_features
,
out_features
=
None
,
activation
=
'gelu_approx'
,
process_group
:
ProcessGroup
=
None
,
bias1
=
True
,
bias2
=
True
,
sequence_parallel
=
True
,
checkpoint_lvl
=
0
,
heuristic
=
0
,
device
=
None
,
dtype
=
None
):
sequence_parallel
=
True
,
checkpoint_lvl
=
0
,
heuristic
=
'auto'
,
device
=
None
,
dtype
=
None
):
"""
process_group is required. We're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul, gelu, then matmul.
...
...
@@ -440,19 +475,22 @@ class ParallelFusedDenseGeluDense(nn.Module):
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute
gelu_in
and gelu_out in the bwd
2: recompute
pre_act
and gelu_out in the bwd
heuristic:
-1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu
For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
'auto': heuristic will be picked automatically:
For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
"""
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
assert
activation
in
[
'gelu_approx'
,
'relu'
]
assert
process_group
is
not
None
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
if
out_features
is
None
:
out_features
=
in_features
self
.
activation
=
activation
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
self
.
checkpoint_lvl
=
checkpoint_lvl
...
...
@@ -463,10 +501,19 @@ class ParallelFusedDenseGeluDense(nn.Module):
bias
=
bias2
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
out
=
fused_dense_gelu_dense_func
(
dtype
=
x
.
dtype
if
not
torch
.
is_autocast_enabled
()
else
torch
.
get_autocast_gpu_dtype
()
if
self
.
heuristic
==
'auto'
:
if
self
.
activation
==
'gelu_approx'
:
cuda_ver
=
tuple
(
map
(
int
,
torch
.
version
.
cuda
.
split
(
'.'
)))
heuristic
=
0
if
cuda_ver
>=
(
11
,
8
)
else
(
1
if
dtype
==
torch
.
float16
else
-
1
)
else
:
heuristic
=
0
else
:
heuristic
=
self
.
heuristic
out
=
fused_mlp_func
(
x
,
self
.
fc1
.
weight
,
self
.
fc2
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
bias
,
save_pre_act
=
self
.
training
,
checkpoint_lvl
=
self
.
checkpoint_lvl
,
heuristic
=
self
.
heuristic
,
activation
=
self
.
activation
,
save_pre_act
=
self
.
training
,
checkpoint_lvl
=
self
.
checkpoint_lvl
,
heuristic
=
heuristic
,
process_group
=
self
.
process_group
,
sequence_parallel
=
self
.
sequence_parallel
)
...
...
tests/models/test_bert.py
View file @
88173a1a
...
...
@@ -95,13 +95,13 @@ def test_bert_optimized(model_name):
"""
dtype
=
torch
.
float16
config
=
BertConfig
.
from_pretrained
(
model_name
)
# Our implementation of fused_
dense_gelu_dense
assumes the activation is
# Our implementation of fused_
mlp
assumes the activation is
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
# If you just want "gelu", disable fused_
dense_gelu_dense
.
# If you just want "gelu", disable fused_
mlp
.
config
.
hidden_act
=
"gelu_new"
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_
dense_gelu_dense
=
True
config
.
fused_
mlp
=
True
config
.
fused_dropout_add_ln
=
True
model
=
BertForPreTraining
.
from_pretrained
(
model_name
,
config
)
...
...
@@ -171,13 +171,13 @@ def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subs
"""
dtype
=
torch
.
float16
config
=
BertConfig
.
from_pretrained
(
model_name
)
# Our implementation of fused_
dense_gelu_dense
assumes the activation is
# Our implementation of fused_
mlp
assumes the activation is
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
# If you just want "gelu", disable fused_
dense_gelu_dense
.
# If you just want "gelu", disable fused_
mlp
.
config
.
hidden_act
=
"gelu_new"
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_
dense_gelu_dense
=
True
config
.
fused_
mlp
=
True
config
.
fused_dropout_add_ln
=
True
config
.
dense_seq_output
=
True
config
.
last_layer_subset
=
last_layer_subset
...
...
tests/models/test_gpt.py
View file @
88173a1a
...
...
@@ -82,7 +82,7 @@ def test_gpt2_optimized(model_name):
vocab_size_og
=
config
.
vocab_size
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_
dense_gelu_dense
=
True
config
.
fused_
mlp
=
True
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
config
.
pad_vocab_size_multiple
=
8
...
...
tests/models/test_gpt_generation.py
View file @
88173a1a
...
...
@@ -18,7 +18,7 @@ from flash_attn.utils.distributed import all_gather_raw
@
pytest
.
mark
.
parametrize
(
'fused_ft_kernel'
,
[
False
,
True
])
# @pytest.mark.parametrize('fused_ft_kernel', [True])
@
pytest
.
mark
.
parametrize
(
'optimized'
,
[
False
,
True
])
# @pytest.mark.parametrize('optimized', [
Tru
e])
# @pytest.mark.parametrize('optimized', [
Fals
e])
@
pytest
.
mark
.
parametrize
(
'rotary'
,
[
False
,
True
])
# @pytest.mark.parametrize('rotary', [False])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
])
...
...
@@ -34,10 +34,11 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
if
rotary
:
config
.
n_positions
=
0
config
.
rotary_emb_dim
=
64
config
.
residual_in_fp32
=
True
if
optimized
:
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_
dense_gelu_dense
=
True
config
.
fused_
mlp
=
True
config
.
fused_dropout_add_ln
=
True
# if not rotary, we load the weight from HF but ignore the position embeddings.
...
...
@@ -78,6 +79,7 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
)
print
(
out
.
sequences
)
print
(
tokenizer
.
batch_decode
(
out
.
sequences
.
tolist
()))
if
fused_ft_kernel
:
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
...
...
@@ -94,122 +96,7 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
print
(
f
'Scores mean diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
assert
torch
.
all
(
out
.
sequences
==
sequences
)
assert
torch
.
allclose
(
torch
.
stack
(
out
.
scores
,
dim
=
1
),
torch
.
stack
(
scores
,
dim
=
1
),
rtol
=
rtol
,
atol
=
atol
)
if
not
rotary
:
assert
torch
.
all
(
out
.
sequences
==
out_ref
.
sequences
)
assert
torch
.
all
(
out
.
sequences
==
out_hf
.
sequences
)
assert
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
<
3
*
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_generation.py -k "parallel"
# @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
2
])
# @pytest.mark.parametrize('fused_ft_kernel', [False, True])
@
pytest
.
mark
.
parametrize
(
'fused_ft_kernel'
,
[
True
])
# @pytest.mark.parametrize('rotary', [False, True])
@
pytest
.
mark
.
parametrize
(
'rotary'
,
[
False
])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
])
def
test_tensor_parallel
(
model_name
,
rotary
,
fused_ft_kernel
,
world_size
):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
dtype
=
torch
.
float16
rtol
,
atol
=
3e-3
,
3e-1
config
=
GPT2Config
.
from_pretrained
(
model_name
)
if
rotary
:
config
.
n_positions
=
0
config
.
rotary_emb_dim
=
64
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_dense_gelu_dense
=
True
config
.
fused_dropout_add_ln
=
True
config
.
pad_vocab_size_multiple
=
8
*
world_size
config
.
sequence_parallel
=
False
# Need to set this to False for generation
os
.
environ
[
"NCCL_ASYNC_ERROR_HANDLING"
]
=
"0"
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
,
init_method
=
'env://'
)
device
=
f
'cuda:
{
torch
.
distributed
.
get_rank
()
}
'
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
torch
.
cuda
.
set_device
(
device
)
from
apex.transformer
import
parallel_state
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
process_group
=
parallel_state
.
get_tensor_model_parallel_group
()
# if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test.
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
strict
=
not
rotary
,
device
=
device
,
dtype
=
dtype
,
process_group
=
process_group
,
world_size
=
world_size
,
rank
=
rank
)
model
.
eval
()
if
not
rotary
:
model_ref
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
to
(
device
=
device
)
model_hf
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
to
(
device
=
device
,
dtype
=
dtype
)
model_ref
.
eval
()
model_hf
.
eval
()
torch
.
manual_seed
(
0
)
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
input_ids
=
tokenizer
(
"Hello, my dog is cute and "
,
return_tensors
=
"pt"
).
input_ids
.
to
(
device
=
device
)
max_length
=
30
# input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
# Slow generation for reference
sequences
=
[]
scores
=
[]
cur_input_ids
=
input_ids
with
torch
.
inference_mode
():
logits
,
_
=
all_gather_raw
(
model
(
cur_input_ids
).
logits
[:,
-
1
],
process_group
)
logits
=
rearrange
(
logits
,
'(n b) d -> b (n d)'
,
b
=
input_ids
.
shape
[
0
])[...,
:
config
.
vocab_size
]
scores
.
append
(
logits
)
sequences
.
append
(
scores
[
-
1
].
argmax
(
dim
=-
1
))
for
_
in
range
(
input_ids
.
shape
[
1
]
+
1
,
max_length
):
cur_input_ids
=
torch
.
cat
([
cur_input_ids
,
rearrange
(
sequences
[
-
1
],
'b -> b 1'
)],
dim
=-
1
)
logits
,
_
=
all_gather_raw
(
model
(
cur_input_ids
).
logits
[:,
-
1
],
process_group
)
logits
=
rearrange
(
logits
,
'(n b) d -> b (n d)'
,
b
=
input_ids
.
shape
[
0
])[...,
:
config
.
vocab_size
]
scores
.
append
(
logits
)
sequences
.
append
(
scores
[
-
1
].
argmax
(
dim
=-
1
))
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
)
scores
=
tuple
(
scores
)
print
(
sequences
)
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
)
print
(
out
.
sequences
)
if
fused_ft_kernel
:
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
)
print
(
out_cg
.
sequences
)
if
not
rotary
:
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
out_ref
=
model_ref
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
print
(
f
'Scores max diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
'
)
print
(
f
'Scores mean diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
print
(
tokenizer
.
batch_decode
(
out_ref
.
sequences
.
tolist
()))
assert
torch
.
all
(
out
.
sequences
==
sequences
)
assert
torch
.
allclose
(
torch
.
stack
(
out
.
scores
,
dim
=
1
),
torch
.
stack
(
scores
,
dim
=
1
),
...
...
tests/models/test_gpt_generation_parallel.py
0 → 100644
View file @
88173a1a
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_generation_parallel.py -k "parallel"
import
os
import
re
import
torch
import
pytest
from
einops
import
rearrange
from
transformers
import
GPT2Config
,
GPT2Tokenizer
from
transformers.models.gpt2.modeling_gpt2
import
GPT2LMHeadModel
as
GPT2LMHeadModelHF
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt
import
remap_state_dict_gpt2
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.distributed
import
all_gather_raw
# @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
2
])
# @pytest.mark.parametrize('fused_ft_kernel', [False, True])
@
pytest
.
mark
.
parametrize
(
'fused_ft_kernel'
,
[
True
])
# @pytest.mark.parametrize('rotary', [False, True])
@
pytest
.
mark
.
parametrize
(
'rotary'
,
[
False
])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
])
def
test_tensor_parallel
(
model_name
,
rotary
,
fused_ft_kernel
,
world_size
):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
dtype
=
torch
.
float16
rtol
,
atol
=
3e-3
,
3e-1
config
=
GPT2Config
.
from_pretrained
(
model_name
)
if
rotary
:
config
.
n_positions
=
0
config
.
rotary_emb_dim
=
64
config
.
residual_in_fp32
=
True
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
True
config
.
fused_dropout_add_ln
=
True
config
.
pad_vocab_size_multiple
=
8
*
world_size
config
.
sequence_parallel
=
False
# Need to set this to False for generation
os
.
environ
[
"NCCL_ASYNC_ERROR_HANDLING"
]
=
"0"
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
,
init_method
=
'env://'
)
device
=
f
'cuda:
{
torch
.
distributed
.
get_rank
()
}
'
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
torch
.
cuda
.
set_device
(
device
)
from
apex.transformer
import
parallel_state
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
process_group
=
parallel_state
.
get_tensor_model_parallel_group
()
# if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test.
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
strict
=
not
rotary
,
device
=
device
,
dtype
=
dtype
,
process_group
=
process_group
,
world_size
=
world_size
,
rank
=
rank
)
model
.
eval
()
if
not
rotary
:
model_ref
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
to
(
device
=
device
)
model_hf
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
to
(
device
=
device
,
dtype
=
dtype
)
model_ref
.
eval
()
model_hf
.
eval
()
torch
.
manual_seed
(
0
)
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
input_ids
=
tokenizer
(
"Hello, my dog is cute and "
,
return_tensors
=
"pt"
).
input_ids
.
to
(
device
=
device
)
max_length
=
30
# input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
# Slow generation for reference
sequences
=
[]
scores
=
[]
cur_input_ids
=
input_ids
with
torch
.
inference_mode
():
logits
,
_
=
all_gather_raw
(
model
(
cur_input_ids
).
logits
[:,
-
1
],
process_group
)
logits
=
rearrange
(
logits
,
'(n b) d -> b (n d)'
,
b
=
input_ids
.
shape
[
0
])[...,
:
config
.
vocab_size
]
scores
.
append
(
logits
)
sequences
.
append
(
scores
[
-
1
].
argmax
(
dim
=-
1
))
for
_
in
range
(
input_ids
.
shape
[
1
]
+
1
,
max_length
):
cur_input_ids
=
torch
.
cat
([
cur_input_ids
,
rearrange
(
sequences
[
-
1
],
'b -> b 1'
)],
dim
=-
1
)
logits
,
_
=
all_gather_raw
(
model
(
cur_input_ids
).
logits
[:,
-
1
],
process_group
)
logits
=
rearrange
(
logits
,
'(n b) d -> b (n d)'
,
b
=
input_ids
.
shape
[
0
])[...,
:
config
.
vocab_size
]
scores
.
append
(
logits
)
sequences
.
append
(
scores
[
-
1
].
argmax
(
dim
=-
1
))
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
)
scores
=
tuple
(
scores
)
print
(
sequences
)
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
)
print
(
out
.
sequences
)
if
fused_ft_kernel
:
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
)
print
(
out_cg
.
sequences
)
if
not
rotary
:
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
out_ref
=
model_ref
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
print
(
f
'Scores max diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
'
)
print
(
f
'Scores mean diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
assert
torch
.
all
(
out
.
sequences
==
sequences
)
assert
torch
.
allclose
(
torch
.
stack
(
out
.
scores
,
dim
=
1
),
torch
.
stack
(
scores
,
dim
=
1
),
rtol
=
rtol
,
atol
=
atol
)
if
not
rotary
:
assert
torch
.
all
(
out
.
sequences
==
out_ref
.
sequences
)
assert
torch
.
all
(
out
.
sequences
==
out_hf
.
sequences
)
assert
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
<
3
*
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
tests/models/test_gpt_parallel.py
View file @
88173a1a
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_parallel.py
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
...
...
@@ -59,10 +61,12 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
n_positions
=
seqlen
if
has_pos_emb
else
0
,
vocab_size
=
50257
,
resid_pdrop
=
0.0
,
embd_pdrop
=
0.0
,
attn_pdrop
=
0.0
,
scale_attn_by_inverse_layer_idx
=
True
,
use_flash_attn
=
True
,
fused_dense_gelu_dense
=
True
,
fused_bias_fc
=
True
,
fused_dropout_add_ln
=
True
,
fused_mlp
=
True
,
fused_bias_fc
=
True
,
fused_dropout_add_ln
=
True
,
residual_in_fp32
=
True
,
rotary_emb_fraction
=
0.0
if
has_pos_emb
else
0.5
,
pad_vocab_size_multiple
=
8
*
world_size
,
sequence_parallel
=
sequence_parallel
)
config
.
vocab_size
=
math
.
ceil
(
config
.
vocab_size
/
(
8
*
world_size
))
*
(
8
*
world_size
)
model_pt
=
GPTLMHeadModel
(
config
,
device
=
device
)
def
init_layer_norm
(
module
):
...
...
@@ -131,9 +135,9 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
grad_dict
[
'transformer.embeddings.position_embeddings.weight'
],
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
model
.
transformer
.
ln_
0
.
weight
.
grad
,
grad_dict
[
'transformer.ln_
0
.weight'
],
assert
torch
.
allclose
(
model
.
transformer
.
ln_
f
.
weight
.
grad
,
grad_dict
[
'transformer.ln_
f
.weight'
],
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
model
.
transformer
.
ln_
0
.
bias
.
grad
,
grad_dict
[
'transformer.ln_
0
.bias'
],
assert
torch
.
allclose
(
model
.
transformer
.
ln_
f
.
bias
.
grad
,
grad_dict
[
'transformer.ln_
f
.bias'
],
rtol
=
rtol
,
atol
=
atol
)
for
i
in
range
(
num_layers
):
assert
torch
.
allclose
(
...
...
tests/models/test_vit.py
View file @
88173a1a
...
...
@@ -8,11 +8,11 @@ from timm.models.vision_transformer import vit_base_patch16_224
from
flash_attn.models.vit
import
vit_base_patch16_224
as
flash_vit_base_patch16_224
@
pytest
.
mark
.
parametrize
(
'fused_
dense_gelu_dense
'
,
[
False
,
True
])
# @pytest.mark.parametrize('fused_
dense_gelu_dense
', [False])
@
pytest
.
mark
.
parametrize
(
'fused_
mlp
'
,
[
False
,
True
])
# @pytest.mark.parametrize('fused_
mlp
', [False])
@
pytest
.
mark
.
parametrize
(
'optimized'
,
[
False
,
True
])
# @pytest.mark.parametrize('optimized', [True])
def
test_vit
(
optimized
,
fused_
dense_gelu_dense
):
def
test_vit
(
optimized
,
fused_
mlp
):
"""Check that our implementation of ViT matches the timm's implementation:
the output of our forward pass in fp16 should be around the same as
timm' forward pass in fp16, when compared to timm's forward pass in fp32.
...
...
@@ -23,7 +23,7 @@ def test_vit(optimized, fused_dense_gelu_dense):
kwargs
=
{}
if
optimized
:
kwargs
=
dict
(
use_flash_attn
=
True
,
fused_bias_fc
=
True
,
fused_dropout_add_ln
=
True
)
kwargs
[
'fused_
dense_gelu_dense'
]
=
fused_dense_gelu_dense
kwargs
[
'fused_
mlp'
]
=
fused_mlp
model
=
flash_vit_base_patch16_224
(
**
kwargs
).
to
(
device
=
device
,
dtype
=
dtype
)
model_ref
=
vit_base_patch16_224
(
pretrained
=
True
).
to
(
device
=
device
)
...
...
@@ -46,4 +46,5 @@ def test_vit(optimized, fused_dense_gelu_dense):
print
(
f
'Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'timm fp16 max diff:
{
(
out_timm
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'timm fp16 mean diff:
{
(
out_timm
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
3
*
(
out_timm
-
out_ref
).
abs
().
max
().
item
()
rtol
=
2
if
not
fused_mlp
else
4
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
rtol
*
(
out_timm
-
out_ref
).
abs
().
max
().
item
()
tests/modules/test_block_parallel.py
View file @
88173a1a
...
...
@@ -15,7 +15,7 @@ from apex.transformer import parallel_state
from
apex.transformer
import
tensor_parallel
from
flash_attn.modules.mha
import
MHA
,
ParallelMHA
from
flash_attn.modules.mlp
import
Fused
DenseGeluDense
,
ParallelFused
DenseGeluDense
from
flash_attn.modules.mlp
import
Fused
MLP
,
ParallelFused
MLP
from
flash_attn.modules.block
import
Block
from
flash_attn.utils.distributed
import
allreduce_sequence_parallel_grad
...
...
@@ -27,7 +27,7 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
,
8
])
# @pytest.mark.parametrize('world_size', [2])
@
pytest
.
mark
.
parametrize
(
'sequence_parallel'
,
[
True
,
False
])
# @pytest.mark.parametrize('sequence_parallel', [
Fals
e])
# @pytest.mark.parametrize('sequence_parallel', [
Tru
e])
@
pytest
.
mark
.
parametrize
(
'dim'
,
[
1024
])
def
test_block_parallel
(
dim
,
sequence_parallel
,
world_size
,
dtype
):
head_dim
=
64
...
...
@@ -62,8 +62,7 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
mixer_cls_pt
=
partial
(
MHA
,
num_heads
=
num_heads
,
rotary_emb_dim
=
int
(
head_dim
//
2
),
use_flash_attn
=
True
,
device
=
device
,
dtype
=
dtype
)
mlp_cls_pt
=
partial
(
FusedDenseGeluDense
,
hidden_features
=
4
*
dim
,
device
=
device
,
dtype
=
dtype
)
mlp_cls_pt
=
partial
(
FusedMLP
,
hidden_features
=
4
*
dim
,
device
=
device
,
dtype
=
dtype
)
norm_cls
=
partial
(
nn
.
LayerNorm
,
device
=
device
,
dtype
=
dtype
)
model_pt
=
Block
(
dim
,
mixer_cls_pt
,
mlp_cls_pt
,
norm_cls
,
fused_dropout_add_ln
=
True
)
with
torch
.
no_grad
():
...
...
@@ -76,7 +75,7 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
process_group
=
parallel_state
.
get_tensor_model_parallel_group
(),
rotary_emb_dim
=
int
(
head_dim
//
2
),
use_flash_attn
=
True
,
sequence_parallel
=
sequence_parallel
,
device
=
device
,
dtype
=
dtype
)
mlp_cls
=
partial
(
ParallelFused
DenseGeluDense
,
hidden_features
=
4
*
dim
,
mlp_cls
=
partial
(
ParallelFused
MLP
,
hidden_features
=
4
*
dim
,
process_group
=
parallel_state
.
get_tensor_model_parallel_group
(),
sequence_parallel
=
sequence_parallel
,
device
=
device
,
dtype
=
dtype
)
model
=
Block
(
dim
,
mixer_cls
,
mlp_cls
,
norm_cls
,
fused_dropout_add_ln
=
True
,
...
...
@@ -143,7 +142,7 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
x
.
grad
,
x_pt
.
grad
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
/
10
0
# magnitude of x.grad is quite small
rtol
=
rtol
,
atol
=
atol
/
10
# magnitude of x.grad is quite small
)
assert
torch
.
allclose
(
residual
.
grad
,
...
...
tests/ops/test_fused_dense.py
View file @
88173a1a
import
math
from
functools
import
partial
import
torch
import
torch.nn.functional
as
F
...
...
@@ -6,7 +7,7 @@ import pytest
from
einops
import
rearrange
from
flash_attn.ops.fused_dense
import
FusedDense
,
Fused
DenseGeluDense
from
flash_attn.ops.fused_dense
import
FusedDense
,
Fused
MLP
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
bfloat16
])
...
...
@@ -60,15 +61,25 @@ def test_fused_linear_bias(in_features, out_features, has_bias, return_residual,
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'heuristic'
,
[
0
,
-
1
])
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
'heuristic'
,
[
'auto'
,
-
1
])
# @pytest.mark.parametrize('heuristic', ['auto'])
@
pytest
.
mark
.
parametrize
(
'checkpoint_lvl'
,
[
0
,
1
,
2
])
# @pytest.mark.parametrize('checkpoint_lvl', [1])
@
pytest
.
mark
.
parametrize
(
'return_residual'
,
[
False
,
True
])
# @pytest.mark.parametrize('return_residual', [False])
@
pytest
.
mark
.
parametrize
(
'has_bias2'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'has_bias1'
,
[
True
,
False
])
# @pytest.mark.parametrize('has_bias2', [True])
# @pytest.mark.parametrize('has_bias1', [True])
@
pytest
.
mark
.
parametrize
(
'activation'
,
[
'gelu_approx'
,
'relu'
])
# @pytest.mark.parametrize('activation', ['relu'])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
,
4096
])
def
test_fused_dense_gelu_dense
(
in_features
,
out_features
,
has_bias1
,
has_bias2
,
return_residual
,
checkpoint_lvl
,
heuristic
,
dtype
):
# @pytest.mark.parametrize('out_features', [4096])
# @pytest.mark.parametrize('in_features', [1024])
def
test_fused_mlp
(
in_features
,
out_features
,
activation
,
has_bias1
,
has_bias2
,
return_residual
,
checkpoint_lvl
,
heuristic
,
dtype
):
device
=
'cuda'
rtol
,
atol
=
(
3e-3
,
3e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
# set seed
...
...
@@ -82,10 +93,10 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2,
dtype
=
dtype
)
model_pt_fc2
=
torch
.
nn
.
Linear
(
out_features
,
in_features
,
bias
=
has_bias2
,
device
=
device
,
dtype
=
dtype
)
model
=
Fused
DenseGeluDense
(
in_features
,
out_features
,
in_features
,
bias1
=
has_bias1
,
bias2
=
has_bias2
,
return_residual
=
return_residual
,
checkpoint_lvl
=
checkpoint_lvl
,
heuristic
=
heuristic
,
device
=
device
,
dtype
=
dtype
)
model
=
Fused
MLP
(
in_features
,
out_features
,
in_features
,
activation
=
activation
,
bias1
=
has_bias1
,
bias2
=
has_bias2
,
return_residual
=
return_residual
,
checkpoint_lvl
=
checkpoint_lvl
,
heuristic
=
heuristic
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
model
.
fc1
.
weight
.
copy_
(
model_pt_fc1
.
weight
)
if
has_bias1
:
...
...
@@ -93,7 +104,9 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2,
model
.
fc2
.
weight
.
copy_
(
model_pt_fc2
.
weight
)
if
has_bias2
:
model
.
fc2
.
bias
.
copy_
(
model_pt_fc2
.
bias
)
out_pt
=
model_pt_fc2
(
F
.
gelu
(
model_pt_fc1
(
x_pt
),
approximate
=
'tanh'
))
activation_fn
=
(
partial
(
F
.
gelu
,
approximate
=
'tanh'
)
if
activation
==
'gelu_approx'
else
partial
(
F
.
relu
,
inplace
=
True
))
out_pt
=
model_pt_fc2
(
activation_fn
(
model_pt_fc1
(
x_pt
)))
if
not
return_residual
:
out
=
model
(
x
)
else
:
...
...
@@ -107,6 +120,9 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2,
g
=
torch
.
randn_like
(
out
)
/
32
out_pt
.
backward
(
g
)
out
.
backward
(
g
)
# The error for relu is higher still
if
activation
==
'relu'
:
atol
=
1e-1
if
dtype
==
torch
.
bfloat16
else
5e-2
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
fc1
.
weight
.
grad
,
model_pt_fc1
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
...
...
tests/ops/test_fused_dense_parallel.py
View file @
88173a1a
...
...
@@ -10,8 +10,8 @@ import pytest
from
apex.transformer
import
parallel_state
from
apex.transformer
import
tensor_parallel
from
flash_attn.ops.fused_dense
import
FusedDense
,
Fused
DenseGeluDense
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
,
ParallelFused
DenseGeluDense
from
flash_attn.ops.fused_dense
import
FusedDense
,
Fused
MLP
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
,
ParallelFused
MLP
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)[
0
]
>=
8
...
...
@@ -106,8 +106,7 @@ def test_fused_linear_bias(in_features, out_features, has_bias, sequence_paralle
# @pytest.mark.parametrize('has_bias2', [True])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
])
def
test_fused_dense_gelu_dense
(
in_features
,
out_features
,
has_bias2
,
sequence_parallel
,
world_size
,
dtype
):
def
test_fused_mlp
(
in_features
,
out_features
,
has_bias2
,
sequence_parallel
,
world_size
,
dtype
):
assert
out_features
%
world_size
==
0
rtol
,
atol
=
(
3e-3
,
3e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
3e-3
)
if
not
torch
.
distributed
.
is_initialized
():
...
...
@@ -137,11 +136,11 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, sequence_p
dtype
=
dtype
)
partition_out_features
=
out_features
//
world_size
partition_in_features
=
in_features
//
world_size
model
=
ParallelFused
DenseGeluDense
(
in_features
,
out_features
,
in_features
,
process_group
=
parallel_state
.
get_tensor_model_parallel_group
(),
bias2
=
has_bias2
and
rank
==
0
,
sequence_parallel
=
sequence_parallel
,
device
=
device
,
dtype
=
dtype
)
model
=
ParallelFused
MLP
(
in_features
,
out_features
,
in_features
,
process_group
=
parallel_state
.
get_tensor_model_parallel_group
(),
bias2
=
has_bias2
and
rank
==
0
,
sequence_parallel
=
sequence_parallel
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
model
.
fc1
.
weight
.
copy_
(
...
...
training/README.md
View file @
88173a1a
...
...
@@ -48,7 +48,7 @@ config = GPT2Config(vocab_size=50257, n_positions=seqlen, n_embd=hidden_dim,
n_layer
=
n_layer
,
n_head
=
nheads
,
scale_attn_by_inverse_layer_idx
=
True
,
rotary_emb_fraction
=
rotary_emb_fraction
,
use_flash_attn
=
True
,
fused_
dense_gelu_dense
=
True
,
use_flash_attn
=
True
,
fused_
mlp
=
True
,
fused_bias_fc
=
True
,
fused_dropout_add_ln
=
True
,
pad_vocab_size_multiple
=
8
)
model
=
GPTLMHeadModel
(
config
)
...
...
training/configs/experiment/owt/gpt2s-flash.yaml
View file @
88173a1a
...
...
@@ -7,9 +7,10 @@ defaults:
model
:
config
:
# n_positions is already set to ${datamodule.max_length}
residual_in_fp32
:
True
use_flash_attn
:
True
fused_bias_fc
:
True
fused_
dense_gelu_dense
:
True
fused_
mlp
:
True
fused_dropout_add_ln
:
True
pad_vocab_size_multiple
:
8
...
...
training/configs/experiment/pile/gpt3s-flash.yaml
View file @
88173a1a
...
...
@@ -7,9 +7,10 @@ defaults:
model
:
config
:
# n_positions is already set to ${datamodule.max_length}
residual_in_fp32
:
True
use_flash_attn
:
True
fused_dropout_add_ln
:
True
fused_
dense_gelu_dense
:
True
fused_
mlp
:
True
fused_bias_fc
:
True
pad_vocab_size_multiple
:
8
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment