Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
88173a1a
"vscode:/vscode.git/clone" did not exist on "edfbb24a1b5c07afabd6ed6dd47db6df01968e45"
Commit
88173a1a
authored
Jan 17, 2023
by
Tri Dao
Browse files
[FusedDense] Support relu, rename FusedDenseGeluDense -> FusedMLP
parent
780e8eea
Changes
20
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
654 additions
and
779 deletions
+654
-779
csrc/fused_dense_lib/fused_dense.cpp
csrc/fused_dense_lib/fused_dense.cpp
+41
-45
csrc/fused_dense_lib/fused_dense_cuda.cu
csrc/fused_dense_lib/fused_dense_cuda.cu
+266
-478
flash_attn/models/bert.py
flash_attn/models/bert.py
+7
-7
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+18
-14
flash_attn/models/vit.py
flash_attn/models/vit.py
+8
-8
flash_attn/modules/block.py
flash_attn/modules/block.py
+2
-1
flash_attn/modules/mlp.py
flash_attn/modules/mlp.py
+2
-2
flash_attn/ops/fused_dense.py
flash_attn/ops/fused_dense.py
+111
-64
tests/models/test_bert.py
tests/models/test_bert.py
+6
-6
tests/models/test_gpt.py
tests/models/test_gpt.py
+1
-1
tests/models/test_gpt_generation.py
tests/models/test_gpt_generation.py
+5
-118
tests/models/test_gpt_generation_parallel.py
tests/models/test_gpt_generation_parallel.py
+131
-0
tests/models/test_gpt_parallel.py
tests/models/test_gpt_parallel.py
+7
-3
tests/models/test_vit.py
tests/models/test_vit.py
+6
-5
tests/modules/test_block_parallel.py
tests/modules/test_block_parallel.py
+5
-6
tests/ops/test_fused_dense.py
tests/ops/test_fused_dense.py
+25
-9
tests/ops/test_fused_dense_parallel.py
tests/ops/test_fused_dense_parallel.py
+8
-9
training/README.md
training/README.md
+1
-1
training/configs/experiment/owt/gpt2s-flash.yaml
training/configs/experiment/owt/gpt2s-flash.yaml
+2
-1
training/configs/experiment/pile/gpt3s-flash.yaml
training/configs/experiment/pile/gpt3s-flash.yaml
+2
-1
No files found.
csrc/fused_dense_lib/fused_dense.cpp
View file @
88173a1a
...
@@ -28,19 +28,19 @@
...
@@ -28,19 +28,19 @@
}
}
template
<
typename
T
>
template
<
typename
T
>
int
linear_bias_wgrad_cuda
(
T
*
input
,
T
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
T
*
d_weight
,
T
*
d_bias
,
void
*
lt_workspace
);
int
linear_bias_wgrad_cuda
(
const
T
*
input
,
const
T
*
d_output
,
int
64_t
in_features
,
int
64_t
batch_size
,
int
64_t
out_features
,
T
*
d_weight
,
T
*
d_bias
);
template
<
typename
T
>
template
<
typename
T
>
int
linear_
gelu
_forward_cuda
(
T
*
input
,
T
*
weight
,
T
*
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
int
heuristic
,
T
*
output
,
T
*
gelu_in
,
void
*
lt_workspace
)
;
int
linear_
act
_forward_cuda
(
const
T
*
input
,
const
T
*
weight
,
const
T
*
bias
,
int
64_t
in_features
,
int
64_t
batch_size
,
int
64_t
out_features
,
bool
is_gelu
,
int
heuristic
,
T
*
output
,
void
*
pre_act
)
;
template
<
typename
T
>
template
<
typename
T
>
int
bias_
gelu
_linear_dgrad_bgrad_cuda
(
T
*
weight
,
T
*
d_output
,
T
*
gelu_in
,
int
in_features
,
int
batch_size
,
int
out_features
,
int
heuristic
,
T
*
d_input
,
T
*
d_bias
,
void
*
lt_workspace
);
int
bias_
act
_linear_dgrad_bgrad_cuda
(
const
T
*
weight
,
const
T
*
d_output
,
const
void
*
pre_act
,
int
64_t
in_features
,
int
64_t
batch_size
,
int
64_t
out_features
,
bool
is_gelu
,
int
heuristic
,
T
*
d_input
,
T
*
d_bias
);
std
::
vector
<
at
::
Tensor
>
linear_bias_wgrad
(
at
::
Tensor
input
,
at
::
Tensor
d_output
,
bool
has_d_bias
)
{
std
::
vector
<
at
::
Tensor
>
linear_bias_wgrad
(
at
::
Tensor
input
,
at
::
Tensor
d_output
,
bool
has_d_bias
)
{
int
batch_size
=
input
.
size
(
0
);
int
64_t
batch_size
=
input
.
size
(
0
);
int
in_features
=
input
.
size
(
1
);
int
64_t
in_features
=
input
.
size
(
1
);
int
out_features
=
d_output
.
size
(
1
);
int
64_t
out_features
=
d_output
.
size
(
1
);
TORCH_CHECK
(
input
.
dtype
()
==
torch
::
kFloat16
||
input
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
input
.
dtype
()
==
torch
::
kFloat16
||
input
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
input
.
dtype
()
==
d_output
.
dtype
());
TORCH_CHECK
(
input
.
dtype
()
==
d_output
.
dtype
());
...
@@ -66,8 +66,6 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output,
...
@@ -66,8 +66,6 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output,
d_bias
=
at
::
empty
({
out_features
},
opts
);
d_bias
=
at
::
empty
({
out_features
},
opts
);
#endif
#endif
}
}
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_bias_wgrad"
,
[
&
]
{
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_bias_wgrad"
,
[
&
]
{
auto
result
=
linear_bias_wgrad_cuda
<
scalar_t
>
(
auto
result
=
linear_bias_wgrad_cuda
<
scalar_t
>
(
...
@@ -77,21 +75,20 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output,
...
@@ -77,21 +75,20 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output,
batch_size
,
batch_size
,
out_features
,
out_features
,
d_weight
.
data_ptr
<
scalar_t
>
(),
d_weight
.
data_ptr
<
scalar_t
>
(),
has_d_bias
?
d_bias
.
data_ptr
<
scalar_t
>
()
:
nullptr
,
has_d_bias
?
d_bias
.
data_ptr
<
scalar_t
>
()
:
nullptr
);
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
TORCH_CHECK
(
result
==
0
,
"linear_bias_wgrad failed."
);
TORCH_CHECK
(
result
==
0
,
"linear_bias_wgrad failed."
);
});
});
return
{
d_weight
,
d_bias
};
return
{
d_weight
,
d_bias
};
}
}
std
::
vector
<
at
::
Tensor
>
linear_
gelu
_forward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
std
::
vector
<
at
::
Tensor
>
linear_
act
_forward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
c10
::
optional
<
at
::
Tensor
>
bias_
,
c10
::
optional
<
at
::
Tensor
>
bias_
,
bool
save_
gelu_in
,
int
heuristic
)
{
bool
is_gelu
,
bool
save_
pre_act
,
int
heuristic
)
{
int
batch_size
=
input
.
size
(
0
);
int
64_t
batch_size
=
input
.
size
(
0
);
int
in_features
=
input
.
size
(
1
);
int
64_t
in_features
=
input
.
size
(
1
);
int
out_features
=
weight
.
size
(
0
);
int
64_t
out_features
=
weight
.
size
(
0
);
TORCH_CHECK
(
input
.
dtype
()
==
torch
::
kFloat16
||
input
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
input
.
dtype
()
==
torch
::
kFloat16
||
input
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
input
.
dtype
()
==
weight
.
dtype
());
TORCH_CHECK
(
input
.
dtype
()
==
weight
.
dtype
());
...
@@ -116,51 +113,52 @@ std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight,
...
@@ -116,51 +113,52 @@ std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight,
// create output/workspace tensor
// create output/workspace tensor
auto
opts
=
input
.
options
();
auto
opts
=
input
.
options
();
auto
output
=
at
::
empty
({
batch_size
,
out_features
},
opts
);
auto
output
=
at
::
empty
({
batch_size
,
out_features
},
opts
);
at
::
Tensor
gelu_in
;
at
::
Tensor
pre_act
;
if
(
save_gelu_in
)
{
gelu_in
=
at
::
empty
({
batch_size
,
out_features
},
opts
);
}
// If ReLU, cuBlasLT stores a bit-mask (1 bit per element)
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
if
(
save_pre_act
)
{
pre_act
=
at
::
empty
({
batch_size
,
is_gelu
?
out_features
:
out_features
/
8
},
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
is_gelu
?
opts
:
opts
.
dtype
(
torch
::
kUInt8
));
}
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_
gelu
_forward"
,
[
&
]
{
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_
act
_forward"
,
[
&
]
{
auto
result
=
linear_
gelu
_forward_cuda
<
scalar_t
>
(
auto
result
=
linear_
act
_forward_cuda
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
bias_
.
has_value
()
?
bias_
.
value
().
data_ptr
<
scalar_t
>
()
:
nullptr
,
bias_
.
has_value
()
?
bias_
.
value
().
data_ptr
<
scalar_t
>
()
:
nullptr
,
in_features
,
in_features
,
batch_size
,
batch_size
,
out_features
,
out_features
,
is_gelu
,
heuristic
,
heuristic
,
output
.
data_ptr
<
scalar_t
>
(),
output
.
data_ptr
<
scalar_t
>
(),
save_gelu_in
?
gelu_in
.
data_ptr
<
scalar_t
>
()
:
nullptr
,
save_pre_act
?
pre_act
.
data_ptr
()
:
nullptr
);
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
TORCH_CHECK
(
result
==
0
,
"linear_act_forward failed."
);
TORCH_CHECK
(
result
==
0
,
"linear_gelu_forward failed."
);
});
});
std
::
vector
<
at
::
Tensor
>
result
=
{
output
};
std
::
vector
<
at
::
Tensor
>
result
=
{
output
};
if
(
save_
gelu_in
)
{
result
.
push_back
(
gelu_in
);
};
if
(
save_
pre_act
)
{
result
.
push_back
(
pre_act
);
};
return
result
;
return
result
;
}
}
std
::
vector
<
at
::
Tensor
>
bias_
gelu
_linear_dgrad_bgrad
(
std
::
vector
<
at
::
Tensor
>
bias_
act
_linear_dgrad_bgrad
(
at
::
Tensor
weight
,
at
::
Tensor
d_output
,
at
::
Tensor
gelu
_in
,
int
heuristic
at
::
Tensor
weight
,
at
::
Tensor
d_output
,
at
::
Tensor
pre_act
,
bool
is_
gelu
,
int
heuristic
)
{
)
{
int
batch_size
=
d_output
.
size
(
0
);
int
64_t
batch_size
=
d_output
.
size
(
0
);
int
out_features
=
d_output
.
size
(
1
);
int
64_t
out_features
=
d_output
.
size
(
1
);
int
in_features
=
weight
.
size
(
1
);
int
64_t
in_features
=
weight
.
size
(
1
);
TORCH_CHECK
(
weight
.
dtype
()
==
torch
::
kFloat16
||
weight
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
weight
.
dtype
()
==
torch
::
kFloat16
||
weight
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
weight
.
dtype
()
==
d_output
.
dtype
());
TORCH_CHECK
(
weight
.
dtype
()
==
d_output
.
dtype
());
TORCH_CHECK
(
weight
.
dtype
()
==
gelu_in
.
dtype
(
));
TORCH_CHECK
(
is_gelu
?
(
pre_act
.
dtype
()
==
weight
.
dtype
())
:
(
pre_act
.
dtype
()
==
torch
::
kUInt8
));
TORCH_CHECK
(
weight
.
is_cuda
());
TORCH_CHECK
(
weight
.
is_cuda
());
TORCH_CHECK
(
d_output
.
is_cuda
());
TORCH_CHECK
(
d_output
.
is_cuda
());
TORCH_CHECK
(
gelu_in
.
is_cuda
());
TORCH_CHECK
(
pre_act
.
is_cuda
());
TORCH_CHECK
(
weight
.
is_contiguous
());
TORCH_CHECK
(
weight
.
is_contiguous
());
TORCH_CHECK
(
d_output
.
is_contiguous
());
TORCH_CHECK
(
d_output
.
is_contiguous
());
TORCH_CHECK
(
gelu_in
.
is_contiguous
());
TORCH_CHECK
(
pre_act
.
is_contiguous
());
CHECK_SHAPE
(
weight
,
out_features
,
in_features
);
CHECK_SHAPE
(
weight
,
out_features
,
in_features
);
CHECK_SHAPE
(
d_output
,
batch_size
,
out_features
);
CHECK_SHAPE
(
d_output
,
batch_size
,
out_features
);
CHECK_SHAPE
(
gelu_in
,
batch_size
,
in_features
);
// If ReLU, cuBlasLT stores a bit-mask (1 bit per element)
CHECK_SHAPE
(
pre_act
,
batch_size
,
is_gelu
?
in_features
:
in_features
/
8
);
// Otherwise the kernel will be launched from cuda:0 device
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
// Cast to char to avoid compiler warning about narrowing
...
@@ -170,22 +168,20 @@ std::vector<at::Tensor> bias_gelu_linear_dgrad_bgrad(
...
@@ -170,22 +168,20 @@ std::vector<at::Tensor> bias_gelu_linear_dgrad_bgrad(
auto
opts
=
weight
.
options
();
auto
opts
=
weight
.
options
();
auto
d_bias
=
at
::
empty
({
in_features
},
opts
);
auto
d_bias
=
at
::
empty
({
in_features
},
opts
);
auto
d_input
=
at
::
empty
({
batch_size
,
in_features
},
opts
);
auto
d_input
=
at
::
empty
({
batch_size
,
in_features
},
opts
);
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
DISPATCH_HALF_AND_BF16
(
weight
.
scalar_type
(),
"bias_
gelu
_linear_dgrad_bgrad"
,
[
&
]
{
DISPATCH_HALF_AND_BF16
(
weight
.
scalar_type
(),
"bias_
act
_linear_dgrad_bgrad"
,
[
&
]
{
auto
result
=
bias_
gelu
_linear_dgrad_bgrad_cuda
<
scalar_t
>
(
auto
result
=
bias_
act
_linear_dgrad_bgrad_cuda
<
scalar_t
>
(
weight
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
d_output
.
data_ptr
<
scalar_t
>
(),
d_output
.
data_ptr
<
scalar_t
>
(),
gelu_in
.
data_ptr
<
scalar_t
>
(),
pre_act
.
data_ptr
(),
in_features
,
in_features
,
batch_size
,
batch_size
,
out_features
,
out_features
,
is_gelu
,
heuristic
,
heuristic
,
d_input
.
data_ptr
<
scalar_t
>
(),
d_input
.
data_ptr
<
scalar_t
>
(),
d_bias
.
data_ptr
<
scalar_t
>
(),
d_bias
.
data_ptr
<
scalar_t
>
());
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
TORCH_CHECK
(
result
==
0
,
"bias_act_linear_dgrad_bgrad failed."
);
TORCH_CHECK
(
result
==
0
,
"bias_gelu_linear_dgrad_bgrad failed."
);
});
});
return
{
d_input
,
d_bias
};
return
{
d_input
,
d_bias
};
...
@@ -193,6 +189,6 @@ std::vector<at::Tensor> bias_gelu_linear_dgrad_bgrad(
...
@@ -193,6 +189,6 @@ std::vector<at::Tensor> bias_gelu_linear_dgrad_bgrad(
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"linear_bias_wgrad"
,
&
linear_bias_wgrad
,
"linear bias wgrad"
);
m
.
def
(
"linear_bias_wgrad"
,
&
linear_bias_wgrad
,
"linear bias wgrad"
);
m
.
def
(
"linear_
gelu
_forward"
,
&
linear_
gelu
_forward
,
"linear gelu forward"
);
m
.
def
(
"linear_
act
_forward"
,
&
linear_
act
_forward
,
"linear gelu
/relu
forward"
);
m
.
def
(
"bias_
gelu
_linear_dgrad_bgrad"
,
&
bias_
gelu
_linear_dgrad_bgrad
,
"bias gelu linear dgrad bgrad"
);
m
.
def
(
"bias_
act
_linear_dgrad_bgrad"
,
&
bias_
act
_linear_dgrad_bgrad
,
"bias gelu
/relu
linear dgrad bgrad"
);
}
}
csrc/fused_dense_lib/fused_dense_cuda.cu
View file @
88173a1a
This diff is collapsed.
Click to expand it.
flash_attn/models/bert.py
View file @
88173a1a
...
@@ -23,7 +23,7 @@ from transformers.models.bert.modeling_bert import BertForPreTrainingOutput
...
@@ -23,7 +23,7 @@ from transformers.models.bert.modeling_bert import BertForPreTrainingOutput
from
einops
import
rearrange
from
einops
import
rearrange
from
flash_attn.modules.mha
import
MHA
from
flash_attn.modules.mha
import
MHA
from
flash_attn.modules.mlp
import
Mlp
,
Fused
DenseGeluDense
from
flash_attn.modules.mlp
import
Mlp
,
Fused
MLP
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.embedding
import
BertEmbeddings
from
flash_attn.modules.embedding
import
BertEmbeddings
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
...
@@ -61,24 +61,24 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
...
@@ -61,24 +61,24 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
def
create_mlp_cls
(
config
,
layer_idx
=
None
,
return_residual
=
False
):
def
create_mlp_cls
(
config
,
layer_idx
=
None
,
return_residual
=
False
):
inner_dim
=
config
.
intermediate_size
inner_dim
=
config
.
intermediate_size
fused_
dense_gelu_dense
=
getattr
(
config
,
'fused_
dense_gelu_dense
'
,
False
)
fused_
mlp
=
getattr
(
config
,
'fused_
mlp
'
,
False
)
if
fused_
dense_gelu_dense
:
if
fused_
mlp
:
assert
config
.
hidden_act
in
[
'gelu_new'
,
'gelu_fast'
],
(
'fused_
dense_gelu_dense
only '
assert
config
.
hidden_act
in
[
'gelu_new'
,
'gelu_fast'
],
(
'fused_
mlp
only '
'supports approximate gelu'
)
'supports approximate gelu'
)
if
not
fused_
dense_gelu_dense
:
if
not
fused_
mlp
:
approximate
=
'tanh'
if
config
.
hidden_act
in
[
'gelu_new'
,
'gelu_fast'
]
else
'none'
approximate
=
'tanh'
if
config
.
hidden_act
in
[
'gelu_new'
,
'gelu_fast'
]
else
'none'
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
inner_dim
,
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
inner_dim
,
activation
=
partial
(
F
.
gelu
,
approximate
=
approximate
),
activation
=
partial
(
F
.
gelu
,
approximate
=
approximate
),
return_residual
=
return_residual
)
return_residual
=
return_residual
)
else
:
else
:
if
Fused
DenseGeluDense
is
None
:
if
Fused
MLP
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
raise
ImportError
(
'fused_dense is not installed'
)
mlp_checkpoint_lvl
=
getattr
(
config
,
'mlp_checkpoint_lvl'
,
0
)
mlp_checkpoint_lvl
=
getattr
(
config
,
'mlp_checkpoint_lvl'
,
0
)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
if
isinstance
(
mlp_checkpoint_lvl
,
Sequence
):
if
isinstance
(
mlp_checkpoint_lvl
,
Sequence
):
assert
layer_idx
is
not
None
assert
layer_idx
is
not
None
mlp_checkpoint_lvl
=
mlp_checkpoint_lvl
[
layer_idx
]
mlp_checkpoint_lvl
=
mlp_checkpoint_lvl
[
layer_idx
]
mlp_cls
=
partial
(
Fused
DenseGeluDense
,
hidden_features
=
inner_dim
,
mlp_cls
=
partial
(
Fused
MLP
,
hidden_features
=
inner_dim
,
checkpoint_lvl
=
mlp_checkpoint_lvl
,
return_residual
=
return_residual
)
checkpoint_lvl
=
mlp_checkpoint_lvl
,
return_residual
=
return_residual
)
return
mlp_cls
return
mlp_cls
...
...
flash_attn/models/gpt.py
View file @
88173a1a
...
@@ -17,7 +17,7 @@ from transformers import GPT2Config
...
@@ -17,7 +17,7 @@ from transformers import GPT2Config
from
einops
import
rearrange
from
einops
import
rearrange
from
flash_attn.modules.mha
import
MHA
,
ParallelMHA
from
flash_attn.modules.mha
import
MHA
,
ParallelMHA
from
flash_attn.modules.mlp
import
Mlp
,
Fused
DenseGeluDense
,
ParallelFused
DenseGeluDense
from
flash_attn.modules.mlp
import
Mlp
,
Fused
MLP
,
ParallelFused
MLP
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.embedding
import
GPT2Embeddings
,
ParallelGPT2Embeddings
from
flash_attn.modules.embedding
import
GPT2Embeddings
,
ParallelGPT2Embeddings
from
flash_attn.utils.distributed
import
sync_shared_params
,
all_gather_raw
from
flash_attn.utils.distributed
import
sync_shared_params
,
all_gather_raw
...
@@ -77,22 +77,22 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
...
@@ -77,22 +77,22 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
def
create_mlp_cls
(
config
,
layer_idx
=
None
,
process_group
=
None
,
device
=
None
,
dtype
=
None
):
def
create_mlp_cls
(
config
,
layer_idx
=
None
,
process_group
=
None
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
config
.
hidden_size
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
config
.
hidden_size
fused_dense_gelu_dense
=
getattr
(
config
,
'fused_dense_gelu_dense'
,
False
)
fused_mlp
=
getattr
(
config
,
'fused_mlp'
,
False
)
if
fused_dense_gelu_dense
:
if
fused_mlp
:
assert
config
.
activation_function
in
[
'gelu_new'
,
'gelu_fast'
],
(
'fused_dense_gelu_dense only '
assert
config
.
activation_function
in
[
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
,
'relu'
]
'supports approximate gelu'
)
fused_dense_sqrelu_dense
=
getattr
(
config
,
'fused_dense_sqrelu_dense'
,
False
)
fused_dense_sqrelu_dense
=
getattr
(
config
,
'fused_dense_sqrelu_dense'
,
False
)
if
fused_dense_sqrelu_dense
:
if
fused_dense_sqrelu_dense
:
assert
config
.
activation_function
==
'sqrelu'
,
(
'fused_dense_sqrelu_dense only '
assert
config
.
activation_function
==
'sqrelu'
,
(
'fused_dense_sqrelu_dense only '
'supports approximate activation_function sqrelu'
)
'supports approximate activation_function sqrelu'
)
assert
not
(
fused_dense_sqrelu_dense
and
fused_
dense_gelu_dense
)
assert
not
(
fused_dense_sqrelu_dense
and
fused_
mlp
)
if
process_group
is
not
None
:
if
process_group
is
not
None
:
assert
fused_
dense_gelu_dense
,
'Tensor Parallel is only implemented for Fused
DenseGeluDense
'
assert
fused_
mlp
,
'Tensor Parallel is only implemented for Fused
MLP
'
if
not
fused_
dense_gelu_dense
and
not
fused_dense_sqrelu_dense
:
if
not
fused_
mlp
and
not
fused_dense_sqrelu_dense
:
if
config
.
activation_function
==
'relu'
:
if
config
.
activation_function
==
'relu'
:
activation
=
partial
(
F
.
relu
,
inplace
=
True
)
activation
=
partial
(
F
.
relu
,
inplace
=
True
)
else
:
else
:
approximate
=
'tanh'
if
config
.
activation_function
in
[
'gelu_new'
,
'gelu_fast'
]
else
'none'
approximate
=
(
'tanh'
if
config
.
activation_function
in
[
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
]
else
'none'
)
activation
=
partial
(
F
.
gelu
,
approximate
=
approximate
)
activation
=
partial
(
F
.
gelu
,
approximate
=
approximate
)
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
inner_dim
,
activation
=
activation
,
**
factory_kwargs
)
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
inner_dim
,
activation
=
activation
,
**
factory_kwargs
)
else
:
else
:
...
@@ -101,14 +101,17 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
...
@@ -101,14 +101,17 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
if
isinstance
(
mlp_checkpoint_lvl
,
Sequence
):
if
isinstance
(
mlp_checkpoint_lvl
,
Sequence
):
assert
layer_idx
is
not
None
assert
layer_idx
is
not
None
mlp_checkpoint_lvl
=
mlp_checkpoint_lvl
[
layer_idx
]
mlp_checkpoint_lvl
=
mlp_checkpoint_lvl
[
layer_idx
]
if
fused_
dense_gelu_dense
:
if
fused_
mlp
:
if
Fused
DenseGeluDense
is
None
:
if
Fused
MLP
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
raise
ImportError
(
'fused_dense is not installed'
)
mlp_cls
=
FusedDenseGeluDense
if
process_group
is
None
else
ParallelFusedDenseGeluDense
activation
=
(
'gelu_approx'
if
config
.
activation_function
in
[
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
]
else
'relu'
)
mlp_cls
=
FusedMLP
if
process_group
is
None
else
ParallelFusedMLP
parallel_kwargs
=
({
'process_group'
:
process_group
,
parallel_kwargs
=
({
'process_group'
:
process_group
,
'sequence_parallel'
:
getattr
(
config
,
'sequence_parallel'
,
True
)}
'sequence_parallel'
:
getattr
(
config
,
'sequence_parallel'
,
True
)}
if
process_group
is
not
None
else
{})
if
process_group
is
not
None
else
{})
mlp_cls
=
partial
(
mlp_cls
,
hidden_features
=
inner_dim
,
checkpoint_lvl
=
mlp_checkpoint_lvl
,
mlp_cls
=
partial
(
mlp_cls
,
hidden_features
=
inner_dim
,
activation
=
activation
,
checkpoint_lvl
=
mlp_checkpoint_lvl
,
**
parallel_kwargs
,
**
factory_kwargs
)
**
parallel_kwargs
,
**
factory_kwargs
)
elif
fused_dense_sqrelu_dense
:
elif
fused_dense_sqrelu_dense
:
assert
FusedDenseSqreluDense
is
not
None
assert
FusedDenseSqreluDense
is
not
None
...
@@ -210,7 +213,8 @@ class GPTModel(GPTPreTrainedModel):
...
@@ -210,7 +213,8 @@ class GPTModel(GPTPreTrainedModel):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
self
.
process_group
=
process_group
self
.
process_group
=
process_group
self
.
sequence_parallel
=
getattr
(
config
,
'sequence_parallel'
,
True
)
self
.
sequence_parallel
=
getattr
(
config
,
'sequence_parallel'
,
True
)
assert
config
.
activation_function
in
[
'gelu'
,
'gelu_new'
,
'gelu_fast'
,
'relu'
,
'sqrelu'
]
assert
config
.
activation_function
in
[
'gelu'
,
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
,
'relu'
,
'sqrelu'
]
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
...
...
flash_attn/models/vit.py
View file @
88173a1a
...
@@ -20,7 +20,7 @@ from timm.models.helpers import named_apply
...
@@ -20,7 +20,7 @@ from timm.models.helpers import named_apply
from
flash_attn.layers.patch_embed
import
PatchEmbed
from
flash_attn.layers.patch_embed
import
PatchEmbed
from
flash_attn.modules.mha
import
MHA
from
flash_attn.modules.mha
import
MHA
from
flash_attn.modules.mlp
import
Mlp
,
Fused
DenseGeluDense
from
flash_attn.modules.mlp
import
Mlp
,
Fused
MLP
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.block
import
Block
try
:
try
:
...
@@ -37,22 +37,22 @@ def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_
...
@@ -37,22 +37,22 @@ def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_
return
mixer_cls
return
mixer_cls
def
create_mlp_cls
(
embed_dim
,
mlp_ratio
,
act_layer
,
fused_
dense_gelu_dense
):
def
create_mlp_cls
(
embed_dim
,
mlp_ratio
,
act_layer
,
fused_
mlp
):
inner_dim
=
int
(
embed_dim
*
mlp_ratio
)
inner_dim
=
int
(
embed_dim
*
mlp_ratio
)
if
not
fused_
dense_gelu_dense
:
if
not
fused_
mlp
:
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
inner_dim
,
activation
=
act_layer
())
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
inner_dim
,
activation
=
act_layer
())
else
:
else
:
mlp_cls
=
partial
(
Fused
DenseGeluDense
,
hidden_features
=
inner_dim
)
mlp_cls
=
partial
(
Fused
MLP
,
hidden_features
=
inner_dim
)
return
mlp_cls
return
mlp_cls
def
create_block
(
embed_dim
,
num_heads
,
mlp_ratio
,
qkv_bias
,
drop_rate
,
attn_drop_rate
,
def
create_block
(
embed_dim
,
num_heads
,
mlp_ratio
,
qkv_bias
,
drop_rate
,
attn_drop_rate
,
drop_path1
,
drop_path2
,
norm_layer
,
act_layer
,
use_flash_attn
,
fused_bias_fc
,
drop_path1
,
drop_path2
,
norm_layer
,
act_layer
,
use_flash_attn
,
fused_bias_fc
,
fused_
dense_gelu_dense
,
fused_dropout_add_ln
,
layer_idx
=
None
,
n_layer
=
None
,
fused_
mlp
,
fused_dropout_add_ln
,
layer_idx
=
None
,
n_layer
=
None
,
last_layer_subset
=
False
):
last_layer_subset
=
False
):
mixer_cls
=
create_mixer_cls
(
num_heads
,
qkv_bias
,
attn_drop_rate
,
use_flash_attn
,
fused_bias_fc
,
mixer_cls
=
create_mixer_cls
(
num_heads
,
qkv_bias
,
attn_drop_rate
,
use_flash_attn
,
fused_bias_fc
,
cross_attn
=
(
last_layer_subset
and
layer_idx
==
n_layer
-
1
))
cross_attn
=
(
last_layer_subset
and
layer_idx
==
n_layer
-
1
))
mlp_cls
=
create_mlp_cls
(
embed_dim
,
mlp_ratio
,
act_layer
,
fused_
dense_gelu_dense
)
mlp_cls
=
create_mlp_cls
(
embed_dim
,
mlp_ratio
,
act_layer
,
fused_
mlp
)
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
block
=
Block
(
embed_dim
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_layer
,
block
=
Block
(
embed_dim
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_layer
,
prenorm
=
True
,
resid_dropout1
=
drop_rate
,
resid_dropout2
=
drop_rate
,
prenorm
=
True
,
resid_dropout1
=
drop_rate
,
resid_dropout2
=
drop_rate
,
...
@@ -92,7 +92,7 @@ class VisionTransformer(nn.Module):
...
@@ -92,7 +92,7 @@ class VisionTransformer(nn.Module):
act_layer
=
None
,
act_layer
=
None
,
use_flash_attn
=
False
,
use_flash_attn
=
False
,
fused_bias_fc
=
False
,
fused_bias_fc
=
False
,
fused_
dense_gelu_dense
=
False
,
fused_
mlp
=
False
,
fused_dropout_add_ln
=
False
,
fused_dropout_add_ln
=
False
,
):
):
"""
"""
...
@@ -164,7 +164,7 @@ class VisionTransformer(nn.Module):
...
@@ -164,7 +164,7 @@ class VisionTransformer(nn.Module):
embed_dim
,
num_heads
,
mlp_ratio
,
qkv_bias
,
drop_rate
,
attn_drop_rate
,
embed_dim
,
num_heads
,
mlp_ratio
,
qkv_bias
,
drop_rate
,
attn_drop_rate
,
drop_path1
=
dpr
[
i
-
1
]
if
i
>
0
else
0.
,
drop_path2
=
dpr
[
i
],
drop_path1
=
dpr
[
i
-
1
]
if
i
>
0
else
0.
,
drop_path2
=
dpr
[
i
],
norm_layer
=
norm_layer
,
act_layer
=
act_layer
,
use_flash_attn
=
use_flash_attn
,
norm_layer
=
norm_layer
,
act_layer
=
act_layer
,
use_flash_attn
=
use_flash_attn
,
fused_bias_fc
=
fused_bias_fc
,
fused_
dense_gelu_dense
=
fused_dense_gelu_dense
,
fused_bias_fc
=
fused_bias_fc
,
fused_
mlp
=
fused_mlp
,
fused_dropout_add_ln
=
fused_dropout_add_ln
,
layer_idx
=
i
,
n_layer
=
depth
,
fused_dropout_add_ln
=
fused_dropout_add_ln
,
layer_idx
=
i
,
n_layer
=
depth
,
last_layer_subset
=
(
global_pool
==
'token'
)
last_layer_subset
=
(
global_pool
==
'token'
)
)
for
i
in
range
(
depth
)])
)
for
i
in
range
(
depth
)])
...
...
flash_attn/modules/block.py
View file @
88173a1a
...
@@ -121,7 +121,8 @@ class Block(nn.Module):
...
@@ -121,7 +121,8 @@ class Block(nn.Module):
)
)
if
mixer_kwargs
is
None
:
if
mixer_kwargs
is
None
:
mixer_kwargs
=
{}
mixer_kwargs
=
{}
mixer_kwargs
[
'mixer_subset'
]
=
mixer_subset
if
mixer_subset
is
not
None
:
mixer_kwargs
[
'mixer_subset'
]
=
mixer_subset
hidden_states
=
self
.
mixer
(
hidden_states
,
**
mixer_kwargs
)
hidden_states
=
self
.
mixer
(
hidden_states
,
**
mixer_kwargs
)
if
mixer_subset
is
not
None
:
if
mixer_subset
is
not
None
:
residual
=
residual
[:,
mixer_subset
]
residual
=
residual
[:,
mixer_subset
]
...
...
flash_attn/modules/mlp.py
View file @
88173a1a
...
@@ -5,9 +5,9 @@ import torch.nn as nn
...
@@ -5,9 +5,9 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
try
:
try
:
from
flash_attn.ops.fused_dense
import
Fused
DenseGeluDense
,
ParallelFused
DenseGeluDense
from
flash_attn.ops.fused_dense
import
Fused
MLP
,
ParallelFused
MLP
except
ImportError
:
except
ImportError
:
Fused
DenseGeluDense
,
ParallelFused
DenseGeluDense
=
None
,
None
Fused
MLP
,
ParallelFused
MLP
=
None
,
None
class
Mlp
(
nn
.
Module
):
class
Mlp
(
nn
.
Module
):
...
...
flash_attn/ops/fused_dense.py
View file @
88173a1a
# Copyright (c) 202
2
, Tri Dao.
# Copyright (c) 202
3
, Tri Dao.
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
# We make it work with pytorch amp and with bfloat16.
# We make it work with pytorch amp and with bfloat16.
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
from
typing
import
Optional
from
typing
import
Optional
from
functools
import
partial
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -19,6 +20,11 @@ from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all
...
@@ -19,6 +20,11 @@ from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all
from
flash_attn.utils.distributed
import
reduce_scatter
,
all_reduce
from
flash_attn.utils.distributed
import
reduce_scatter
,
all_reduce
@
torch
.
jit
.
script
def
relu_bwd
(
g
,
x
):
return
torch
.
where
(
x
>=
0
,
g
,
0.0
).
to
(
dtype
=
x
.
dtype
)
class
FusedDenseFunc
(
torch
.
autograd
.
Function
):
class
FusedDenseFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
...
@@ -185,12 +191,13 @@ class RowParallelLinear(nn.Linear):
...
@@ -185,12 +191,13 @@ class RowParallelLinear(nn.Linear):
return
reduce_fn
(
out
,
self
.
process_group
)
return
reduce_fn
(
out
,
self
.
process_group
)
class
Fused
DenseGeluDense
Func
(
torch
.
autograd
.
Function
):
class
Fused
MLP
Func
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
@
custom_fwd
@
custom_fwd
def
forward
(
ctx
,
x
,
weight1
,
bias1
,
weight2
,
bias2
,
save_pre_act
=
True
,
return_residual
=
False
,
def
forward
(
ctx
,
x
,
weight1
,
bias1
,
weight2
,
bias2
,
activation
=
'gelu_approx'
,
save_pre_act
=
True
,
checkpoint_lvl
=
0
,
heuristic
=
0
,
process_group
=
None
,
sequence_parallel
=
True
):
return_residual
=
False
,
checkpoint_lvl
=
0
,
heuristic
=
0
,
process_group
=
None
,
sequence_parallel
=
True
):
"""
"""
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
with sequence parallelism: we do an all_gather of x before doing the matmul.
with sequence parallelism: we do an all_gather of x before doing the matmul.
...
@@ -198,10 +205,11 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
...
@@ -198,10 +205,11 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
checkpoint_lvl:
checkpoint_lvl:
0: no recomputation in the bwd
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
1: recompute gelu_out
/ relu_out
in the bwd
2: recompute
gelu_in and g
elu_out in the bwd
2: recompute
pre_act and gelu_out / r
elu_out in the bwd
"""
"""
assert
-
1
<=
heuristic
<=
4
assert
-
1
<=
heuristic
<=
4
assert
activation
in
[
'gelu_approx'
,
'relu'
]
if
not
save_pre_act
:
if
not
save_pre_act
:
checkpoint_lvl
=
2
checkpoint_lvl
=
2
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
...
@@ -209,6 +217,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
...
@@ -209,6 +217,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
ctx
.
process_group
=
process_group
ctx
.
process_group
=
process_group
ctx
.
sequence_parallel
=
sequence_parallel
ctx
.
sequence_parallel
=
sequence_parallel
ctx
.
checkpoint_lvl
=
checkpoint_lvl
ctx
.
checkpoint_lvl
=
checkpoint_lvl
ctx
.
activation
=
activation
ctx
.
heuristic
=
heuristic
ctx
.
heuristic
=
heuristic
if
torch
.
is_autocast_enabled
():
if
torch
.
is_autocast_enabled
():
...
@@ -237,23 +246,27 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
...
@@ -237,23 +246,27 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
if
min
(
batch_dim
,
n
,
*
weight1
.
shape
,
*
weight2
.
shape
)
>
65535
*
32
:
if
min
(
batch_dim
,
n
,
*
weight1
.
shape
,
*
weight2
.
shape
)
>
65535
*
32
:
raise
RuntimeError
(
'fused_dense only supports matrix dims <= 2M'
)
raise
RuntimeError
(
'fused_dense only supports matrix dims <= 2M'
)
if
heuristic
==
-
1
:
if
heuristic
==
-
1
:
gelu_in
=
F
.
linear
(
total_x
,
weight1
,
bias1
)
pre_act
=
F
.
linear
(
total_x
,
weight1
,
bias1
)
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
activation_fn
=
(
partial
(
F
.
gelu
,
approximate
=
'tanh'
)
if
activation
==
'gelu_approx'
else
F
.
relu
)
output1
=
activation_fn
(
pre_act
)
# This is before adding bias1
# This is before adding bias1
#
gelu_in
= F.linear(total_x.reshape(batch_dim, n), weight1)
#
pre_act
= F.linear(total_x.reshape(batch_dim, n), weight1)
# with torch.jit.fuser('fuser2'):
# with torch.jit.fuser('fuser2'):
# output1 = bias_gelu(
gelu_in
, bias1)
# output1 = bias_gelu(
pre_act
, bias1)
else
:
else
:
output1
,
*
rest
=
fused_dense_cuda
.
linear_gelu_forward
(
is_gelu
=
activation
==
'gelu_approx'
total_x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
save_pre_act
,
heuristic
output1
,
*
rest
=
fused_dense_cuda
.
linear_act_forward
(
total_x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
is_gelu
,
save_pre_act
,
heuristic
)
)
if
save_pre_act
:
if
save_pre_act
:
gelu_in
=
rest
[
0
]
pre_act
=
rest
[
0
]
output2
=
F
.
linear
(
output1
,
weight2
,
bias2
)
output2
=
F
.
linear
(
output1
,
weight2
,
bias2
)
if
checkpoint_lvl
==
0
:
if
checkpoint_lvl
==
0
or
(
checkpoint_lvl
==
1
and
activation
==
'relu'
):
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
gelu_in
,
output1
)
# For RELU the pre_act is very small (just a bit-mask) so we just save it
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
pre_act
,
output1
)
elif
checkpoint_lvl
==
1
:
elif
checkpoint_lvl
==
1
:
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
gelu_in
)
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
pre_act
)
elif
checkpoint_lvl
==
2
:
elif
checkpoint_lvl
==
2
:
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
bias1
)
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
bias1
)
output2
=
output2
.
reshape
(
*
batch_shape
,
output2
.
shape
[
-
1
])
output2
=
output2
.
reshape
(
*
batch_shape
,
output2
.
shape
[
-
1
])
...
@@ -264,6 +277,9 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
...
@@ -264,6 +277,9 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
def
backward
(
ctx
,
grad_output
,
*
args
):
def
backward
(
ctx
,
grad_output
,
*
args
):
grad_output
=
grad_output
.
contiguous
()
grad_output
=
grad_output
.
contiguous
()
checkpoint_lvl
=
ctx
.
checkpoint_lvl
checkpoint_lvl
=
ctx
.
checkpoint_lvl
activation
=
ctx
.
activation
activation_fn
=
(
partial
(
F
.
gelu
,
approximate
=
'tanh'
)
if
activation
==
'gelu_approx'
else
F
.
relu
)
if
ctx
.
return_residual
:
if
ctx
.
return_residual
:
grad_input
,
=
args
grad_input
,
=
args
grad_input
=
grad_input
.
contiguous
()
grad_input
=
grad_input
.
contiguous
()
...
@@ -277,27 +293,27 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
...
@@ -277,27 +293,27 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
if
checkpoint_lvl
in
[
0
,
1
]:
if
checkpoint_lvl
in
[
0
,
1
]:
if
process_group
is
not
None
and
sequence_parallel
:
if
process_group
is
not
None
and
sequence_parallel
:
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
if
checkpoint_lvl
==
0
:
if
checkpoint_lvl
==
0
or
(
checkpoint_lvl
==
1
and
activation
==
'relu'
)
:
gelu_in
,
output1
=
rest
pre_act
,
output1
=
rest
elif
checkpoint_lvl
==
1
:
elif
checkpoint_lvl
==
1
:
gelu_in
,
=
rest
pre_act
,
=
rest
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
output1
=
activation_fn
(
pre_act
)
elif
checkpoint_lvl
==
2
:
elif
checkpoint_lvl
==
2
:
bias1
,
=
rest
bias1
,
=
rest
if
process_group
is
not
None
and
sequence_parallel
:
if
process_group
is
not
None
and
sequence_parallel
:
total_x
,
_
=
all_gather_raw
(
x
,
process_group
)
total_x
,
_
=
all_gather_raw
(
x
,
process_group
)
if
ctx
.
heuristic
==
-
1
:
if
ctx
.
heuristic
==
-
1
:
gelu_in
=
F
.
linear
(
total_x
,
weight1
,
bias1
)
pre_act
=
F
.
linear
(
total_x
,
weight1
,
bias1
)
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
output1
=
activation_fn
(
pre_act
)
else
:
else
:
output1
,
gelu_in
=
fused_dense_cuda
.
linear_
gelu
_forward
(
output1
,
pre_act
=
fused_dense_cuda
.
linear_
act
_forward
(
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]),
weight1
,
bias1
,
True
,
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]),
weight1
,
bias1
,
ctx
.
heuristic
activation
==
'gelu_approx'
,
True
,
ctx
.
heuristic
)
)
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
output1
=
output1
.
reshape
(
batch_dim
,
output1
.
shape
[
-
1
])
output1
=
output1
.
reshape
(
batch_dim
,
output1
.
shape
[
-
1
])
gelu_in
=
gelu_in
.
reshape
(
batch_dim
,
gelu_in
.
shape
[
-
1
])
pre_act
=
pre_act
.
reshape
(
batch_dim
,
pre_act
.
shape
[
-
1
])
if
ctx
.
needs_input_grad
[
3
]:
if
ctx
.
needs_input_grad
[
3
]:
grad_weight2
,
grad_bias2
=
fused_dense_cuda
.
linear_bias_wgrad
(
grad_weight2
,
grad_bias2
=
fused_dense_cuda
.
linear_bias_wgrad
(
output1
,
grad_output
,
ctx
.
needs_input_grad
[
4
]
output1
,
grad_output
,
ctx
.
needs_input_grad
[
4
]
...
@@ -306,24 +322,25 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
...
@@ -306,24 +322,25 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
grad_weight2
=
None
grad_weight2
=
None
grad_bias2
=
grad_output
if
ctx
.
needs_input_grad
[
4
]
else
None
grad_bias2
=
grad_output
if
ctx
.
needs_input_grad
[
4
]
else
None
if
ctx
.
heuristic
==
-
1
:
if
ctx
.
heuristic
==
-
1
:
# grad_
gelu
= matmul_dgelu(grad_output, weight2,
gelu_in
)
# grad_
pre_act
= matmul_dgelu(grad_output, weight2,
pre_act
)
grad_output1
=
F
.
linear
(
grad_output
,
weight2
.
t
())
grad_output1
=
F
.
linear
(
grad_output
,
weight2
.
t
())
with
torch
.
jit
.
fuser
(
'fuser2'
):
with
torch
.
jit
.
fuser
(
'fuser2'
):
grad_gelu
=
gelu_bwd
(
grad_output1
,
gelu_in
)
activation_grad_fn
=
gelu_bwd
if
activation
==
'gelu_approx'
else
relu_bwd
grad_pre_act
=
activation_grad_fn
(
grad_output1
,
pre_act
)
else
:
else
:
# The cublasLt epilogue has to compute both gelu grad and bias grad, we can't
# The cublasLt epilogue has to compute both gelu
/relu
grad and bias grad, we can't
# just compute gelu grad
# just compute gelu
/relu
grad
grad_
gelu
,
grad_bias1
=
fused_dense_cuda
.
bias_
gelu
_linear_dgrad_bgrad
(
grad_
pre_act
,
grad_bias1
=
fused_dense_cuda
.
bias_
act
_linear_dgrad_bgrad
(
weight2
,
grad_output
,
gelu_in
,
ctx
.
heuristic
weight2
,
grad_output
,
pre_act
,
activation
==
'gelu_approx'
,
ctx
.
heuristic
)
)
if
not
ctx
.
needs_input_grad
[
2
]:
if
not
ctx
.
needs_input_grad
[
2
]:
grad_bias1
=
None
grad_bias1
=
None
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
needs_input_grad
[
0
]:
if
not
ctx
.
return_residual
:
if
not
ctx
.
return_residual
:
grad_input
=
F
.
linear
(
grad_
gelu
,
weight1
.
t
())
grad_input
=
F
.
linear
(
grad_
pre_act
,
weight1
.
t
())
else
:
else
:
grad_input
=
torch
.
addmm
(
grad_input
.
reshape
(
batch_dim
,
grad_input
.
shape
[
-
1
]),
grad_input
=
torch
.
addmm
(
grad_input
.
reshape
(
batch_dim
,
grad_input
.
shape
[
-
1
]),
grad_
gelu
,
weight1
)
grad_
pre_act
,
weight1
)
grad_input
=
grad_input
.
reshape
(
*
batch_shape
,
grad_input
.
shape
[
-
1
])
grad_input
=
grad_input
.
reshape
(
*
batch_shape
,
grad_input
.
shape
[
-
1
])
if
process_group
is
not
None
:
if
process_group
is
not
None
:
reduce_fn
=
reduce_scatter_raw
if
sequence_parallel
else
all_reduce_raw
reduce_fn
=
reduce_scatter_raw
if
sequence_parallel
else
all_reduce_raw
...
@@ -335,55 +352,60 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
...
@@ -335,55 +352,60 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
if
process_group
is
not
None
and
sequence_parallel
:
if
process_group
is
not
None
and
sequence_parallel
:
handle_x
.
wait
()
handle_x
.
wait
()
grad_weight1
,
grad_bias1
=
fused_dense_cuda
.
linear_bias_wgrad
(
grad_weight1
,
grad_bias1
=
fused_dense_cuda
.
linear_bias_wgrad
(
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]),
grad_
gelu
,
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]),
grad_
pre_act
,
ctx
.
needs_input_grad
[
2
]
ctx
.
needs_input_grad
[
2
]
)
)
else
:
else
:
grad_weight1
=
None
grad_weight1
=
None
grad_bias1
=
grad_
gelu
if
ctx
.
needs_input_grad
[
2
]
else
None
grad_bias1
=
grad_
pre_act
if
ctx
.
needs_input_grad
[
2
]
else
None
else
:
else
:
if
ctx
.
needs_input_grad
[
1
]:
if
ctx
.
needs_input_grad
[
1
]:
if
process_group
is
not
None
and
sequence_parallel
:
if
process_group
is
not
None
and
sequence_parallel
:
handle_x
.
wait
()
handle_x
.
wait
()
grad_weight1
=
F
.
linear
(
grad_
gelu
.
t
(),
grad_weight1
=
F
.
linear
(
grad_
pre_act
.
t
(),
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]).
t
())
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]).
t
())
else
:
else
:
grad_weight1
=
None
grad_weight1
=
None
if
process_group
is
not
None
and
ctx
.
needs_input_grad
[
0
]:
if
process_group
is
not
None
and
ctx
.
needs_input_grad
[
0
]:
handle_grad_input
.
wait
()
handle_grad_input
.
wait
()
return
(
grad_input
,
grad_weight1
,
grad_bias1
,
grad_weight2
,
grad_bias2
,
return
(
grad_input
,
grad_weight1
,
grad_bias1
,
grad_weight2
,
grad_bias2
,
None
,
None
,
None
,
None
,
None
,
None
)
None
,
None
,
None
,
None
,
None
,
None
,
None
)
def
fused_
dense_gelu_dense
_func
(
def
fused_
mlp
_func
(
x
:
Tensor
,
weight1
:
Tensor
,
weight2
:
Tensor
,
bias1
:
Optional
[
Tensor
]
=
None
,
x
:
Tensor
,
weight1
:
Tensor
,
weight2
:
Tensor
,
bias1
:
Optional
[
Tensor
]
=
None
,
bias2
:
Optional
[
Tensor
]
=
None
,
bias2
:
Optional
[
Tensor
]
=
None
,
activation
:
str
=
'gelu_approx'
,
save_pre_act
:
bool
=
True
,
return_residual
:
bool
=
False
,
save_pre_act
:
bool
=
True
,
return_residual
:
bool
=
False
,
checkpoint_lvl
:
int
=
0
,
heuristic
:
int
=
0
,
checkpoint_lvl
:
int
=
0
,
heuristic
:
int
=
0
,
process_group
:
Optional
[
ProcessGroup
]
=
None
,
process_group
:
Optional
[
ProcessGroup
]
=
None
,
sequence_parallel
:
bool
=
True
sequence_parallel
:
bool
=
True
):
):
assert
activation
in
[
'gelu_approx'
,
'relu'
]
dtype_eligible
=
(
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
dtype_eligible
=
(
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
or
(
x
.
dtype
==
torch
.
float32
and
torch
.
is_autocast_enabled
()))
or
(
x
.
dtype
==
torch
.
float32
and
torch
.
is_autocast_enabled
()))
# If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu)
dim_eligible
=
not
save_pre_act
or
(
x
.
shape
[
-
1
]
%
(
128
if
activation
==
'relu'
else
8
)
==
0
)
if
(
x
.
is_cuda
and
weight1
.
is_cuda
and
weight2
.
is_cuda
and
(
bias1
is
None
or
bias1
.
is_cuda
)
if
(
x
.
is_cuda
and
weight1
.
is_cuda
and
weight2
.
is_cuda
and
(
bias1
is
None
or
bias1
.
is_cuda
)
and
(
bias2
is
None
or
bias2
.
is_cuda
)
and
dtype_eligible
):
and
(
bias2
is
None
or
bias2
.
is_cuda
)
and
dtype_eligible
and
dim_eligible
):
return
Fused
DenseGeluDense
Func
.
apply
(
return
Fused
MLP
Func
.
apply
(
x
,
weight1
,
bias1
,
weight2
,
bias2
,
save_pre_act
,
return_residual
,
x
,
weight1
,
bias1
,
weight2
,
bias2
,
activation
,
save_pre_act
,
return_residual
,
checkpoint_lvl
,
heuristic
,
process_group
,
sequence_parallel
checkpoint_lvl
,
heuristic
,
process_group
,
sequence_parallel
)
)
else
:
else
:
assert
process_group
is
None
assert
process_group
is
None
gelu_in
=
F
.
linear
(
x
,
weight1
,
bias1
)
pre_act
=
F
.
linear
(
x
,
weight1
,
bias1
)
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
activation_fn
=
(
partial
(
F
.
gelu
,
approximate
=
'tanh'
)
if
activation
==
'gelu_approx'
else
partial
(
F
.
relu
,
inplace
=
True
))
output1
=
activation_fn
(
pre_act
)
output2
=
F
.
linear
(
output1
,
weight2
,
bias2
)
output2
=
F
.
linear
(
output1
,
weight2
,
bias2
)
return
output2
if
not
return_residual
else
(
output2
,
x
)
return
output2
if
not
return_residual
else
(
output2
,
x
)
class
Fused
DenseGeluDense
(
nn
.
Module
):
class
Fused
MLP
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
,
out_features
=
None
,
bias1
=
True
,
def
__init__
(
self
,
in_features
,
hidden_features
,
out_features
=
None
,
bias1
=
True
,
bias2
=
True
,
return_residual
=
False
,
checkpoint_lvl
=
0
,
heuristic
=
0
,
bias2
=
True
,
activation
=
'gelu_approx'
,
return_residual
=
False
,
device
=
None
,
dtype
=
None
):
checkpoint_lvl
=
0
,
heuristic
=
'auto'
,
device
=
None
,
dtype
=
None
):
"""
"""
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul, gelu, then matmul.
we do an all_gather of x before doing the matmul, gelu, then matmul.
...
@@ -392,21 +414,24 @@ class FusedDenseGeluDense(nn.Module):
...
@@ -392,21 +414,24 @@ class FusedDenseGeluDense(nn.Module):
checkpoint_lvl (increasing lvl means slower but more memory saving):
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
1: recompute gelu_out in the bwd
2: recompute
gelu_in
and gelu_out in the bwd
2: recompute
pre_act
and gelu_out in the bwd
heuristic:
heuristic:
-1: don't fuse gemm + gelu (separate kernel)
-1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu
0..4: use this heuristic for the algo section in the fused gemm + gelu
For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
'auto': heuristic will be picked automatically:
For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
return_residual: whether to return the input x along with the output. This is for
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
to fuse the backward of nn.Linear with the residual connection.
"""
"""
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
assert
activation
in
[
'gelu_approx'
,
'relu'
]
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
super
().
__init__
()
if
out_features
is
None
:
if
out_features
is
None
:
out_features
=
in_features
out_features
=
in_features
self
.
activation
=
activation
self
.
return_residual
=
return_residual
self
.
return_residual
=
return_residual
self
.
checkpoint_lvl
=
checkpoint_lvl
self
.
checkpoint_lvl
=
checkpoint_lvl
self
.
heuristic
=
heuristic
self
.
heuristic
=
heuristic
...
@@ -414,11 +439,20 @@ class FusedDenseGeluDense(nn.Module):
...
@@ -414,11 +439,20 @@ class FusedDenseGeluDense(nn.Module):
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias2
,
**
factory_kwargs
)
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias2
,
**
factory_kwargs
)
def
forward
(
self
,
x
,
process_group
=
None
):
def
forward
(
self
,
x
,
process_group
=
None
):
out
=
fused_dense_gelu_dense_func
(
dtype
=
x
.
dtype
if
not
torch
.
is_autocast_enabled
()
else
torch
.
get_autocast_gpu_dtype
()
if
self
.
heuristic
==
'auto'
:
if
self
.
activation
==
'gelu_approx'
:
cuda_ver
=
tuple
(
map
(
int
,
torch
.
version
.
cuda
.
split
(
'.'
)))
heuristic
=
0
if
cuda_ver
>=
(
11
,
8
)
else
(
1
if
dtype
==
torch
.
float16
else
-
1
)
else
:
heuristic
=
0
else
:
heuristic
=
self
.
heuristic
out
=
fused_mlp_func
(
x
,
self
.
fc1
.
weight
,
self
.
fc2
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
bias
,
x
,
self
.
fc1
.
weight
,
self
.
fc2
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
bias
,
save_pre_act
=
self
.
training
,
return_residual
=
self
.
return_residual
,
activation
=
self
.
activation
,
save_pre_act
=
self
.
training
,
checkpoint_lvl
=
self
.
checkpoint_lvl
,
heuristic
=
self
.
he
uristic
,
return_residual
=
self
.
return_residual
,
checkpoint_lvl
=
self
.
c
he
ckpoint_lvl
,
process_group
=
process_group
heuristic
=
heuristic
,
process_group
=
process_group
)
)
if
self
.
return_residual
:
if
self
.
return_residual
:
out
,
x
=
out
out
,
x
=
out
...
@@ -427,11 +461,12 @@ class FusedDenseGeluDense(nn.Module):
...
@@ -427,11 +461,12 @@ class FusedDenseGeluDense(nn.Module):
return
out
if
not
self
.
return_residual
else
(
out
,
x
)
return
out
if
not
self
.
return_residual
else
(
out
,
x
)
class
ParallelFused
DenseGeluDense
(
nn
.
Module
):
class
ParallelFused
MLP
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
,
out_features
=
None
,
def
__init__
(
self
,
in_features
,
hidden_features
,
out_features
=
None
,
activation
=
'gelu_approx'
,
process_group
:
ProcessGroup
=
None
,
bias1
=
True
,
bias2
=
True
,
process_group
:
ProcessGroup
=
None
,
bias1
=
True
,
bias2
=
True
,
sequence_parallel
=
True
,
checkpoint_lvl
=
0
,
heuristic
=
0
,
device
=
None
,
dtype
=
None
):
sequence_parallel
=
True
,
checkpoint_lvl
=
0
,
heuristic
=
'auto'
,
device
=
None
,
dtype
=
None
):
"""
"""
process_group is required. We're doing Tensor Parallel with sequence parallelism:
process_group is required. We're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul, gelu, then matmul.
we do an all_gather of x before doing the matmul, gelu, then matmul.
...
@@ -440,19 +475,22 @@ class ParallelFusedDenseGeluDense(nn.Module):
...
@@ -440,19 +475,22 @@ class ParallelFusedDenseGeluDense(nn.Module):
checkpoint_lvl (increasing lvl means slower but more memory saving):
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
1: recompute gelu_out in the bwd
2: recompute
gelu_in
and gelu_out in the bwd
2: recompute
pre_act
and gelu_out in the bwd
heuristic:
heuristic:
-1: don't fuse gemm + gelu (separate kernel)
-1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu
0..4: use this heuristic for the algo section in the fused gemm + gelu
For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
'auto': heuristic will be picked automatically:
For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
"""
"""
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
assert
activation
in
[
'gelu_approx'
,
'relu'
]
assert
process_group
is
not
None
assert
process_group
is
not
None
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
super
().
__init__
()
if
out_features
is
None
:
if
out_features
is
None
:
out_features
=
in_features
out_features
=
in_features
self
.
activation
=
activation
self
.
process_group
=
process_group
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
self
.
sequence_parallel
=
sequence_parallel
self
.
checkpoint_lvl
=
checkpoint_lvl
self
.
checkpoint_lvl
=
checkpoint_lvl
...
@@ -463,10 +501,19 @@ class ParallelFusedDenseGeluDense(nn.Module):
...
@@ -463,10 +501,19 @@ class ParallelFusedDenseGeluDense(nn.Module):
bias
=
bias2
,
**
factory_kwargs
)
bias
=
bias2
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
out
=
fused_dense_gelu_dense_func
(
dtype
=
x
.
dtype
if
not
torch
.
is_autocast_enabled
()
else
torch
.
get_autocast_gpu_dtype
()
if
self
.
heuristic
==
'auto'
:
if
self
.
activation
==
'gelu_approx'
:
cuda_ver
=
tuple
(
map
(
int
,
torch
.
version
.
cuda
.
split
(
'.'
)))
heuristic
=
0
if
cuda_ver
>=
(
11
,
8
)
else
(
1
if
dtype
==
torch
.
float16
else
-
1
)
else
:
heuristic
=
0
else
:
heuristic
=
self
.
heuristic
out
=
fused_mlp_func
(
x
,
self
.
fc1
.
weight
,
self
.
fc2
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
bias
,
x
,
self
.
fc1
.
weight
,
self
.
fc2
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
bias
,
save_pre_act
=
self
.
training
,
checkpoint_lvl
=
self
.
checkpoint_lvl
,
activation
=
self
.
activation
,
save_pre_act
=
self
.
training
,
heuristic
=
self
.
heuristic
,
checkpoint_lvl
=
self
.
checkpoint_lvl
,
heuristic
=
heuristic
,
process_group
=
self
.
process_group
,
process_group
=
self
.
process_group
,
sequence_parallel
=
self
.
sequence_parallel
sequence_parallel
=
self
.
sequence_parallel
)
)
...
...
tests/models/test_bert.py
View file @
88173a1a
...
@@ -95,13 +95,13 @@ def test_bert_optimized(model_name):
...
@@ -95,13 +95,13 @@ def test_bert_optimized(model_name):
"""
"""
dtype
=
torch
.
float16
dtype
=
torch
.
float16
config
=
BertConfig
.
from_pretrained
(
model_name
)
config
=
BertConfig
.
from_pretrained
(
model_name
)
# Our implementation of fused_
dense_gelu_dense
assumes the activation is
# Our implementation of fused_
mlp
assumes the activation is
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
# If you just want "gelu", disable fused_
dense_gelu_dense
.
# If you just want "gelu", disable fused_
mlp
.
config
.
hidden_act
=
"gelu_new"
config
.
hidden_act
=
"gelu_new"
config
.
use_flash_attn
=
True
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_bias_fc
=
True
config
.
fused_
dense_gelu_dense
=
True
config
.
fused_
mlp
=
True
config
.
fused_dropout_add_ln
=
True
config
.
fused_dropout_add_ln
=
True
model
=
BertForPreTraining
.
from_pretrained
(
model_name
,
config
)
model
=
BertForPreTraining
.
from_pretrained
(
model_name
,
config
)
...
@@ -171,13 +171,13 @@ def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subs
...
@@ -171,13 +171,13 @@ def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subs
"""
"""
dtype
=
torch
.
float16
dtype
=
torch
.
float16
config
=
BertConfig
.
from_pretrained
(
model_name
)
config
=
BertConfig
.
from_pretrained
(
model_name
)
# Our implementation of fused_
dense_gelu_dense
assumes the activation is
# Our implementation of fused_
mlp
assumes the activation is
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
# If you just want "gelu", disable fused_
dense_gelu_dense
.
# If you just want "gelu", disable fused_
mlp
.
config
.
hidden_act
=
"gelu_new"
config
.
hidden_act
=
"gelu_new"
config
.
use_flash_attn
=
True
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_bias_fc
=
True
config
.
fused_
dense_gelu_dense
=
True
config
.
fused_
mlp
=
True
config
.
fused_dropout_add_ln
=
True
config
.
fused_dropout_add_ln
=
True
config
.
dense_seq_output
=
True
config
.
dense_seq_output
=
True
config
.
last_layer_subset
=
last_layer_subset
config
.
last_layer_subset
=
last_layer_subset
...
...
tests/models/test_gpt.py
View file @
88173a1a
...
@@ -82,7 +82,7 @@ def test_gpt2_optimized(model_name):
...
@@ -82,7 +82,7 @@ def test_gpt2_optimized(model_name):
vocab_size_og
=
config
.
vocab_size
vocab_size_og
=
config
.
vocab_size
config
.
use_flash_attn
=
True
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_bias_fc
=
True
config
.
fused_
dense_gelu_dense
=
True
config
.
fused_
mlp
=
True
config
.
fused_dropout_add_ln
=
True
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
config
.
residual_in_fp32
=
True
config
.
pad_vocab_size_multiple
=
8
config
.
pad_vocab_size_multiple
=
8
...
...
tests/models/test_gpt_generation.py
View file @
88173a1a
...
@@ -18,7 +18,7 @@ from flash_attn.utils.distributed import all_gather_raw
...
@@ -18,7 +18,7 @@ from flash_attn.utils.distributed import all_gather_raw
@
pytest
.
mark
.
parametrize
(
'fused_ft_kernel'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'fused_ft_kernel'
,
[
False
,
True
])
# @pytest.mark.parametrize('fused_ft_kernel', [True])
# @pytest.mark.parametrize('fused_ft_kernel', [True])
@
pytest
.
mark
.
parametrize
(
'optimized'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'optimized'
,
[
False
,
True
])
# @pytest.mark.parametrize('optimized', [
Tru
e])
# @pytest.mark.parametrize('optimized', [
Fals
e])
@
pytest
.
mark
.
parametrize
(
'rotary'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'rotary'
,
[
False
,
True
])
# @pytest.mark.parametrize('rotary', [False])
# @pytest.mark.parametrize('rotary', [False])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
])
...
@@ -34,10 +34,11 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
...
@@ -34,10 +34,11 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
if
rotary
:
if
rotary
:
config
.
n_positions
=
0
config
.
n_positions
=
0
config
.
rotary_emb_dim
=
64
config
.
rotary_emb_dim
=
64
config
.
residual_in_fp32
=
True
if
optimized
:
if
optimized
:
config
.
use_flash_attn
=
True
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_bias_fc
=
True
config
.
fused_
dense_gelu_dense
=
True
config
.
fused_
mlp
=
True
config
.
fused_dropout_add_ln
=
True
config
.
fused_dropout_add_ln
=
True
# if not rotary, we load the weight from HF but ignore the position embeddings.
# if not rotary, we load the weight from HF but ignore the position embeddings.
...
@@ -78,6 +79,7 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
...
@@ -78,6 +79,7 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
fused_ft_kernel
=
fused_ft_kernel
,
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
)
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
)
print
(
out
.
sequences
)
print
(
out
.
sequences
)
print
(
tokenizer
.
batch_decode
(
out
.
sequences
.
tolist
()))
if
fused_ft_kernel
:
if
fused_ft_kernel
:
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
...
@@ -94,122 +96,7 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
...
@@ -94,122 +96,7 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
print
(
f
'Scores mean diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
print
(
f
'Scores mean diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
print
(
tokenizer
.
batch_decode
(
out_ref
.
sequences
.
tolist
()))
assert
torch
.
all
(
out
.
sequences
==
sequences
)
assert
torch
.
allclose
(
torch
.
stack
(
out
.
scores
,
dim
=
1
),
torch
.
stack
(
scores
,
dim
=
1
),
rtol
=
rtol
,
atol
=
atol
)
if
not
rotary
:
assert
torch
.
all
(
out
.
sequences
==
out_ref
.
sequences
)
assert
torch
.
all
(
out
.
sequences
==
out_hf
.
sequences
)
assert
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
<
3
*
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_generation.py -k "parallel"
# @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
2
])
# @pytest.mark.parametrize('fused_ft_kernel', [False, True])
@
pytest
.
mark
.
parametrize
(
'fused_ft_kernel'
,
[
True
])
# @pytest.mark.parametrize('rotary', [False, True])
@
pytest
.
mark
.
parametrize
(
'rotary'
,
[
False
])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
])
def
test_tensor_parallel
(
model_name
,
rotary
,
fused_ft_kernel
,
world_size
):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
dtype
=
torch
.
float16
rtol
,
atol
=
3e-3
,
3e-1
config
=
GPT2Config
.
from_pretrained
(
model_name
)
if
rotary
:
config
.
n_positions
=
0
config
.
rotary_emb_dim
=
64
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_dense_gelu_dense
=
True
config
.
fused_dropout_add_ln
=
True
config
.
pad_vocab_size_multiple
=
8
*
world_size
config
.
sequence_parallel
=
False
# Need to set this to False for generation
os
.
environ
[
"NCCL_ASYNC_ERROR_HANDLING"
]
=
"0"
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
,
init_method
=
'env://'
)
device
=
f
'cuda:
{
torch
.
distributed
.
get_rank
()
}
'
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
torch
.
cuda
.
set_device
(
device
)
from
apex.transformer
import
parallel_state
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
process_group
=
parallel_state
.
get_tensor_model_parallel_group
()
# if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test.
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
strict
=
not
rotary
,
device
=
device
,
dtype
=
dtype
,
process_group
=
process_group
,
world_size
=
world_size
,
rank
=
rank
)
model
.
eval
()
if
not
rotary
:
model_ref
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
to
(
device
=
device
)
model_hf
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
to
(
device
=
device
,
dtype
=
dtype
)
model_ref
.
eval
()
model_hf
.
eval
()
torch
.
manual_seed
(
0
)
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
input_ids
=
tokenizer
(
"Hello, my dog is cute and "
,
return_tensors
=
"pt"
).
input_ids
.
to
(
device
=
device
)
max_length
=
30
# input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
# Slow generation for reference
sequences
=
[]
scores
=
[]
cur_input_ids
=
input_ids
with
torch
.
inference_mode
():
logits
,
_
=
all_gather_raw
(
model
(
cur_input_ids
).
logits
[:,
-
1
],
process_group
)
logits
=
rearrange
(
logits
,
'(n b) d -> b (n d)'
,
b
=
input_ids
.
shape
[
0
])[...,
:
config
.
vocab_size
]
scores
.
append
(
logits
)
sequences
.
append
(
scores
[
-
1
].
argmax
(
dim
=-
1
))
for
_
in
range
(
input_ids
.
shape
[
1
]
+
1
,
max_length
):
cur_input_ids
=
torch
.
cat
([
cur_input_ids
,
rearrange
(
sequences
[
-
1
],
'b -> b 1'
)],
dim
=-
1
)
logits
,
_
=
all_gather_raw
(
model
(
cur_input_ids
).
logits
[:,
-
1
],
process_group
)
logits
=
rearrange
(
logits
,
'(n b) d -> b (n d)'
,
b
=
input_ids
.
shape
[
0
])[...,
:
config
.
vocab_size
]
scores
.
append
(
logits
)
sequences
.
append
(
scores
[
-
1
].
argmax
(
dim
=-
1
))
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
)
scores
=
tuple
(
scores
)
print
(
sequences
)
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
)
print
(
out
.
sequences
)
if
fused_ft_kernel
:
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
)
print
(
out_cg
.
sequences
)
if
not
rotary
:
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
out_ref
=
model_ref
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
print
(
f
'Scores max diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
'
)
print
(
f
'Scores mean diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
assert
torch
.
all
(
out
.
sequences
==
sequences
)
assert
torch
.
all
(
out
.
sequences
==
sequences
)
assert
torch
.
allclose
(
torch
.
stack
(
out
.
scores
,
dim
=
1
),
torch
.
stack
(
scores
,
dim
=
1
),
assert
torch
.
allclose
(
torch
.
stack
(
out
.
scores
,
dim
=
1
),
torch
.
stack
(
scores
,
dim
=
1
),
...
...
tests/models/test_gpt_generation_parallel.py
0 → 100644
View file @
88173a1a
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_generation_parallel.py -k "parallel"
import
os
import
re
import
torch
import
pytest
from
einops
import
rearrange
from
transformers
import
GPT2Config
,
GPT2Tokenizer
from
transformers.models.gpt2.modeling_gpt2
import
GPT2LMHeadModel
as
GPT2LMHeadModelHF
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt
import
remap_state_dict_gpt2
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.distributed
import
all_gather_raw
# @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
2
])
# @pytest.mark.parametrize('fused_ft_kernel', [False, True])
@
pytest
.
mark
.
parametrize
(
'fused_ft_kernel'
,
[
True
])
# @pytest.mark.parametrize('rotary', [False, True])
@
pytest
.
mark
.
parametrize
(
'rotary'
,
[
False
])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
])
def
test_tensor_parallel
(
model_name
,
rotary
,
fused_ft_kernel
,
world_size
):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
dtype
=
torch
.
float16
rtol
,
atol
=
3e-3
,
3e-1
config
=
GPT2Config
.
from_pretrained
(
model_name
)
if
rotary
:
config
.
n_positions
=
0
config
.
rotary_emb_dim
=
64
config
.
residual_in_fp32
=
True
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
True
config
.
fused_dropout_add_ln
=
True
config
.
pad_vocab_size_multiple
=
8
*
world_size
config
.
sequence_parallel
=
False
# Need to set this to False for generation
os
.
environ
[
"NCCL_ASYNC_ERROR_HANDLING"
]
=
"0"
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
,
init_method
=
'env://'
)
device
=
f
'cuda:
{
torch
.
distributed
.
get_rank
()
}
'
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
torch
.
cuda
.
set_device
(
device
)
from
apex.transformer
import
parallel_state
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
process_group
=
parallel_state
.
get_tensor_model_parallel_group
()
# if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test.
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
strict
=
not
rotary
,
device
=
device
,
dtype
=
dtype
,
process_group
=
process_group
,
world_size
=
world_size
,
rank
=
rank
)
model
.
eval
()
if
not
rotary
:
model_ref
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
to
(
device
=
device
)
model_hf
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
to
(
device
=
device
,
dtype
=
dtype
)
model_ref
.
eval
()
model_hf
.
eval
()
torch
.
manual_seed
(
0
)
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
input_ids
=
tokenizer
(
"Hello, my dog is cute and "
,
return_tensors
=
"pt"
).
input_ids
.
to
(
device
=
device
)
max_length
=
30
# input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
# Slow generation for reference
sequences
=
[]
scores
=
[]
cur_input_ids
=
input_ids
with
torch
.
inference_mode
():
logits
,
_
=
all_gather_raw
(
model
(
cur_input_ids
).
logits
[:,
-
1
],
process_group
)
logits
=
rearrange
(
logits
,
'(n b) d -> b (n d)'
,
b
=
input_ids
.
shape
[
0
])[...,
:
config
.
vocab_size
]
scores
.
append
(
logits
)
sequences
.
append
(
scores
[
-
1
].
argmax
(
dim
=-
1
))
for
_
in
range
(
input_ids
.
shape
[
1
]
+
1
,
max_length
):
cur_input_ids
=
torch
.
cat
([
cur_input_ids
,
rearrange
(
sequences
[
-
1
],
'b -> b 1'
)],
dim
=-
1
)
logits
,
_
=
all_gather_raw
(
model
(
cur_input_ids
).
logits
[:,
-
1
],
process_group
)
logits
=
rearrange
(
logits
,
'(n b) d -> b (n d)'
,
b
=
input_ids
.
shape
[
0
])[...,
:
config
.
vocab_size
]
scores
.
append
(
logits
)
sequences
.
append
(
scores
[
-
1
].
argmax
(
dim
=-
1
))
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
)
scores
=
tuple
(
scores
)
print
(
sequences
)
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
)
print
(
out
.
sequences
)
if
fused_ft_kernel
:
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
)
print
(
out_cg
.
sequences
)
if
not
rotary
:
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
out_ref
=
model_ref
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
print
(
f
'Scores max diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
'
)
print
(
f
'Scores mean diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
assert
torch
.
all
(
out
.
sequences
==
sequences
)
assert
torch
.
allclose
(
torch
.
stack
(
out
.
scores
,
dim
=
1
),
torch
.
stack
(
scores
,
dim
=
1
),
rtol
=
rtol
,
atol
=
atol
)
if
not
rotary
:
assert
torch
.
all
(
out
.
sequences
==
out_ref
.
sequences
)
assert
torch
.
all
(
out
.
sequences
==
out_hf
.
sequences
)
assert
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
<
3
*
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
tests/models/test_gpt_parallel.py
View file @
88173a1a
# Run test with:
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_parallel.py
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_parallel.py
import
math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -59,10 +61,12 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
...
@@ -59,10 +61,12 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
n_positions
=
seqlen
if
has_pos_emb
else
0
,
n_positions
=
seqlen
if
has_pos_emb
else
0
,
vocab_size
=
50257
,
resid_pdrop
=
0.0
,
embd_pdrop
=
0.0
,
attn_pdrop
=
0.0
,
vocab_size
=
50257
,
resid_pdrop
=
0.0
,
embd_pdrop
=
0.0
,
attn_pdrop
=
0.0
,
scale_attn_by_inverse_layer_idx
=
True
,
use_flash_attn
=
True
,
scale_attn_by_inverse_layer_idx
=
True
,
use_flash_attn
=
True
,
fused_dense_gelu_dense
=
True
,
fused_bias_fc
=
True
,
fused_dropout_add_ln
=
True
,
fused_mlp
=
True
,
fused_bias_fc
=
True
,
fused_dropout_add_ln
=
True
,
residual_in_fp32
=
True
,
rotary_emb_fraction
=
0.0
if
has_pos_emb
else
0.5
,
rotary_emb_fraction
=
0.0
if
has_pos_emb
else
0.5
,
pad_vocab_size_multiple
=
8
*
world_size
,
pad_vocab_size_multiple
=
8
*
world_size
,
sequence_parallel
=
sequence_parallel
)
sequence_parallel
=
sequence_parallel
)
config
.
vocab_size
=
math
.
ceil
(
config
.
vocab_size
/
(
8
*
world_size
))
*
(
8
*
world_size
)
model_pt
=
GPTLMHeadModel
(
config
,
device
=
device
)
model_pt
=
GPTLMHeadModel
(
config
,
device
=
device
)
def
init_layer_norm
(
module
):
def
init_layer_norm
(
module
):
...
@@ -131,9 +135,9 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
...
@@ -131,9 +135,9 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
grad_dict
[
'transformer.embeddings.position_embeddings.weight'
],
grad_dict
[
'transformer.embeddings.position_embeddings.weight'
],
rtol
=
rtol
,
atol
=
atol
rtol
=
rtol
,
atol
=
atol
)
)
assert
torch
.
allclose
(
model
.
transformer
.
ln_
0
.
weight
.
grad
,
grad_dict
[
'transformer.ln_
0
.weight'
],
assert
torch
.
allclose
(
model
.
transformer
.
ln_
f
.
weight
.
grad
,
grad_dict
[
'transformer.ln_
f
.weight'
],
rtol
=
rtol
,
atol
=
atol
)
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
model
.
transformer
.
ln_
0
.
bias
.
grad
,
grad_dict
[
'transformer.ln_
0
.bias'
],
assert
torch
.
allclose
(
model
.
transformer
.
ln_
f
.
bias
.
grad
,
grad_dict
[
'transformer.ln_
f
.bias'
],
rtol
=
rtol
,
atol
=
atol
)
rtol
=
rtol
,
atol
=
atol
)
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
assert
torch
.
allclose
(
assert
torch
.
allclose
(
...
...
tests/models/test_vit.py
View file @
88173a1a
...
@@ -8,11 +8,11 @@ from timm.models.vision_transformer import vit_base_patch16_224
...
@@ -8,11 +8,11 @@ from timm.models.vision_transformer import vit_base_patch16_224
from
flash_attn.models.vit
import
vit_base_patch16_224
as
flash_vit_base_patch16_224
from
flash_attn.models.vit
import
vit_base_patch16_224
as
flash_vit_base_patch16_224
@
pytest
.
mark
.
parametrize
(
'fused_
dense_gelu_dense
'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'fused_
mlp
'
,
[
False
,
True
])
# @pytest.mark.parametrize('fused_
dense_gelu_dense
', [False])
# @pytest.mark.parametrize('fused_
mlp
', [False])
@
pytest
.
mark
.
parametrize
(
'optimized'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'optimized'
,
[
False
,
True
])
# @pytest.mark.parametrize('optimized', [True])
# @pytest.mark.parametrize('optimized', [True])
def
test_vit
(
optimized
,
fused_
dense_gelu_dense
):
def
test_vit
(
optimized
,
fused_
mlp
):
"""Check that our implementation of ViT matches the timm's implementation:
"""Check that our implementation of ViT matches the timm's implementation:
the output of our forward pass in fp16 should be around the same as
the output of our forward pass in fp16 should be around the same as
timm' forward pass in fp16, when compared to timm's forward pass in fp32.
timm' forward pass in fp16, when compared to timm's forward pass in fp32.
...
@@ -23,7 +23,7 @@ def test_vit(optimized, fused_dense_gelu_dense):
...
@@ -23,7 +23,7 @@ def test_vit(optimized, fused_dense_gelu_dense):
kwargs
=
{}
kwargs
=
{}
if
optimized
:
if
optimized
:
kwargs
=
dict
(
use_flash_attn
=
True
,
fused_bias_fc
=
True
,
fused_dropout_add_ln
=
True
)
kwargs
=
dict
(
use_flash_attn
=
True
,
fused_bias_fc
=
True
,
fused_dropout_add_ln
=
True
)
kwargs
[
'fused_
dense_gelu_dense'
]
=
fused_dense_gelu_dense
kwargs
[
'fused_
mlp'
]
=
fused_mlp
model
=
flash_vit_base_patch16_224
(
**
kwargs
).
to
(
device
=
device
,
dtype
=
dtype
)
model
=
flash_vit_base_patch16_224
(
**
kwargs
).
to
(
device
=
device
,
dtype
=
dtype
)
model_ref
=
vit_base_patch16_224
(
pretrained
=
True
).
to
(
device
=
device
)
model_ref
=
vit_base_patch16_224
(
pretrained
=
True
).
to
(
device
=
device
)
...
@@ -46,4 +46,5 @@ def test_vit(optimized, fused_dense_gelu_dense):
...
@@ -46,4 +46,5 @@ def test_vit(optimized, fused_dense_gelu_dense):
print
(
f
'Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'timm fp16 max diff:
{
(
out_timm
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'timm fp16 max diff:
{
(
out_timm
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'timm fp16 mean diff:
{
(
out_timm
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'timm fp16 mean diff:
{
(
out_timm
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
3
*
(
out_timm
-
out_ref
).
abs
().
max
().
item
()
rtol
=
2
if
not
fused_mlp
else
4
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
rtol
*
(
out_timm
-
out_ref
).
abs
().
max
().
item
()
tests/modules/test_block_parallel.py
View file @
88173a1a
...
@@ -15,7 +15,7 @@ from apex.transformer import parallel_state
...
@@ -15,7 +15,7 @@ from apex.transformer import parallel_state
from
apex.transformer
import
tensor_parallel
from
apex.transformer
import
tensor_parallel
from
flash_attn.modules.mha
import
MHA
,
ParallelMHA
from
flash_attn.modules.mha
import
MHA
,
ParallelMHA
from
flash_attn.modules.mlp
import
Fused
DenseGeluDense
,
ParallelFused
DenseGeluDense
from
flash_attn.modules.mlp
import
Fused
MLP
,
ParallelFused
MLP
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.block
import
Block
from
flash_attn.utils.distributed
import
allreduce_sequence_parallel_grad
from
flash_attn.utils.distributed
import
allreduce_sequence_parallel_grad
...
@@ -27,7 +27,7 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
...
@@ -27,7 +27,7 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
,
8
])
# @pytest.mark.parametrize('world_size', [2])
# @pytest.mark.parametrize('world_size', [2])
@
pytest
.
mark
.
parametrize
(
'sequence_parallel'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'sequence_parallel'
,
[
True
,
False
])
# @pytest.mark.parametrize('sequence_parallel', [
Fals
e])
# @pytest.mark.parametrize('sequence_parallel', [
Tru
e])
@
pytest
.
mark
.
parametrize
(
'dim'
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
'dim'
,
[
1024
])
def
test_block_parallel
(
dim
,
sequence_parallel
,
world_size
,
dtype
):
def
test_block_parallel
(
dim
,
sequence_parallel
,
world_size
,
dtype
):
head_dim
=
64
head_dim
=
64
...
@@ -62,8 +62,7 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
...
@@ -62,8 +62,7 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
mixer_cls_pt
=
partial
(
MHA
,
num_heads
=
num_heads
,
rotary_emb_dim
=
int
(
head_dim
//
2
),
mixer_cls_pt
=
partial
(
MHA
,
num_heads
=
num_heads
,
rotary_emb_dim
=
int
(
head_dim
//
2
),
use_flash_attn
=
True
,
device
=
device
,
dtype
=
dtype
)
use_flash_attn
=
True
,
device
=
device
,
dtype
=
dtype
)
mlp_cls_pt
=
partial
(
FusedDenseGeluDense
,
hidden_features
=
4
*
dim
,
mlp_cls_pt
=
partial
(
FusedMLP
,
hidden_features
=
4
*
dim
,
device
=
device
,
dtype
=
dtype
)
device
=
device
,
dtype
=
dtype
)
norm_cls
=
partial
(
nn
.
LayerNorm
,
device
=
device
,
dtype
=
dtype
)
norm_cls
=
partial
(
nn
.
LayerNorm
,
device
=
device
,
dtype
=
dtype
)
model_pt
=
Block
(
dim
,
mixer_cls_pt
,
mlp_cls_pt
,
norm_cls
,
fused_dropout_add_ln
=
True
)
model_pt
=
Block
(
dim
,
mixer_cls_pt
,
mlp_cls_pt
,
norm_cls
,
fused_dropout_add_ln
=
True
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -76,7 +75,7 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
...
@@ -76,7 +75,7 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
process_group
=
parallel_state
.
get_tensor_model_parallel_group
(),
process_group
=
parallel_state
.
get_tensor_model_parallel_group
(),
rotary_emb_dim
=
int
(
head_dim
//
2
),
use_flash_attn
=
True
,
rotary_emb_dim
=
int
(
head_dim
//
2
),
use_flash_attn
=
True
,
sequence_parallel
=
sequence_parallel
,
device
=
device
,
dtype
=
dtype
)
sequence_parallel
=
sequence_parallel
,
device
=
device
,
dtype
=
dtype
)
mlp_cls
=
partial
(
ParallelFused
DenseGeluDense
,
hidden_features
=
4
*
dim
,
mlp_cls
=
partial
(
ParallelFused
MLP
,
hidden_features
=
4
*
dim
,
process_group
=
parallel_state
.
get_tensor_model_parallel_group
(),
process_group
=
parallel_state
.
get_tensor_model_parallel_group
(),
sequence_parallel
=
sequence_parallel
,
device
=
device
,
dtype
=
dtype
)
sequence_parallel
=
sequence_parallel
,
device
=
device
,
dtype
=
dtype
)
model
=
Block
(
dim
,
mixer_cls
,
mlp_cls
,
norm_cls
,
fused_dropout_add_ln
=
True
,
model
=
Block
(
dim
,
mixer_cls
,
mlp_cls
,
norm_cls
,
fused_dropout_add_ln
=
True
,
...
@@ -143,7 +142,7 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
...
@@ -143,7 +142,7 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
x
.
grad
,
x
.
grad
,
x_pt
.
grad
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
]
x_pt
.
grad
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
x_pt
.
grad
,
if
sequence_parallel
else
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
/
10
0
# magnitude of x.grad is quite small
rtol
=
rtol
,
atol
=
atol
/
10
# magnitude of x.grad is quite small
)
)
assert
torch
.
allclose
(
assert
torch
.
allclose
(
residual
.
grad
,
residual
.
grad
,
...
...
tests/ops/test_fused_dense.py
View file @
88173a1a
import
math
import
math
from
functools
import
partial
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -6,7 +7,7 @@ import pytest
...
@@ -6,7 +7,7 @@ import pytest
from
einops
import
rearrange
from
einops
import
rearrange
from
flash_attn.ops.fused_dense
import
FusedDense
,
Fused
DenseGeluDense
from
flash_attn.ops.fused_dense
import
FusedDense
,
Fused
MLP
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
bfloat16
])
...
@@ -60,15 +61,25 @@ def test_fused_linear_bias(in_features, out_features, has_bias, return_residual,
...
@@ -60,15 +61,25 @@ def test_fused_linear_bias(in_features, out_features, has_bias, return_residual,
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'heuristic'
,
[
0
,
-
1
])
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
'heuristic'
,
[
'auto'
,
-
1
])
# @pytest.mark.parametrize('heuristic', ['auto'])
@
pytest
.
mark
.
parametrize
(
'checkpoint_lvl'
,
[
0
,
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'checkpoint_lvl'
,
[
0
,
1
,
2
])
# @pytest.mark.parametrize('checkpoint_lvl', [1])
@
pytest
.
mark
.
parametrize
(
'return_residual'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'return_residual'
,
[
False
,
True
])
# @pytest.mark.parametrize('return_residual', [False])
@
pytest
.
mark
.
parametrize
(
'has_bias2'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'has_bias2'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'has_bias1'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'has_bias1'
,
[
True
,
False
])
# @pytest.mark.parametrize('has_bias2', [True])
# @pytest.mark.parametrize('has_bias1', [True])
@
pytest
.
mark
.
parametrize
(
'activation'
,
[
'gelu_approx'
,
'relu'
])
# @pytest.mark.parametrize('activation', ['relu'])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
,
4096
])
def
test_fused_dense_gelu_dense
(
in_features
,
out_features
,
has_bias1
,
has_bias2
,
return_residual
,
# @pytest.mark.parametrize('out_features', [4096])
checkpoint_lvl
,
heuristic
,
dtype
):
# @pytest.mark.parametrize('in_features', [1024])
def
test_fused_mlp
(
in_features
,
out_features
,
activation
,
has_bias1
,
has_bias2
,
return_residual
,
checkpoint_lvl
,
heuristic
,
dtype
):
device
=
'cuda'
device
=
'cuda'
rtol
,
atol
=
(
3e-3
,
3e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
rtol
,
atol
=
(
3e-3
,
3e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
# set seed
# set seed
...
@@ -82,10 +93,10 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2,
...
@@ -82,10 +93,10 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2,
dtype
=
dtype
)
dtype
=
dtype
)
model_pt_fc2
=
torch
.
nn
.
Linear
(
out_features
,
in_features
,
bias
=
has_bias2
,
device
=
device
,
model_pt_fc2
=
torch
.
nn
.
Linear
(
out_features
,
in_features
,
bias
=
has_bias2
,
device
=
device
,
dtype
=
dtype
)
dtype
=
dtype
)
model
=
Fused
DenseGeluDense
(
in_features
,
out_features
,
in_features
,
bias1
=
has_bias1
,
model
=
Fused
MLP
(
in_features
,
out_features
,
in_features
,
activation
=
activation
,
bias2
=
has_bias2
,
return_residual
=
return_residual
,
bias1
=
has_bias1
,
bias2
=
has_bias2
,
return_residual
=
return_residual
,
checkpoint_lvl
=
checkpoint_lvl
,
heuristic
=
heuristic
,
checkpoint_lvl
=
checkpoint_lvl
,
heuristic
=
heuristic
,
device
=
device
,
dtype
=
dtype
)
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
.
fc1
.
weight
.
copy_
(
model_pt_fc1
.
weight
)
model
.
fc1
.
weight
.
copy_
(
model_pt_fc1
.
weight
)
if
has_bias1
:
if
has_bias1
:
...
@@ -93,7 +104,9 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2,
...
@@ -93,7 +104,9 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2,
model
.
fc2
.
weight
.
copy_
(
model_pt_fc2
.
weight
)
model
.
fc2
.
weight
.
copy_
(
model_pt_fc2
.
weight
)
if
has_bias2
:
if
has_bias2
:
model
.
fc2
.
bias
.
copy_
(
model_pt_fc2
.
bias
)
model
.
fc2
.
bias
.
copy_
(
model_pt_fc2
.
bias
)
out_pt
=
model_pt_fc2
(
F
.
gelu
(
model_pt_fc1
(
x_pt
),
approximate
=
'tanh'
))
activation_fn
=
(
partial
(
F
.
gelu
,
approximate
=
'tanh'
)
if
activation
==
'gelu_approx'
else
partial
(
F
.
relu
,
inplace
=
True
))
out_pt
=
model_pt_fc2
(
activation_fn
(
model_pt_fc1
(
x_pt
)))
if
not
return_residual
:
if
not
return_residual
:
out
=
model
(
x
)
out
=
model
(
x
)
else
:
else
:
...
@@ -107,6 +120,9 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2,
...
@@ -107,6 +120,9 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2,
g
=
torch
.
randn_like
(
out
)
/
32
g
=
torch
.
randn_like
(
out
)
/
32
out_pt
.
backward
(
g
)
out_pt
.
backward
(
g
)
out
.
backward
(
g
)
out
.
backward
(
g
)
# The error for relu is higher still
if
activation
==
'relu'
:
atol
=
1e-1
if
dtype
==
torch
.
bfloat16
else
5e-2
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
# The error for d_weight and d_bias is quite a bit higher
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
fc1
.
weight
.
grad
,
model_pt_fc1
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
fc1
.
weight
.
grad
,
model_pt_fc1
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
...
...
tests/ops/test_fused_dense_parallel.py
View file @
88173a1a
...
@@ -10,8 +10,8 @@ import pytest
...
@@ -10,8 +10,8 @@ import pytest
from
apex.transformer
import
parallel_state
from
apex.transformer
import
parallel_state
from
apex.transformer
import
tensor_parallel
from
apex.transformer
import
tensor_parallel
from
flash_attn.ops.fused_dense
import
FusedDense
,
Fused
DenseGeluDense
from
flash_attn.ops.fused_dense
import
FusedDense
,
Fused
MLP
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
,
ParallelFused
DenseGeluDense
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
,
ParallelFused
MLP
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)[
0
]
>=
8
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)[
0
]
>=
8
...
@@ -106,8 +106,7 @@ def test_fused_linear_bias(in_features, out_features, has_bias, sequence_paralle
...
@@ -106,8 +106,7 @@ def test_fused_linear_bias(in_features, out_features, has_bias, sequence_paralle
# @pytest.mark.parametrize('has_bias2', [True])
# @pytest.mark.parametrize('has_bias2', [True])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
])
def
test_fused_dense_gelu_dense
(
in_features
,
out_features
,
has_bias2
,
sequence_parallel
,
def
test_fused_mlp
(
in_features
,
out_features
,
has_bias2
,
sequence_parallel
,
world_size
,
dtype
):
world_size
,
dtype
):
assert
out_features
%
world_size
==
0
assert
out_features
%
world_size
==
0
rtol
,
atol
=
(
3e-3
,
3e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
3e-3
)
rtol
,
atol
=
(
3e-3
,
3e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
3e-3
)
if
not
torch
.
distributed
.
is_initialized
():
if
not
torch
.
distributed
.
is_initialized
():
...
@@ -137,11 +136,11 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, sequence_p
...
@@ -137,11 +136,11 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, sequence_p
dtype
=
dtype
)
dtype
=
dtype
)
partition_out_features
=
out_features
//
world_size
partition_out_features
=
out_features
//
world_size
partition_in_features
=
in_features
//
world_size
partition_in_features
=
in_features
//
world_size
model
=
ParallelFused
DenseGeluDense
(
in_features
,
out_features
,
in_features
,
model
=
ParallelFused
MLP
(
in_features
,
out_features
,
in_features
,
process_group
=
parallel_state
.
get_tensor_model_parallel_group
(),
process_group
=
parallel_state
.
get_tensor_model_parallel_group
(),
bias2
=
has_bias2
and
rank
==
0
,
bias2
=
has_bias2
and
rank
==
0
,
sequence_parallel
=
sequence_parallel
,
sequence_parallel
=
sequence_parallel
,
device
=
device
,
dtype
=
dtype
)
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
.
fc1
.
weight
.
copy_
(
model
.
fc1
.
weight
.
copy_
(
...
...
training/README.md
View file @
88173a1a
...
@@ -48,7 +48,7 @@ config = GPT2Config(vocab_size=50257, n_positions=seqlen, n_embd=hidden_dim,
...
@@ -48,7 +48,7 @@ config = GPT2Config(vocab_size=50257, n_positions=seqlen, n_embd=hidden_dim,
n_layer
=
n_layer
,
n_head
=
nheads
,
n_layer
=
n_layer
,
n_head
=
nheads
,
scale_attn_by_inverse_layer_idx
=
True
,
scale_attn_by_inverse_layer_idx
=
True
,
rotary_emb_fraction
=
rotary_emb_fraction
,
rotary_emb_fraction
=
rotary_emb_fraction
,
use_flash_attn
=
True
,
fused_
dense_gelu_dense
=
True
,
use_flash_attn
=
True
,
fused_
mlp
=
True
,
fused_bias_fc
=
True
,
fused_dropout_add_ln
=
True
,
fused_bias_fc
=
True
,
fused_dropout_add_ln
=
True
,
pad_vocab_size_multiple
=
8
)
pad_vocab_size_multiple
=
8
)
model
=
GPTLMHeadModel
(
config
)
model
=
GPTLMHeadModel
(
config
)
...
...
training/configs/experiment/owt/gpt2s-flash.yaml
View file @
88173a1a
...
@@ -7,9 +7,10 @@ defaults:
...
@@ -7,9 +7,10 @@ defaults:
model
:
model
:
config
:
config
:
# n_positions is already set to ${datamodule.max_length}
# n_positions is already set to ${datamodule.max_length}
residual_in_fp32
:
True
use_flash_attn
:
True
use_flash_attn
:
True
fused_bias_fc
:
True
fused_bias_fc
:
True
fused_
dense_gelu_dense
:
True
fused_
mlp
:
True
fused_dropout_add_ln
:
True
fused_dropout_add_ln
:
True
pad_vocab_size_multiple
:
8
pad_vocab_size_multiple
:
8
...
...
training/configs/experiment/pile/gpt3s-flash.yaml
View file @
88173a1a
...
@@ -7,9 +7,10 @@ defaults:
...
@@ -7,9 +7,10 @@ defaults:
model
:
model
:
config
:
config
:
# n_positions is already set to ${datamodule.max_length}
# n_positions is already set to ${datamodule.max_length}
residual_in_fp32
:
True
use_flash_attn
:
True
use_flash_attn
:
True
fused_dropout_add_ln
:
True
fused_dropout_add_ln
:
True
fused_
dense_gelu_dense
:
True
fused_
mlp
:
True
fused_bias_fc
:
True
fused_bias_fc
:
True
pad_vocab_size_multiple
:
8
pad_vocab_size_multiple
:
8
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment