Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
e68ebbe8
Commit
e68ebbe8
authored
Dec 22, 2022
by
Tri Dao
Browse files
Simplify FusedDense
parent
1bc6e5b0
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
318 additions
and
1063 deletions
+318
-1063
csrc/fused_dense_lib/fused_dense.cpp
csrc/fused_dense_lib/fused_dense.cpp
+72
-243
csrc/fused_dense_lib/fused_dense_cuda.cu
csrc/fused_dense_lib/fused_dense_cuda.cu
+18
-431
flash_attn/layers/patch_embed.py
flash_attn/layers/patch_embed.py
+4
-4
flash_attn/models/bert.py
flash_attn/models/bert.py
+10
-8
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+2
-0
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+7
-6
flash_attn/modules/mlp.py
flash_attn/modules/mlp.py
+2
-44
flash_attn/ops/fused_dense.py
flash_attn/ops/fused_dense.py
+146
-233
tests/ops/test_fused_dense.py
tests/ops/test_fused_dense.py
+57
-94
No files found.
csrc/fused_dense_lib/fused_dense.cpp
View file @
e68ebbe8
...
...
@@ -6,6 +6,8 @@
#include <stdio.h>
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \
...
...
@@ -24,14 +26,6 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
template
<
typename
T
>
int
linear_bias_forward_cuda
(
at
::
Tensor
input
,
T
*
weight
,
at
::
Tensor
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
at
::
Tensor
output
,
void
*
lt_workspace
);
template
<
typename
T
>
int
linear_bias_backward_cuda
(
T
*
input
,
T
*
weight
,
T
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
T
*
d_weight
,
T
*
d_bias
,
T
*
d_input
,
bool
residual
,
void
*
lt_workspace
);
template
<
typename
T
>
int
linear_bias_wgrad_cuda
(
T
*
input
,
T
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
T
*
d_weight
,
T
*
d_bias
,
void
*
lt_workspace
);
...
...
@@ -39,103 +33,34 @@ template <typename T>
int
linear_gelu_forward_cuda
(
T
*
input
,
T
*
weight
,
T
*
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
int
heuristic
,
T
*
output
,
T
*
gelu_in
,
void
*
lt_workspace
)
;
template
<
typename
T
>
int
linear_gelu_linear_backward_cuda
(
T
*
input
,
T
*
gelu_in
,
T
*
output1
,
T
*
weight1
,
T
*
weight2
,
T
*
d_output1
,
T
*
d_output2
,
int
in_features
,
int
batch_size
,
int
hidden_features
,
int
out_features
,
int
heuristic
,
T
*
d_weight1
,
T
*
d_weight2
,
T
*
d_bias1
,
T
*
d_bias2
,
T
*
d_input
,
bool
residual
,
void
*
lt_workspace
);
at
::
Tensor
linear_bias_forward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
at
::
Tensor
bias
)
{
auto
batch_size
=
input
.
size
(
0
);
auto
in_features
=
input
.
size
(
1
);
int
out_features
=
weight
.
size
(
0
);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto
out
=
at
::
empty
({
batch_size
,
out_features
},
at
::
dtype
(
input
.
dtype
()).
device
(
input
.
device
()));
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
at
::
dtype
(
input
.
dtype
()).
device
(
input
.
device
()));
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_bias_forward"
,
[
&
]
{
scalar_t
*
w_ptr
=
weight
.
data_ptr
<
scalar_t
>
();
auto
result
=
linear_bias_forward_cuda
<
scalar_t
>
(
input
,
w_ptr
,
bias
,
in_features
,
batch_size
,
out_features
,
out
,
//out.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
TORCH_CHECK
(
result
==
0
,
"linear_bias_forward failed."
)
});
return
{
out
};
}
std
::
vector
<
at
::
Tensor
>
linear_bias_backward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
at
::
Tensor
d_output
)
{
auto
batch_size
=
input
.
size
(
0
);
auto
in_features
=
input
.
size
(
1
);
int
out_features
=
weight
.
size
(
0
);
int
bias_gelu_linear_dgrad_bgrad_cuda
(
T
*
weight
,
T
*
d_output
,
T
*
gelu_in
,
int
in_features
,
int
batch_size
,
int
out_features
,
int
heuristic
,
T
*
d_input
,
T
*
d_bias
,
void
*
lt_workspace
);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto
opts
=
input
.
options
();
auto
d_weight
=
at
::
empty
({
out_features
,
in_features
},
opts
);
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
auto
d_bias
=
d_output
.
view
({
-
1
,
out_features
}).
sum
(
0
,
false
);
#else
auto
d_bias
=
at
::
empty
({
out_features
},
opts
);
#endif
auto
d_input
=
at
::
empty
({
batch_size
,
in_features
},
opts
);
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_bias_backward"
,
[
&
]
{
scalar_t
*
w_ptr
=
weight
.
data_ptr
<
scalar_t
>
();
auto
result
=
linear_bias_backward_cuda
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
w_ptr
,
d_output
.
data_ptr
<
scalar_t
>
(),
in_features
,
batch_size
,
out_features
,
d_weight
.
data_ptr
<
scalar_t
>
(),
d_bias
.
data_ptr
<
scalar_t
>
(),
d_input
.
data_ptr
<
scalar_t
>
(),
// reserved_space.data_ptr<scalar_t>(),
/*residual=*/
false
,
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
TORCH_CHECK
(
result
==
0
,
"linear_bias_backward failed."
)
});
return
{
d_input
,
d_weight
,
d_bias
};
}
std
::
vector
<
at
::
Tensor
>
linear_bias_wgrad
(
at
::
Tensor
input
,
at
::
Tensor
d_output
)
{
auto
batch_size
=
input
.
size
(
0
);
auto
in_features
=
input
.
size
(
1
);
std
::
vector
<
at
::
Tensor
>
linear_bias_wgrad
(
at
::
Tensor
input
,
at
::
Tensor
d_output
,
bool
has_d_bias
)
{
int
batch_size
=
input
.
size
(
0
);
int
in_features
=
input
.
size
(
1
);
int
out_features
=
d_output
.
size
(
1
);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
TORCH_CHECK
(
input
.
dtype
()
==
torch
::
kFloat16
||
input
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
input
.
dtype
()
==
d_output
.
dtype
());
TORCH_CHECK
(
input
.
is_cuda
());
TORCH_CHECK
(
d_output
.
is_cuda
());
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
d_output
.
is_contiguous
());
CHECK_SHAPE
(
input
,
batch_size
,
in_features
);
CHECK_SHAPE
(
d_output
,
batch_size
,
out_features
);
// create output/workspace tensor
auto
opts
=
input
.
options
();
auto
d_weight
=
at
::
empty
({
out_features
,
in_features
},
opts
);
at
::
Tensor
d_bias
;
if
(
has_d_bias
)
{
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
auto
d_bias
=
d_output
.
view
({
-
1
,
out_features
}).
sum
(
0
,
false
);
d_bias
=
d_output
.
view
({
-
1
,
out_features
}).
sum
(
0
,
false
);
#else
auto
d_bias
=
at
::
empty
({
out_features
},
opts
);
d_bias
=
at
::
empty
({
out_features
},
opts
);
#endif
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
}
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
...
...
@@ -147,93 +72,59 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output)
batch_size
,
out_features
,
d_weight
.
data_ptr
<
scalar_t
>
(),
d_bias
.
data_ptr
<
scalar_t
>
(),
// reserved_space.data_ptr<scalar_t>(),
has_d_bias
?
d_bias
.
data_ptr
<
scalar_t
>
()
:
nullptr
,
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
TORCH_CHECK
(
result
==
0
,
"linear_bias_wgrad failed."
)
TORCH_CHECK
(
result
==
0
,
"linear_bias_wgrad failed."
)
;
});
return
{
d_weight
,
d_bias
};
}
std
::
vector
<
at
::
Tensor
>
linear_bias_residual_backward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
at
::
Tensor
d_output
,
at
::
Tensor
d_input
)
{
auto
batch_size
=
input
.
size
(
0
);
auto
in_features
=
input
.
size
(
1
);
int
out_features
=
weight
.
size
(
0
);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto
opts
=
input
.
options
();
auto
d_weight
=
at
::
empty
({
out_features
,
in_features
},
opts
);
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
auto
d_bias
=
d_output
.
view
({
-
1
,
out_features
}).
sum
(
0
,
false
);
#else
auto
d_bias
=
at
::
empty
({
out_features
},
opts
);
#endif
CHECK_SHAPE
(
d_input
,
batch_size
,
in_features
);
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_bias_backward"
,
[
&
]
{
scalar_t
*
w_ptr
=
weight
.
data_ptr
<
scalar_t
>
();
auto
result
=
linear_bias_backward_cuda
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
w_ptr
,
d_output
.
data_ptr
<
scalar_t
>
(),
in_features
,
batch_size
,
out_features
,
d_weight
.
data_ptr
<
scalar_t
>
(),
d_bias
.
data_ptr
<
scalar_t
>
(),
d_input
.
data_ptr
<
scalar_t
>
(),
// reserved_space.data_ptr<scalar_t>(),
/*residual=*/
true
,
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
TORCH_CHECK
(
result
==
0
,
"linear_bias_residual_backward failed."
)
});
return
{
d_input
,
d_weight
,
d_bias
};
}
std
::
vector
<
at
::
Tensor
>
linear_gelu_forward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
at
::
Tensor
bias
,
std
::
vector
<
at
::
Tensor
>
linear_gelu_forward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
c10
::
optional
<
at
::
Tensor
>
bias_
,
bool
save_gelu_in
,
int
heuristic
)
{
auto
batch_size
=
input
.
size
(
0
);
auto
in_features
=
input
.
size
(
1
);
int
batch_size
=
input
.
size
(
0
);
int
in_features
=
input
.
size
(
1
);
int
out_features
=
weight
.
size
(
0
);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
TORCH_CHECK
(
input
.
dtype
()
==
torch
::
kFloat16
||
input
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
input
.
dtype
()
==
weight
.
dtype
());
TORCH_CHECK
(
input
.
is_cuda
());
TORCH_CHECK
(
weight
.
is_cuda
());
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
weight
.
is_contiguous
());
CHECK_SHAPE
(
input
,
batch_size
,
in_features
);
CHECK_SHAPE
(
weight
,
out_features
,
in_features
);
if
(
bias_
.
has_value
())
{
auto
bias
=
bias_
.
value
();
TORCH_CHECK
(
bias
.
dtype
()
==
input
.
dtype
());
TORCH_CHECK
(
bias
.
is_cuda
());
TORCH_CHECK
(
bias
.
is_contiguous
());
CHECK_SHAPE
(
bias
,
out_features
);
}
// create output/workspace tensor
auto
opts
=
input
.
options
();
auto
output
=
at
::
empty
({
batch_size
,
out_features
},
opts
);
at
::
Tensor
gelu_in
;
if
(
save_gelu_in
)
{
gelu_in
=
at
::
empty
({
batch_size
,
out_features
},
opts
);
}
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_gelu_forward"
,
[
&
]
{
scalar_t
*
w_ptr
=
weight
.
data_ptr
<
scalar_t
>
();
scalar_t
*
b_ptr
=
bias
.
data_ptr
<
scalar_t
>
();
auto
result
=
linear_gelu_forward_cuda
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
w
_ptr
,
b
_
ptr
,
w
eight
.
data_ptr
<
scalar_t
>
()
,
b
ias_
.
has_value
()
?
bias_
.
value
().
data_ptr
<
scalar_t
>
()
:
null
ptr
,
in_features
,
batch_size
,
out_features
,
heuristic
,
output
.
data_ptr
<
scalar_t
>
(),
save_gelu_in
?
gelu_in
.
data_ptr
<
scalar_t
>
()
:
nullptr
,
// reserved_space.data_ptr<scalar_t>(),
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
TORCH_CHECK
(
result
==
0
,
"linear_gelu_forward failed."
)
TORCH_CHECK
(
result
==
0
,
"linear_gelu_forward failed."
)
;
});
std
::
vector
<
at
::
Tensor
>
result
=
{
output
};
...
...
@@ -241,116 +132,54 @@ std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight,
return
result
;
}
std
::
vector
<
at
::
Tensor
>
linear_gelu_linear_backward
(
at
::
Tensor
input
,
at
::
Tensor
gelu_in
,
at
::
Tensor
output1
,
at
::
Tensor
weight1
,
at
::
Tensor
weight2
,
at
::
Tensor
d_output2
,
int
heuristic
)
{
std
::
vector
<
at
::
Tensor
>
bias_gelu_linear_dgrad_bgrad
(
at
::
Tensor
weight
,
at
::
Tensor
d_output
,
at
::
Tensor
gelu_in
,
int
heuristic
)
{
auto
batch_size
=
input
.
size
(
0
);
auto
in_features
=
input
.
size
(
1
);
int
hidden_features
=
weight1
.
size
(
0
);
int
out_features
=
weight2
.
size
(
0
);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
int
batch_size
=
d_output
.
size
(
0
);
int
out_features
=
d_output
.
size
(
1
);
int
in_features
=
weight
.
size
(
1
);
TORCH_CHECK
(
weight
.
dtype
()
==
torch
::
kFloat16
||
weight
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
weight
.
dtype
()
==
d_output
.
dtype
());
TORCH_CHECK
(
weight
.
dtype
()
==
gelu_in
.
dtype
());
TORCH_CHECK
(
weight
.
is_cuda
());
TORCH_CHECK
(
d_output
.
is_cuda
());
TORCH_CHECK
(
gelu_in
.
is_cuda
());
TORCH_CHECK
(
weight
.
is_contiguous
());
TORCH_CHECK
(
d_output
.
is_contiguous
());
TORCH_CHECK
(
gelu_in
.
is_contiguous
());
CHECK_SHAPE
(
weight
,
out_features
,
in_features
);
CHECK_SHAPE
(
d_output
,
batch_size
,
out_features
);
CHECK_SHAPE
(
gelu_in
,
batch_size
,
in_features
);
// create output/workspace tensor
auto
opts
=
input
.
options
();
auto
d_weight1
=
at
::
empty
({
hidden_features
,
in_features
},
opts
);
auto
d_weight2
=
at
::
empty
({
out_features
,
hidden_features
},
opts
);
auto
d_bias1
=
at
::
empty
({
hidden_features
},
opts
);
auto
d_bias2
=
at
::
empty
({
out_features
},
opts
);
auto
opts
=
weight
.
options
();
auto
d_bias
=
at
::
empty
({
in_features
},
opts
);
auto
d_input
=
at
::
empty
({
batch_size
,
in_features
},
opts
);
auto
d_output1
=
at
::
empty
({
batch_size
,
hidden_features
},
opts
);
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_bias_backward"
,
[
&
]
{
//scalar_t* w_ptr = weight.data_ptr<scalar_t>();
//scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
auto
result
=
linear_gelu_linear_backward_cuda
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
gelu_in
.
data_ptr
<
scalar_t
>
(),
output1
.
data_ptr
<
scalar_t
>
(),
weight1
.
data_ptr
<
scalar_t
>
(),
weight2
.
data_ptr
<
scalar_t
>
(),
d_output1
.
data_ptr
<
scalar_t
>
(),
d_output2
.
data_ptr
<
scalar_t
>
(),
in_features
,
batch_size
,
hidden_features
,
out_features
,
heuristic
,
d_weight1
.
data_ptr
<
scalar_t
>
(),
d_weight2
.
data_ptr
<
scalar_t
>
(),
d_bias1
.
data_ptr
<
scalar_t
>
(),
d_bias2
.
data_ptr
<
scalar_t
>
(),
d_input
.
data_ptr
<
scalar_t
>
(),
// reserved_space.data_ptr<scalar_t>(),
/*residual=*/
false
,
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
TORCH_CHECK
(
result
==
0
,
"linear_gelu_linear_backward failed."
)
});
return
{
d_input
,
d_weight1
,
d_bias1
,
d_weight2
,
d_bias2
};
}
std
::
vector
<
at
::
Tensor
>
linear_residual_gelu_linear_backward
(
at
::
Tensor
input
,
at
::
Tensor
gelu_in
,
at
::
Tensor
output1
,
at
::
Tensor
weight1
,
at
::
Tensor
weight2
,
at
::
Tensor
d_output2
,
at
::
Tensor
d_input
,
int
heuristic
)
{
auto
batch_size
=
input
.
size
(
0
);
auto
in_features
=
input
.
size
(
1
);
int
hidden_features
=
weight1
.
size
(
0
);
int
out_features
=
weight2
.
size
(
0
);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto
opts
=
input
.
options
();
auto
d_weight1
=
at
::
empty
({
hidden_features
,
in_features
},
opts
);
auto
d_weight2
=
at
::
empty
({
out_features
,
hidden_features
},
opts
);
auto
d_bias1
=
at
::
empty
({
hidden_features
},
opts
);
auto
d_bias2
=
at
::
empty
({
out_features
},
opts
);
CHECK_SHAPE
(
d_input
,
batch_size
,
in_features
);
auto
d_output1
=
at
::
empty
({
batch_size
,
hidden_features
},
opts
);
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_bias_backward"
,
[
&
]
{
//scalar_t* w_ptr = weight.data_ptr<scalar_t>();
//scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
auto
result
=
linear_gelu_linear_backward_cuda
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
DISPATCH_HALF_AND_BF16
(
weight
.
scalar_type
(),
"bias_gelu_linear_dgrad_bgrad"
,
[
&
]
{
auto
result
=
bias_gelu_linear_dgrad_bgrad_cuda
<
scalar_t
>
(
weight
.
data_ptr
<
scalar_t
>
(),
d_output
.
data_ptr
<
scalar_t
>
(),
gelu_in
.
data_ptr
<
scalar_t
>
(),
output1
.
data_ptr
<
scalar_t
>
(),
weight1
.
data_ptr
<
scalar_t
>
(),
weight2
.
data_ptr
<
scalar_t
>
(),
d_output1
.
data_ptr
<
scalar_t
>
(),
d_output2
.
data_ptr
<
scalar_t
>
(),
in_features
,
batch_size
,
hidden_features
,
out_features
,
heuristic
,
d_weight1
.
data_ptr
<
scalar_t
>
(),
d_weight2
.
data_ptr
<
scalar_t
>
(),
d_bias1
.
data_ptr
<
scalar_t
>
(),
d_bias2
.
data_ptr
<
scalar_t
>
(),
d_input
.
data_ptr
<
scalar_t
>
(),
// reserved_space.data_ptr<scalar_t>(),
/*residual=*/
true
,
d_bias
.
data_ptr
<
scalar_t
>
(),
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
TORCH_CHECK
(
result
==
0
,
"
linear_residual
_gelu_linear_
backwar
d failed."
)
TORCH_CHECK
(
result
==
0
,
"
bias
_gelu_linear_
dgrad_bgra
d failed."
)
;
});
return
{
d_input
,
d_
weight1
,
d_bias1
,
d_weight2
,
d_bias2
};
return
{
d_input
,
d_
bias
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"linear_bias_forward"
,
&
linear_bias_forward
,
"linear bias forward"
);
m
.
def
(
"linear_bias_backward"
,
&
linear_bias_backward
,
"linear bias backward"
);
m
.
def
(
"linear_bias_wgrad"
,
&
linear_bias_wgrad
,
"linear bias wgrad"
);
m
.
def
(
"linear_bias_residual_backward"
,
&
linear_bias_residual_backward
,
"linear bias residual backward"
);
m
.
def
(
"linear_gelu_forward"
,
&
linear_gelu_forward
,
"linear gelu forward"
);
m
.
def
(
"linear_gelu_linear_backward"
,
&
linear_gelu_linear_backward
,
"linear gelu linear backward"
);
m
.
def
(
"linear_residual_gelu_linear_backward"
,
&
linear_residual_gelu_linear_backward
,
"linear residual gelu linear backward"
);
m
.
def
(
"bias_gelu_linear_dgrad_bgrad"
,
&
bias_gelu_linear_dgrad_bgrad
,
"bias gelu linear dgrad bgrad"
);
}
csrc/fused_dense_lib/fused_dense_cuda.cu
View file @
e68ebbe8
...
...
@@ -94,226 +94,6 @@ cublasStatus_t gemm_bias(
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
int
gemm_bias_lt
(
cublasLtHandle_t
ltHandle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
/* host pointer */
at
::
Half
*
A
,
int
lda
,
at
::
Half
*
B
,
int
ldb
,
const
float
*
beta
,
/* host pointer */
at
::
Half
*
C
,
int
ldc
,
void
*
workspace
,
size_t
workspaceSize
,
cudaStream_t
stream
,
bool
use_bias
,
const
void
*
bias
)
{
cublasStatus_t
status
=
CUBLAS_STATUS_SUCCESS
;
cublasLtMatmulDescOpaque_t
operationDesc
=
{};
cublasLtMatrixLayoutOpaque_t
Adesc
=
{},
Bdesc
=
{},
Cdesc
=
{};
cublasLtMatmulPreferenceOpaque_t
preference
=
{};
int
returnedResults
=
0
;
cublasLtMatmulHeuristicResult_t
heuristicResult
=
{};
cublasLtEpilogue_t
epilogue
=
CUBLASLT_EPILOGUE_DEFAULT
;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status
=
cublasLtMatmulDescInit
(
&
operationDesc
,
CUBLAS_COMPUTE_32F
,
CUDA_R_32F
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
transa
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
transb
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
use_bias
)
{
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bias
,
sizeof
(
bias
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
epilogue
=
CUBLASLT_EPILOGUE_BIAS
;
}
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
epilogue
,
sizeof
(
epilogue
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
// Create matrix descriptors. Not setting any extra attributes.
status
=
cublasLtMatrixLayoutInit
(
&
Adesc
,
CUDA_R_16F
,
transa
==
CUBLAS_OP_N
?
m
:
k
,
transa
==
CUBLAS_OP_N
?
k
:
m
,
lda
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Bdesc
,
CUDA_R_16F
,
transb
==
CUBLAS_OP_N
?
k
:
n
,
transb
==
CUBLAS_OP_N
?
n
:
k
,
ldb
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Cdesc
,
CUDA_R_16F
,
m
,
n
,
ldc
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status
=
cublasLtMatmulPreferenceInit
(
&
preference
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulPreferenceSetAttribute
(
&
preference
,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
,
&
workspaceSize
,
sizeof
(
workspaceSize
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status
=
cublasLtMatmulAlgoGetHeuristic
(
ltHandle
,
&
operationDesc
,
&
Adesc
,
&
Bdesc
,
&
Cdesc
,
&
Cdesc
,
&
preference
,
1
,
&
heuristicResult
,
&
returnedResults
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
returnedResults
==
0
)
{
status
=
CUBLAS_STATUS_NOT_SUPPORTED
;
goto
CLEANUP
;
}
status
=
cublasLtMatmul
(
ltHandle
,
&
operationDesc
,
alpha
,
A
,
&
Adesc
,
B
,
&
Bdesc
,
beta
,
C
,
&
Cdesc
,
C
,
&
Cdesc
,
//&heuristicResult.algo,
NULL
,
workspace
,
workspaceSize
,
stream
);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return
status
==
CUBLAS_STATUS_SUCCESS
?
0
:
1
;
}
int
gemm_bias_lt
(
cublasLtHandle_t
ltHandle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
/* host pointer */
at
::
BFloat16
*
A
,
int
lda
,
at
::
BFloat16
*
B
,
int
ldb
,
const
float
*
beta
,
/* host pointer */
at
::
BFloat16
*
C
,
int
ldc
,
void
*
workspace
,
size_t
workspaceSize
,
cudaStream_t
stream
,
bool
use_bias
,
const
void
*
bias
)
{
cublasStatus_t
status
=
CUBLAS_STATUS_SUCCESS
;
cublasLtMatmulDescOpaque_t
operationDesc
=
{};
cublasLtMatrixLayoutOpaque_t
Adesc
=
{},
Bdesc
=
{},
Cdesc
=
{};
cublasLtMatmulPreferenceOpaque_t
preference
=
{};
int
returnedResults
=
0
;
cublasLtMatmulHeuristicResult_t
heuristicResult
=
{};
cublasLtEpilogue_t
epilogue
=
CUBLASLT_EPILOGUE_DEFAULT
;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status
=
cublasLtMatmulDescInit
(
&
operationDesc
,
CUBLAS_COMPUTE_32F
,
CUDA_R_32F
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
transa
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
transb
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
use_bias
)
{
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bias
,
sizeof
(
bias
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
epilogue
=
CUBLASLT_EPILOGUE_BIAS
;
}
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
epilogue
,
sizeof
(
epilogue
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
// Create matrix descriptors. Not setting any extra attributes.
status
=
cublasLtMatrixLayoutInit
(
&
Adesc
,
CUDA_R_16BF
,
transa
==
CUBLAS_OP_N
?
m
:
k
,
transa
==
CUBLAS_OP_N
?
k
:
m
,
lda
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Bdesc
,
CUDA_R_16BF
,
transb
==
CUBLAS_OP_N
?
k
:
n
,
transb
==
CUBLAS_OP_N
?
n
:
k
,
ldb
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Cdesc
,
CUDA_R_16BF
,
m
,
n
,
ldc
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status
=
cublasLtMatmulPreferenceInit
(
&
preference
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulPreferenceSetAttribute
(
&
preference
,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
,
&
workspaceSize
,
sizeof
(
workspaceSize
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status
=
cublasLtMatmulAlgoGetHeuristic
(
ltHandle
,
&
operationDesc
,
&
Adesc
,
&
Bdesc
,
&
Cdesc
,
&
Cdesc
,
&
preference
,
1
,
&
heuristicResult
,
&
returnedResults
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
returnedResults
==
0
)
{
status
=
CUBLAS_STATUS_NOT_SUPPORTED
;
goto
CLEANUP
;
}
status
=
cublasLtMatmul
(
ltHandle
,
&
operationDesc
,
alpha
,
A
,
&
Adesc
,
B
,
&
Bdesc
,
beta
,
C
,
&
Cdesc
,
C
,
&
Cdesc
,
//&heuristicResult.algo,
NULL
,
workspace
,
workspaceSize
,
stream
);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return
status
==
CUBLAS_STATUS_SUCCESS
?
0
:
1
;
}
int
gemm_bias_gelu_lt
(
cublasLtHandle_t
ltHandle
,
cublasOperation_t
transa
,
...
...
@@ -332,7 +112,6 @@ int gemm_bias_gelu_lt(
void
*
workspace
,
size_t
workspaceSize
,
cudaStream_t
stream
,
bool
use_bias
,
int
heuristic
,
const
void
*
gelu_in
,
const
void
*
bias
)
{
...
...
@@ -363,12 +142,14 @@ int gemm_bias_gelu_lt(
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD
,
&
ldc
,
sizeof
(
ldc
));
}
if
(
use_
bias
)
{
if
(
bias
!=
nullptr
)
{
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bias
,
sizeof
(
bias
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
epilogue
=
save_gelu_in
?
CUBLASLT_EPILOGUE_GELU_AUX_BIAS
:
CUBLASLT_EPILOGUE_GELU_BIAS
;
}
else
{
epilogue
=
save_gelu_in
?
CUBLASLT_EPILOGUE_GELU_AUX
:
CUBLASLT_EPILOGUE_GELU
;
}
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
epilogue
,
sizeof
(
epilogue
));
...
...
@@ -453,7 +234,6 @@ int gemm_bias_gelu_lt(
void
*
workspace
,
size_t
workspaceSize
,
cudaStream_t
stream
,
bool
use_bias
,
int
heuristic
,
const
void
*
gelu_in
,
const
void
*
bias
)
{
...
...
@@ -484,12 +264,14 @@ int gemm_bias_gelu_lt(
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD
,
&
ldc
,
sizeof
(
ldc
));
}
if
(
use_
bias
)
{
if
(
bias
!=
nullptr
)
{
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bias
,
sizeof
(
bias
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
epilogue
=
save_gelu_in
?
CUBLASLT_EPILOGUE_GELU_AUX_BIAS
:
CUBLASLT_EPILOGUE_GELU_BIAS
;
}
else
{
epilogue
=
save_gelu_in
?
CUBLASLT_EPILOGUE_GELU_AUX
:
CUBLASLT_EPILOGUE_GELU
;
}
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
epilogue
,
sizeof
(
epilogue
));
...
...
@@ -574,7 +356,6 @@ int gemm_bgradb_lt(
void
*
workspace
,
size_t
workspaceSize
,
cudaStream_t
stream
,
bool
use_bias
,
const
void
*
bgrad
)
{
cublasStatus_t
status
=
CUBLAS_STATUS_SUCCESS
;
...
...
@@ -596,7 +377,7 @@ int gemm_bgradb_lt(
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
transb
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
use_bias
)
{
if
(
bgrad
!=
nullptr
)
{
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bgrad
,
sizeof
(
bgrad
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
...
...
@@ -684,7 +465,6 @@ int gemm_bgradb_lt(
void
*
workspace
,
size_t
workspaceSize
,
cudaStream_t
stream
,
bool
use_bias
,
const
void
*
bgrad
)
{
cublasStatus_t
status
=
CUBLAS_STATUS_SUCCESS
;
...
...
@@ -706,7 +486,7 @@ int gemm_bgradb_lt(
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
transb
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
use_bias
)
{
if
(
bgrad
!=
nullptr
)
{
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bgrad
,
sizeof
(
bgrad
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
...
...
@@ -1008,132 +788,6 @@ CLEANUP:
#endif
template
<
typename
T
>
int
linear_bias_forward_cuda
(
at
::
Tensor
input
,
T
*
weight
,
at
::
Tensor
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
at
::
Tensor
output
,
void
*
lt_workspace
)
{
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
// Get the stream from cublas handle to reuse for biasReLU kernel.
cudaStream_t
stream
;
cublasGetStream
(
handle
,
&
stream
);
const
float
alpha
=
1.0
;
const
float
beta_zero
=
0.0
;
const
float
beta_one
=
1.0
;
int
status
=
1
;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
status
=
gemm_bias_lt
(
(
cublasLtHandle_t
)
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
out_features
,
batch_size
,
in_features
,
&
alpha
,
/* host pointer */
weight
,
in_features
,
input
.
data_ptr
<
T
>
(),
in_features
,
&
beta_zero
,
/* host pointer */
output
.
data_ptr
<
T
>
(),
out_features
,
lt_workspace
,
1
<<
22
,
stream
,
true
,
static_cast
<
const
void
*>
(
bias
.
data_ptr
<
T
>
()));
#endif
if
(
status
!=
0
){
output
.
copy_
(
bias
);
status
=
gemm_bias
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
out_features
,
batch_size
,
in_features
,
&
alpha
,
weight
,
in_features
,
input
.
data_ptr
<
T
>
(),
in_features
,
&
beta_one
,
output
.
data_ptr
<
T
>
(),
out_features
);
}
return
status
;
}
template
<
typename
T
>
int
linear_bias_backward_cuda
(
T
*
input
,
T
*
weight
,
T
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
T
*
d_weight
,
T
*
d_bias
,
T
*
d_input
,
bool
residual
,
void
*
lt_workspace
)
{
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
// Get the stream from cublas handle to reuse for biasReLU kernel.
cudaStream_t
stream
;
cublasGetStream
(
handle
,
&
stream
);
const
float
alpha
=
1.0
;
const
float
beta_zero
=
0.0
;
const
float
beta
=
residual
?
1.0
:
0.0
;
int
status
=
1
;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
status
=
gemm_bgradb_lt
(
(
cublasLtHandle_t
)
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
in_features
,
out_features
,
batch_size
,
&
alpha
,
/* host pointer */
input
,
in_features
,
d_output
,
out_features
,
&
beta_zero
,
/* host pointer */
d_weight
,
in_features
,
lt_workspace
,
1
<<
22
,
stream
,
true
,
static_cast
<
const
void
*>
(
d_bias
));
#endif
if
(
status
!=
0
){
status
=
gemm_bias
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
in_features
,
out_features
,
batch_size
,
&
alpha
,
input
,
in_features
,
d_output
,
out_features
,
&
beta_zero
,
d_weight
,
in_features
);
}
status
=
gemm_bias
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
in_features
,
batch_size
,
out_features
,
&
alpha
,
weight
,
in_features
,
d_output
,
out_features
,
&
beta
,
d_input
,
in_features
);
return
status
;
}
template
<
typename
T
>
int
linear_bias_wgrad_cuda
(
T
*
input
,
T
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
T
*
d_weight
,
T
*
d_bias
,
void
*
lt_workspace
)
{
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
...
...
@@ -1162,13 +816,10 @@ int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_siz
lt_workspace
,
1
<<
22
,
stream
,
true
,
static_cast
<
const
void
*>
(
d_bias
));
#endif
if
(
status
!=
0
){
status
=
gemm_bias
(
handle
,
CUBLAS_OP_N
,
...
...
@@ -1217,7 +868,6 @@ int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int
lt_workspace
,
1
<<
22
,
stream
,
true
,
heuristic
,
static_cast
<
const
void
*>
(
gelu_in
),
static_cast
<
const
void
*>
(
bias
));
...
...
@@ -1228,109 +878,46 @@ int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int
}
template
<
typename
T
>
int
linear
_gelu_linear_
backward_cuda
(
T
*
input
,
T
*
gelu_in
,
T
*
output1
,
T
*
weight1
,
T
*
weight2
,
T
*
d_output
1
,
T
*
d_output2
,
int
in_features
,
int
batch_size
,
int
hidden_features
,
int
out_features
,
int
heuristic
,
T
*
d_
weight1
,
T
*
d_weight2
,
T
*
d_bias1
,
T
*
d_bias2
,
T
*
d_input
,
bool
residual
,
void
*
lt_workspace
)
{
int
bias
_gelu_linear_
dgrad_bgrad_cuda
(
T
*
weight
,
T
*
d_output
,
T
*
gelu_in
,
int
in_features
,
int
batch_size
,
int
out_features
,
int
heuristic
,
T
*
d_
input
,
T
*
d_bias
,
void
*
lt_workspace
)
{
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
// Get the stream from cublas handle to reuse for biasReLU kernel.
cudaStream_t
stream
;
cublasGetStream
(
handle
,
&
stream
);
const
float
alpha
=
1.0
;
const
float
beta_zero
=
0.0
;
const
float
beta
=
residual
?
1.0
:
0.0
;
int
status
=
1
;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
//wgrad for first gemm
status
=
gemm_bgradb_lt
(
(
cublasLtHandle_t
)
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
hidden_features
,
out_features
,
batch_size
,
&
alpha
,
/* host pointer */
output1
,
hidden_features
,
d_output2
,
out_features
,
&
beta_zero
,
/* host pointer */
d_weight2
,
hidden_features
,
lt_workspace
,
1
<<
22
,
stream
,
true
,
static_cast
<
const
void
*>
(
d_bias2
));
//dgrad for second GEMM
status
=
gemm_dgelu_bgradb_lt
(
(
cublasLtHandle_t
)
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
hidde
n_features
,
i
n_features
,
batch_size
,
out_features
,
&
alpha
,
/* host pointer */
weight
2
,
hidde
n_features
,
d_output
2
,
weight
,
i
n_features
,
d_output
,
out_features
,
&
beta_zero
,
/* host pointer */
d_
out
put
1
,
hidde
n_features
,
d_
in
put
,
i
n_features
,
lt_workspace
,
1
<<
22
,
stream
,
heuristic
,
static_cast
<
const
void
*>
(
gelu_in
),
static_cast
<
const
void
*>
(
d_bias1
));
//wgrad for the first GEMM
status
=
gemm_bias
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
in_features
,
hidden_features
,
batch_size
,
&
alpha
,
input
,
in_features
,
d_output1
,
hidden_features
,
&
beta_zero
,
d_weight1
,
in_features
);
//dgrad for the first GEMM
status
=
gemm_bias
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
in_features
,
batch_size
,
hidden_features
,
&
alpha
,
weight1
,
in_features
,
d_output1
,
hidden_features
,
&
beta
,
d_input
,
in_features
);
static_cast
<
const
void
*>
(
d_bias
));
#endif
return
status
;
}
template
int
linear_bias_forward_cuda
<
at
::
Half
>(
at
::
Tensor
input
,
at
::
Half
*
weight
,
at
::
Tensor
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
at
::
Tensor
output
,
void
*
lt_workspace
);
template
int
linear_bias_forward_cuda
<
at
::
BFloat16
>(
at
::
Tensor
input
,
at
::
BFloat16
*
weight
,
at
::
Tensor
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
at
::
Tensor
output
,
void
*
lt_workspace
);
template
int
linear_bias_backward_cuda
<
at
::
Half
>(
at
::
Half
*
input
,
at
::
Half
*
weight
,
at
::
Half
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
at
::
Half
*
d_weight
,
at
::
Half
*
d_bias
,
at
::
Half
*
d_input
,
bool
residual
,
void
*
lt_workspace
)
;
template
int
linear_bias_backward_cuda
<
at
::
BFloat16
>(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
weight
,
at
::
BFloat16
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
at
::
BFloat16
*
d_weight
,
at
::
BFloat16
*
d_bias
,
at
::
BFloat16
*
d_input
,
bool
residual
,
void
*
lt_workspace
)
;
template
int
linear_bias_wgrad_cuda
<
at
::
Half
>(
at
::
Half
*
input
,
at
::
Half
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
at
::
Half
*
d_weight
,
at
::
Half
*
d_bias
,
void
*
lt_workspace
)
;
template
int
linear_bias_wgrad_cuda
<
at
::
BFloat16
>(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
at
::
BFloat16
*
d_weight
,
at
::
BFloat16
*
d_bias
,
void
*
lt_workspace
)
;
template
int
linear_gelu_forward_cuda
<
at
::
Half
>(
at
::
Half
*
input
,
at
::
Half
*
weight
,
at
::
Half
*
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
int
heuristic
,
at
::
Half
*
output
,
at
::
Half
*
gelu_in
,
void
*
lt_workspace
)
;
template
int
linear_gelu_forward_cuda
<
at
::
BFloat16
>(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
weight
,
at
::
BFloat16
*
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
int
heuristic
,
at
::
BFloat16
*
output
,
at
::
BFloat16
*
gelu_in
,
void
*
lt_workspace
)
;
template
int
linear
_gelu_linear_
backwar
d_cuda
<
at
::
Half
>(
at
::
Half
*
input
,
at
::
Half
*
gelu_in
,
at
::
Half
*
output1
,
at
::
Half
*
weight1
,
at
::
Half
*
weight2
,
at
::
Half
*
d_output
1
,
at
::
Half
*
d_output2
,
int
in_features
,
int
batch_size
,
int
hidden_features
,
int
out_features
,
int
heuristic
,
at
::
Half
*
d_
weight1
,
at
::
Half
*
d_weight2
,
at
::
Half
*
d_bias1
,
at
::
Half
*
d_bias2
,
at
::
Half
*
d_input
,
bool
residual
,
void
*
lt_workspace
);
template
int
linear
_gelu_linear_
backwar
d_cuda
<
at
::
BFloat16
>(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
gelu_in
,
at
::
BFloat16
*
output1
,
at
::
BFloat16
*
weight1
,
at
::
BFloat16
*
weight2
,
at
::
BFloat16
*
d_output
1
,
at
::
BFloat16
*
d_output2
,
int
in_features
,
int
batch_size
,
int
hidden_features
,
int
out_features
,
int
heuristic
,
at
::
BFloat16
*
d_
weight1
,
at
::
BFloat16
*
d_weight2
,
at
::
BFloat16
*
d_bias1
,
at
::
BFloat16
*
d_bias2
,
at
::
BFloat16
*
d_input
,
bool
residual
,
void
*
lt_workspace
);
template
int
bias
_gelu_linear_
dgrad_bgra
d_cuda
<
at
::
Half
>(
at
::
Half
*
weight
,
at
::
Half
*
d_output
,
at
::
Half
*
gelu_in
,
int
in_features
,
int
batch_size
,
int
out_features
,
int
heuristic
,
at
::
Half
*
d_
input
,
at
::
Half
*
d_bias
,
void
*
lt_workspace
);
template
int
bias
_gelu_linear_
dgrad_bgra
d_cuda
<
at
::
BFloat16
>(
at
::
BFloat16
*
weight
,
at
::
BFloat16
*
d_output
,
at
::
BFloat16
*
gelu_in
,
int
in_features
,
int
batch_size
,
int
out_features
,
int
heuristic
,
at
::
BFloat16
*
d_
input
,
at
::
BFloat16
*
d_bias
,
void
*
lt_workspace
);
\ No newline at end of file
flash_attn/layers/patch_embed.py
View file @
e68ebbe8
...
...
@@ -10,9 +10,9 @@ from torch.nn.modules.utils import _pair
from
einops
import
rearrange
try
:
from
flash_attn.ops.fused_dense
import
FusedDense
TD
from
flash_attn.ops.fused_dense
import
FusedDense
except
ImportError
:
FusedDense
TD
=
None
FusedDense
=
None
class
PatchEmbed
(
nn
.
Module
):
...
...
@@ -37,10 +37,10 @@ class PatchEmbed(nn.Module):
self
.
grid_size
=
(
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
])
self
.
num_patches
=
self
.
grid_size
[
0
]
*
self
.
grid_size
[
1
]
self
.
flatten
=
flatten
if
fused_bias_fc
and
FusedDense
TD
is
None
:
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
or
not
bias
else
FusedDense
TD
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
or
not
bias
else
FusedDense
self
.
proj
=
linear_cls
(
in_chans
*
patch_size
[
0
]
*
patch_size
[
1
],
embed_dim
,
bias
=
bias
)
self
.
norm
=
norm_layer
(
embed_dim
)
if
norm_layer
else
nn
.
Identity
()
...
...
flash_attn/models/bert.py
View file @
e68ebbe8
...
...
@@ -30,9 +30,9 @@ from flash_attn.bert_padding import unpad_input, pad_input
from
flash_attn.bert_padding
import
index_first_axis
,
index_first_axis_residual
try
:
from
flash_attn.ops.fused_dense
import
FusedDense
TD
from
flash_attn.ops.fused_dense
import
FusedDense
except
ImportError
:
FusedDense
TD
=
None
FusedDense
=
None
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm
,
layer_norm
...
...
@@ -70,6 +70,8 @@ def create_mlp_cls(config, layer_idx=None, return_residual=False):
activation
=
partial
(
F
.
gelu
,
approximate
=
approximate
),
return_residual
=
return_residual
)
else
:
if
FusedDenseGeluDense
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
mlp_checkpoint_lvl
=
getattr
(
config
,
'mlp_checkpoint_lvl'
,
0
)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
if
isinstance
(
mlp_checkpoint_lvl
,
Sequence
):
...
...
@@ -168,9 +170,9 @@ class BertPooler(nn.Module):
def
__init__
(
self
,
config
):
super
().
__init__
()
fused_bias_fc
=
getattr
(
config
,
'fused_bias_fc'
,
False
)
if
fused_bias_fc
and
FusedDense
TD
is
None
:
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
TD
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
self
.
dense
=
linear_cls
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
activation
=
nn
.
Tanh
()
...
...
@@ -188,12 +190,12 @@ class BertPredictionHeadTransform(nn.Module):
def
__init__
(
self
,
config
):
super
().
__init__
()
fused_bias_fc
=
getattr
(
config
,
'fused_bias_fc'
,
False
)
if
fused_bias_fc
and
FusedDense
TD
is
None
:
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
self
.
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
)
if
self
.
fused_dropout_add_ln
and
layer_norm
is
None
:
raise
ImportError
(
'dropout_add_layer_norm is not installed'
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
TD
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
self
.
dense
=
linear_cls
(
config
.
hidden_size
,
config
.
hidden_size
)
approximate
=
'tanh'
if
config
.
hidden_act
in
[
'gelu_new'
,
'gelu_fast'
]
else
'none'
self
.
transform_act_fn
=
nn
.
GELU
(
approximate
=
approximate
)
...
...
@@ -215,9 +217,9 @@ class BertLMPredictionHead(nn.Module):
def
__init__
(
self
,
config
):
super
().
__init__
()
fused_bias_fc
=
getattr
(
config
,
'fused_bias_fc'
,
False
)
if
fused_bias_fc
and
FusedDense
TD
is
None
:
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
TD
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
self
.
transform
=
BertPredictionHeadTransform
(
config
)
...
...
flash_attn/models/gpt.py
View file @
e68ebbe8
...
...
@@ -61,6 +61,8 @@ def create_mlp_cls(config, layer_idx=None):
assert
layer_idx
is
not
None
mlp_checkpoint_lvl
=
mlp_checkpoint_lvl
[
layer_idx
]
if
fused_dense_gelu_dense
:
if
FusedDenseGeluDense
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
mlp_cls
=
partial
(
FusedDenseGeluDense
,
hidden_features
=
inner_dim
,
checkpoint_lvl
=
mlp_checkpoint_lvl
)
elif
fused_dense_sqrelu_dense
:
...
...
flash_attn/modules/mha.py
View file @
e68ebbe8
...
...
@@ -21,9 +21,9 @@ except ImportError:
flash_attn_qkvpacked_func
,
flash_attn_kvpacked_func
=
None
,
None
try
:
from
flash_attn.ops.fused_dense
import
FusedDense
TD
,
FusedDenseResidual
from
flash_attn.ops.fused_dense
import
FusedDense
except
ImportError
:
FusedDense
TD
,
FusedDenseResidual
=
None
,
None
FusedDense
=
None
try
:
from
flash_attn.layers.rotary
import
RotaryEmbedding
...
...
@@ -270,7 +270,7 @@ class CrossAttention(nn.Module):
class
LinearResidual
(
nn
.
Linear
):
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense
Residual
.
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.
"""
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -311,10 +311,11 @@ class MHA(nn.Module):
assert
RotaryEmbedding
is
not
None
,
'rotary_emb is not installed'
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
scale_base
=
rotary_emb_scale_base
)
if
fused_bias_fc
and
FusedDense
TD
is
None
:
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDenseTD
linear_resid_cls
=
LinearResidual
if
not
fused_bias_fc
else
FusedDenseResidual
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
linear_resid_cls
=
(
LinearResidual
if
not
fused_bias_fc
else
partial
(
FusedDense
,
return_residual
=
True
))
if
not
self
.
cross_attn
:
if
not
self
.
return_residual
:
self
.
Wqkv
=
linear_cls
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
...
...
flash_attn/modules/mlp.py
View file @
e68ebbe8
...
...
@@ -5,11 +5,9 @@ import torch.nn as nn
import
torch.nn.functional
as
F
try
:
from
flash_attn.ops.fused_dense
import
fused_dense_gelu_dense_function_td
from
flash_attn.ops.fused_dense
import
fused_dense_res_gelu_dense_function_td
from
flash_attn.ops.fused_dense
import
FusedDenseGeluDense
except
ImportError
:
fused_dense_gelu_dense_function_td
=
None
fused_dense_res_gelu_dense_function_td
=
None
FusedDenseGeluDense
=
None
class
Mlp
(
nn
.
Module
):
...
...
@@ -30,43 +28,3 @@ class Mlp(nn.Module):
y
=
self
.
activation
(
y
)
y
=
self
.
fc2
(
y
)
return
y
if
not
self
.
return_residual
else
(
y
,
x
)
class
FusedDenseGeluDense
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
bias
=
True
,
checkpoint_lvl
=
0
,
heuristic
=
0
,
return_residual
=
False
,
device
=
None
,
dtype
=
None
):
"""
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
heuristic:
-1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu
For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
assert
bias
==
True
,
"DenseGeluDense module without bias is currently not supported"
assert
(
fused_dense_gelu_dense_function_td
is
not
None
and
fused_dense_res_gelu_dense_function_td
is
not
None
),
'fused_dense_lib is not installed'
self
.
checkpoint_lvl
=
checkpoint_lvl
self
.
heuristic
=
heuristic
self
.
return_residual
=
return_residual
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
,
bias
=
bias
,
**
factory_kwargs
)
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
assert
x
.
is_cuda
fn
=
(
fused_dense_gelu_dense_function_td
if
not
self
.
return_residual
else
fused_dense_res_gelu_dense_function_td
)
return
fn
(
x
,
self
.
fc1
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
weight
,
self
.
fc2
.
bias
,
self
.
checkpoint_lvl
,
self
.
heuristic
)
flash_attn/ops/fused_dense.py
View file @
e68ebbe8
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
# We make it work with pytorch amp and with bfloat16.
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
# import fused_dense_cuda # from apex
...
...
@@ -11,126 +13,84 @@ import fused_dense_lib as fused_dense_cuda
from
flash_attn.ops.gelu_activation
import
gelu_bwd
# implements fused GEMM+bias in forward pass using mlp_cuda from apex
class
FusedDenseFuncTD
(
torch
.
autograd
.
Function
):
class
FusedDenseFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
x
,
weight
,
bias
):
def
forward
(
ctx
,
x
,
weight
,
bias
,
return_residual
=
False
):
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_gpu_dtype
()
x
,
weight
,
bias
=
[
a
.
to
(
dtype
=
dtype
)
for
a
in
[
x
,
weight
,
bias
]]
x
,
weight
=
[
a
.
to
(
dtype
=
dtype
)
for
a
in
[
x
,
weight
]]
bias
=
bias
.
to
(
dtype
=
dtype
)
if
bias
is
not
None
else
None
ctx
.
return_residual
=
return_residual
x
=
x
.
contiguous
()
weight
=
weight
.
contiguous
()
bias
=
bias
.
contiguous
()
ctx
.
save_for_backward
(
x
,
weight
)
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
assert
batch_dim
<=
64
*
1024
,
'fused_dense only supports dimension at most 64k'
output
=
fused_dense_cuda
.
linear_bias_forward
(
x
.
reshape
(
batch_dim
,
n
)
,
weight
,
bias
)
return
output
.
reshape
(
*
batch_shape
,
output
.
shape
[
-
1
]
)
output
=
F
.
linear
(
x
,
weight
,
bias
)
return
output
if
not
return_residual
else
(
output
,
x
)
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
,
*
args
):
grad_output
=
grad_output
.
contiguous
()
if
ctx
.
return_residual
:
grad_input
,
=
args
grad_input
=
grad_input
.
contiguous
()
x
,
weight
=
ctx
.
saved_tensors
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
if
ctx
.
needs_input_grad
[
0
]:
grad_input
,
grad_weight
,
grad_bias
=
fused_dense_cuda
.
linear_bias_backward
(
x
.
reshape
(
batch_dim
,
n
),
weight
,
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
if
ctx
.
needs_input_grad
[
1
]:
grad_weight
,
grad_bias
=
fused_dense_cuda
.
linear_bias_wgrad
(
x
.
reshape
(
batch_dim
,
n
),
grad_output
,
ctx
.
needs_input_grad
[
2
]
)
else
:
grad_weight
=
None
grad_bias
=
grad_output
if
ctx
.
needs_input_grad
[
2
]
else
None
if
ctx
.
needs_input_grad
[
0
]:
if
not
ctx
.
return_residual
:
grad_input
=
F
.
linear
(
grad_output
,
weight
.
t
())
else
:
grad_input
=
torch
.
addmm
(
grad_input
.
reshape
(
batch_dim
,
n
),
grad_output
,
weight
)
grad_input
=
grad_input
.
reshape_as
(
x
)
else
:
grad_weight
,
grad_bias
=
fused_dense_cuda
.
linear_bias_wgrad
(
x
.
reshape
(
batch_dim
,
n
),
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
)
grad_input
=
None
# print((grad_bias - grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)).abs().max())
return
grad_input
,
grad_weight
,
grad_bias
# grad_input, grad_weight = None, None
# grad_output_reshaped = grad_output.reshape(batch_dim, grad_output.shape[-1])
# if ctx.needs_input_grad[0]:
# grad_input = (grad_output_reshaped @ weight.conj()).reshape(*batch_shape, n)
# if ctx.needs_input_grad[1]:
# grad_weight = grad_output_reshaped.t() @ x.conj().reshape(batch_dim, n)
# # We don't need to compute grad_bias explicitly, when we return grad_out Pytorch
# # will sum over the batch dimension to get grad_bias.
# return grad_input, grad_weight, grad_output
return
grad_input
,
grad_weight
,
grad_bias
,
None
fused_dense_function_td
=
FusedDenseFuncTD
.
apply
def
fused_dense_func
(
x
:
Tensor
,
weight
:
Tensor
,
bias
:
Optional
[
Tensor
]
=
None
,
return_residual
:
bool
=
False
):
batch_dim
=
x
.
shape
[:
-
1
].
numel
()
dtype_eligible
=
(
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
or
(
x
.
dtype
==
torch
.
float32
and
torch
.
is_autocast_enabled
()))
if
(
x
.
is_cuda
and
weight
.
is_cuda
and
(
bias
is
None
or
bias
.
is_cuda
)
and
batch_dim
<=
64
*
1024
and
dtype_eligible
):
return
FusedDenseFunc
.
apply
(
x
,
weight
,
bias
,
return_residual
)
else
:
out
=
F
.
linear
(
x
,
weight
,
bias
)
return
out
if
not
return_residual
else
(
out
,
x
)
class
FusedDense
TD
(
nn
.
Linear
):
class
FusedDense
(
nn
.
Linear
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
bias
:
bool
=
True
,
device
=
None
,
dtype
=
None
)
->
None
:
return_residual
:
bool
=
False
,
device
=
None
,
dtype
=
None
)
->
None
:
super
().
__init__
(
in_features
,
out_features
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
)
self
.
return_residual
=
return_residual
def
forward
(
self
,
x
):
if
x
.
is_cuda
and
self
.
bias
is
not
None
:
return
fused_dense_function_td
(
x
,
self
.
weight
,
self
.
bias
)
else
:
return
F
.
linear
(
x
,
self
.
weight
,
self
.
bias
)
return
fused_dense_func
(
x
,
self
.
weight
,
self
.
bias
,
return_residual
=
self
.
return_residual
)
class
FusedDense
Residual
Func
(
torch
.
autograd
.
Function
):
class
FusedDense
GeluDense
Func
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
x
,
weight
,
bias
):
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_gpu_dtype
()
x
,
weight
,
bias
=
[
a
.
to
(
dtype
=
dtype
)
for
a
in
[
x
,
weight
,
bias
]]
x
=
x
.
contiguous
()
x
=
x
.
contiguous
()
weight
=
weight
.
contiguous
()
bias
=
bias
.
contiguous
()
ctx
.
save_for_backward
(
x
,
weight
)
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
assert
batch_dim
<=
64
*
1024
,
'fused_dense only supports dimension at most 64k'
output
=
fused_dense_cuda
.
linear_bias_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight
,
bias
)
return
output
.
reshape
(
*
batch_shape
,
output
.
shape
[
-
1
]),
x
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
grad_output
,
grad_input
):
grad_output
=
grad_output
.
contiguous
()
grad_input
=
grad_input
.
contiguous
()
x
,
weight
=
ctx
.
saved_tensors
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
grad_input
,
grad_weight
,
grad_bias
=
fused_dense_cuda
.
linear_bias_residual_backward
(
x
.
reshape
(
batch_dim
,
n
),
weight
,
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
]),
grad_input
.
reshape
(
batch_dim
,
n
)
)
return
grad_input
.
reshape_as
(
x
),
grad_weight
,
grad_bias
fused_dense_residual_function
=
FusedDenseResidualFunc
.
apply
class
FusedDenseResidual
(
nn
.
Linear
):
"""Similar to FusedDense, but we return both the output and the input.
This is so that in the backward pass, we can combine the input gradient from the residual branch
with the input gradient from the matrix multiply, without having to do a separate addition.
"""
def
forward
(
self
,
x
):
if
x
.
is_cuda
and
self
.
bias
is
not
None
:
return
fused_dense_residual_function
(
x
,
self
.
weight
,
self
.
bias
)
else
:
return
F
.
linear
(
x
,
self
.
weight
,
self
.
bias
),
x
class
FusedDenseGeluDenseFuncTD
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
x
,
weight1
,
bias1
,
weight2
,
bias2
,
checkpoint_lvl
=
0
,
heuristic
=
0
):
def
forward
(
ctx
,
x
,
weight1
,
bias1
,
weight2
,
bias2
,
save_gelu_in
=
True
,
return_residual
=
False
,
checkpoint_lvl
=
0
,
heuristic
=
0
):
"""checkpoint_lvl:
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
...
...
@@ -139,49 +99,53 @@ class FusedDenseGeluDenseFuncTD(torch.autograd.Function):
assert
-
1
<=
heuristic
<=
4
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_gpu_dtype
()
x
,
weight1
,
bias1
,
weight2
,
bias2
=
[
a
.
to
(
dtype
=
dtype
)
for
a
in
[
x
,
weight1
,
bias1
,
weight2
,
bias2
]]
x
,
weight1
,
weight2
=
[
a
.
to
(
dtype
=
dtype
)
for
a
in
[
x
,
weight1
,
weight2
]]
bias1
=
bias1
.
to
(
dtype
=
dtype
)
if
bias1
is
not
None
else
None
bias2
=
bias2
.
to
(
dtype
=
dtype
)
if
bias2
is
not
None
else
None
if
not
save_gelu_in
:
checkpoint_lvl
=
2
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
ctx
.
return_residual
=
return_residual
x
=
x
.
contiguous
()
weight1
=
weight1
.
contiguous
()
bias1
=
bias1
.
contiguous
()
bias1
=
bias1
.
contiguous
()
if
bias1
is
not
None
else
None
weight2
=
weight2
.
contiguous
()
bias2
=
bias2
.
contiguous
()
bias2
=
bias2
.
contiguous
()
if
bias2
is
not
None
else
None
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
assert
batch_dim
<=
64
*
1024
,
'fused_dense only supports dimension at most 64k'
# output1, output2, gelu_in = fused_dense_cuda.linear_gelu_linear_forward(
# x.reshape(batch_dim, n), weight1, bias1, weight2, bias2
# )
if
heuristic
==
-
1
:
gelu_in
=
fused_dense_cuda
.
linear_bias_forward
(
x
.
reshape
(
batch_dim
,
n
)
,
weight1
,
bias1
)
gelu_in
=
F
.
linear
(
x
,
weight1
,
bias1
)
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
# gelu_in = F.linear(x.reshape(batch_dim, n), weight1) # This is before adding bias1
# with torch.jit.fuser('fuser2'):
# output1 = bias_gelu(gelu_in, bias1)
else
:
save_gelu_in
=
checkpoint_lvl
!=
2
output1
,
*
rest
=
fused_dense_cuda
.
linear_gelu_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
save_gelu_in
,
heuristic
)
if
save_gelu_in
:
gelu_in
=
rest
[
0
]
output2
=
fused_dense_cuda
.
linear_bias_forward
(
output1
,
weight2
,
bias2
)
output2
=
F
.
linear
(
output1
,
weight2
,
bias2
)
ctx
.
checkpoint_lvl
=
checkpoint_lvl
ctx
.
heuristic
=
heuristic
if
checkpoint_lvl
==
0
:
ctx
.
save_for_backward
(
x
,
weight1
,
bias1
,
weight2
,
gelu_in
,
output1
)
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
gelu_in
,
output1
)
elif
checkpoint_lvl
==
1
:
ctx
.
save_for_backward
(
x
,
weight1
,
bias1
,
weight2
,
gelu_in
)
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
gelu_in
)
elif
checkpoint_lvl
==
2
:
ctx
.
save_for_backward
(
x
,
weight1
,
bias1
,
weight2
)
return
output2
.
reshape
(
*
batch_shape
,
output2
.
shape
[
-
1
])
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
bias1
)
output2
=
output2
.
reshape
(
*
batch_shape
,
output2
.
shape
[
-
1
])
return
output2
if
not
return_residual
else
(
output2
,
x
)
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
,
*
args
):
grad_output
=
grad_output
.
contiguous
()
checkpoint_lvl
=
ctx
.
checkpoint_lvl
x
,
weight1
,
bias1
,
weight2
,
*
rest
=
ctx
.
saved_tensors
if
ctx
.
return_residual
:
grad_input
,
=
args
grad_input
=
grad_input
.
contiguous
()
x
,
weight1
,
weight2
,
*
rest
=
ctx
.
saved_tensors
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
if
checkpoint_lvl
==
0
:
...
...
@@ -190,55 +154,88 @@ class FusedDenseGeluDenseFuncTD(torch.autograd.Function):
gelu_in
,
=
rest
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
elif
checkpoint_lvl
==
2
:
#
bias1, = rest
bias1
,
=
rest
if
ctx
.
heuristic
==
-
1
:
gelu_in
=
fused_dense_cuda
.
linear_bias_forward
(
x
.
reshape
(
batch_dim
,
n
)
,
weight1
,
bias1
)
gelu_in
=
F
.
linear
(
x
,
weight1
,
bias1
)
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
else
:
output1
,
gelu_in
=
fused_dense_cuda
.
linear_gelu_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
True
,
ctx
.
heuristic
)
output1
,
gelu_in
=
fused_dense_cuda
.
linear_gelu_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
True
,
ctx
.
heuristic
)
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
output1
=
output1
.
reshape
(
batch_dim
,
output1
.
shape
[
-
1
])
gelu_in
=
gelu_in
.
reshape
(
batch_dim
,
gelu_in
.
shape
[
-
1
])
if
ctx
.
needs_input_grad
[
3
]:
grad_weight2
,
grad_bias2
=
fused_dense_cuda
.
linear_bias_wgrad
(
output1
,
grad_output
,
ctx
.
needs_input_grad
[
4
]
)
else
:
grad_weight2
=
None
grad_bias2
=
grad_output
if
ctx
.
needs_input_grad
[
4
]
else
None
if
ctx
.
heuristic
==
-
1
:
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
# grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output)
grad_weight2
,
grad_bias2
=
fused_dense_cuda
.
linear_bias_wgrad
(
output1
,
grad_output
)
# grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
grad_output1
=
grad_output
@
weight2
grad_output1
=
F
.
linear
(
grad_output
,
weight2
.
t
())
with
torch
.
jit
.
fuser
(
'fuser2'
):
grad_gelu
=
gelu_bwd
(
grad_output1
,
gelu_in
)
grad_input
,
grad_weight1
,
grad_bias1
=
fused_dense_cuda
.
linear_bias_backward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
grad_gelu
)
# with torch.jit.fuser('fuser2'):
# grad_gelu, grad_bias1 = bias_gelu_back(grad_output1, gelu_in, bias1)
# grad_input = grad_gelu @ weight1
# grad_weight1 = grad_gelu.reshape(batch_dim, -1).T @ x.reshape(batch_dim, n)
# grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
# x.reshape(batch_dim, n), weight1, grad_gelu
# )
if
ctx
.
needs_input_grad
[
1
]:
grad_weight1
,
grad_bias1
=
fused_dense_cuda
.
linear_bias_wgrad
(
x
.
reshape
(
batch_dim
,
n
),
grad_gelu
,
ctx
.
needs_input_grad
[
2
]
)
else
:
grad_weight1
=
None
grad_bias1
=
grad_gelu
if
ctx
.
needs_input_grad
[
2
]
else
None
else
:
grad_input
,
grad_weight1
,
grad_
bias
1
,
grad
_weight2
,
grad_bias2
=
fused_dense_cuda
.
linear_gelu_linear_backward
(
x
.
reshape
(
batch_dim
,
n
),
gelu_in
,
output1
,
weight1
,
weight2
,
grad_
output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
]),
ctx
.
heuristic
# The cublasLt epilogue has to compute both gelu grad and
bias grad
, we can't
# just compute gelu grad
grad_
gelu
,
grad_bias1
=
fused_dense_cuda
.
bias_gelu_linear_dgrad_bgrad
(
weight2
,
grad_output
,
gelu_in
,
ctx
.
heuristic
)
# grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
# # grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output)
# grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
# grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
# grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
# x.reshape(batch_dim, n), weight1, grad_gelu
# )
return
grad_input
.
reshape_as
(
x
),
grad_weight1
,
grad_bias1
,
grad_weight2
,
grad_bias2
,
None
,
None
fused_dense_gelu_dense_function_td
=
FusedDenseGeluDenseFuncTD
.
apply
if
not
ctx
.
needs_input_grad
[
2
]:
grad_bias1
=
None
if
ctx
.
needs_input_grad
[
1
]:
grad_weight1
=
F
.
linear
(
grad_gelu
.
t
(),
x
.
reshape
(
batch_dim
,
n
).
t
())
else
:
grad_weight1
=
None
if
ctx
.
needs_input_grad
[
0
]:
if
not
ctx
.
return_residual
:
grad_input
=
F
.
linear
(
grad_gelu
,
weight1
.
t
())
else
:
grad_input
=
torch
.
addmm
(
grad_input
.
reshape
(
batch_dim
,
n
),
grad_gelu
,
weight1
)
grad_input
=
grad_input
.
reshape_as
(
x
)
else
:
grad_input
=
None
return
grad_input
,
grad_weight1
,
grad_bias1
,
grad_weight2
,
grad_bias2
,
None
,
None
,
None
,
None
def
fused_dense_gelu_dense_func
(
x
:
Tensor
,
weight1
:
Tensor
,
weight2
:
Tensor
,
bias1
:
Optional
[
Tensor
]
=
None
,
bias2
:
Optional
[
Tensor
]
=
None
,
save_gelu_in
:
bool
=
True
,
return_residual
:
bool
=
False
,
checkpoint_lvl
:
int
=
0
,
heuristic
:
int
=
0
):
batch_dim
=
x
.
shape
[:
-
1
].
numel
()
dtype_eligible
=
(
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
or
(
x
.
dtype
==
torch
.
float32
and
torch
.
is_autocast_enabled
()))
if
(
x
.
is_cuda
and
weight1
.
is_cuda
and
weight2
.
is_cuda
and
(
bias1
is
None
or
bias1
.
is_cuda
)
and
(
bias2
is
None
or
bias2
.
is_cuda
)
and
batch_dim
<=
64
*
1024
and
dtype_eligible
):
return
FusedDenseGeluDenseFunc
.
apply
(
x
,
weight1
,
bias1
,
weight2
,
bias2
,
save_gelu_in
,
return_residual
,
checkpoint_lvl
,
heuristic
)
else
:
gelu_in
=
F
.
linear
(
x
,
weight1
,
bias1
)
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
output2
=
F
.
linear
(
output1
,
weight2
,
bias2
)
return
output2
if
not
return_residual
else
(
output2
,
x
)
class
FusedDenseGeluDense
TD
(
nn
.
Module
):
class
FusedDenseGeluDense
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
intermediate_features
,
out_features
=
None
,
bias
=
True
,
checkpoint_lvl
=
0
,
heuristic
=
0
,
device
=
None
,
dtype
=
None
):
def
__init__
(
self
,
in_features
,
hidden_features
,
out_features
=
None
,
bias1
=
True
,
bias2
=
True
,
return_residual
=
False
,
checkpoint_lvl
=
0
,
heuristic
=
0
,
device
=
None
,
dtype
=
None
):
"""
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
...
...
@@ -247,110 +244,26 @@ class FusedDenseGeluDenseTD(nn.Module):
heuristic:
-1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu
For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
if
out_features
is
None
:
out_features
=
in_features
assert
bias
==
True
,
"DenseGeluDense module without bias is currently not supported"
self
.
return_residual
=
return_residual
self
.
checkpoint_lvl
=
checkpoint_lvl
self
.
heuristic
=
heuristic
self
.
fc1
=
nn
.
Linear
(
in_features
,
intermediate
_features
,
bias
=
bias
,
**
factory_kwargs
)
self
.
fc2
=
nn
.
Linear
(
intermediate
_features
,
out_features
,
bias
=
bias
,
**
factory_kwargs
)
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden
_features
,
bias
=
bias
1
,
**
factory_kwargs
)
self
.
fc2
=
nn
.
Linear
(
hidden
_features
,
out_features
,
bias
=
bias
2
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
return
fused_dense_gelu_dense_function_td
(
x
,
self
.
fc1
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
weight
,
self
.
fc2
.
bias
,
self
.
checkpoint_lvl
,
self
.
heuristic
)
class
FusedDenseResGeluDenseFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
x
,
weight1
,
bias1
,
weight2
,
bias2
,
checkpoint_lvl
=
0
,
heuristic
=
0
):
"""checkpoint_lvl:
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
"""
assert
-
1
<=
heuristic
<=
4
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_gpu_dtype
()
x
,
weight1
,
bias1
,
weight2
,
bias2
=
[
a
.
to
(
dtype
=
dtype
)
for
a
in
[
x
,
weight1
,
bias1
,
weight2
,
bias2
]]
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
x
=
x
.
contiguous
()
weight1
=
weight1
.
contiguous
()
bias1
=
bias1
.
contiguous
()
weight2
=
weight2
.
contiguous
()
bias2
=
bias2
.
contiguous
()
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
assert
batch_dim
<=
64
*
1024
,
'fused_dense only supports dimension at most 64k'
# output1, output2, gelu_in = fused_dense_cuda.linear_gelu_linear_forward(
# x.reshape(batch_dim, n), weight1, bias1, weight2, bias2
# )
# gelu_in = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
# output1 = F.gelu(gelu_in, approximate='tanh')
save_gelu_in
=
checkpoint_lvl
!=
2
output1
,
*
rest
=
fused_dense_cuda
.
linear_gelu_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
save_gelu_in
,
heuristic
)
if
save_gelu_in
:
gelu_in
=
rest
[
0
]
output2
=
fused_dense_cuda
.
linear_bias_forward
(
output1
,
weight2
,
bias2
)
ctx
.
checkpoint_lvl
=
checkpoint_lvl
ctx
.
heuristic
=
heuristic
if
checkpoint_lvl
==
0
:
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
gelu_in
,
output1
)
elif
checkpoint_lvl
==
1
:
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
gelu_in
)
elif
checkpoint_lvl
==
2
:
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
bias1
)
return
output2
.
reshape
(
*
batch_shape
,
output2
.
shape
[
-
1
]),
x
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
grad_output
,
grad_input
):
grad_output
=
grad_output
.
contiguous
()
grad_input
=
grad_input
.
contiguous
()
checkpoint_lvl
=
ctx
.
checkpoint_lvl
x
,
weight1
,
weight2
,
*
rest
=
ctx
.
saved_tensors
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
if
checkpoint_lvl
==
0
:
gelu_in
,
output1
=
rest
elif
checkpoint_lvl
==
1
:
gelu_in
,
=
rest
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
elif
checkpoint_lvl
==
2
:
bias1
,
=
rest
output1
,
gelu_in
=
fused_dense_cuda
.
linear_gelu_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
True
,
ctx
.
heuristic
)
grad_input
,
grad_weight1
,
grad_bias1
,
grad_weight2
,
grad_bias2
=
fused_dense_cuda
.
linear_residual_gelu_linear_backward
(
x
.
reshape
(
batch_dim
,
n
),
gelu_in
,
output1
,
weight1
,
weight2
,
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
]),
grad_input
.
reshape
(
batch_dim
,
n
),
ctx
.
heuristic
return
fused_dense_gelu_dense_func
(
x
,
self
.
fc1
.
weight
,
self
.
fc2
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
bias
,
save_gelu_in
=
self
.
training
,
return_residual
=
self
.
return_residual
,
checkpoint_lvl
=
self
.
checkpoint_lvl
,
heuristic
=
self
.
heuristic
)
# grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
# # grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output)
# grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
# grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
# grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_residual_backward(
# x.reshape(batch_dim, n), weight1, grad_gelu,
# grad_input.reshape(batch_dim, n)
# )
return
grad_input
.
reshape_as
(
x
),
grad_weight1
,
grad_bias1
,
grad_weight2
,
grad_bias2
,
None
,
None
fused_dense_res_gelu_dense_function_td
=
FusedDenseResGeluDenseFunc
.
apply
class
FusedDenseResGeluDense
(
FusedDenseGeluDenseTD
):
def
forward
(
self
,
x
):
return
fused_dense_res_gelu_dense_function_td
(
x
,
self
.
fc1
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
weight
,
self
.
fc2
.
bias
,
self
.
checkpoint_lvl
,
False
,
self
.
heuristic
)
tests/ops/test_fused_dense.py
View file @
e68ebbe8
...
...
@@ -6,29 +6,44 @@ import pytest
from
einops
import
rearrange
from
flash_attn.ops.fused_dense
import
FusedDenseTD
,
FusedDenseGeluDenseTD
from
flash_attn.ops.fused_dense
import
FusedDenseResidual
,
FusedDenseResGeluDense
from
flash_attn.ops.fused_dense
import
FusedDense
,
FusedDenseGeluDense
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'return_residual'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'has_bias'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
,
4096
])
def
test_fused_linear_bias
(
in_features
,
out_features
,
dtype
):
def
test_fused_linear_bias
(
in_features
,
out_features
,
has_bias
,
return_residual
,
dtype
):
device
=
'cuda'
rtol
,
atol
=
(
3e-3
,
1e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
x_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
model
=
FusedDenseTD
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
model_pt
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
bias
=
has_bias
,
device
=
device
,
dtype
=
dtype
)
model
=
FusedDense
(
in_features
,
out_features
,
bias
=
has_bias
,
return_residual
=
return_residual
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
model_pt
.
weight
)
model
.
bias
.
copy_
(
model_pt
.
bias
)
if
has_bias
:
model
.
bias
.
copy_
(
model_pt
.
bias
)
out_pt
=
model_pt
(
x_pt
)
out
=
model
(
x
)
if
not
return_residual
:
out
=
model
(
x
)
else
:
out
,
x_copy
=
model
(
x
)
x_copy
=
(
x_copy
[...,
:
out_features
]
if
out_features
<
in_features
else
F
.
pad
(
x_copy
,
(
0
,
out_features
-
in_features
)))
x_pt_copy
=
(
x_pt
[...,
:
out_features
]
if
out_features
<
in_features
else
F
.
pad
(
x_pt
,
(
0
,
out_features
-
in_features
)))
# Just add some random function of the residual
out_pt
=
out_pt
+
F
.
gelu
(
x_pt_copy
)
out
=
out
+
F
.
gelu
(
x_copy
)
# with torch.no_grad():
# out_fl = F.linear(x_pt.float(), model.weight.float(), model.bias.float()).half()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
atol
)
...
...
@@ -40,66 +55,52 @@ def test_fused_linear_bias(in_features, out_features, dtype):
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
weight
.
grad
,
model_pt
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
bias
.
grad
,
model_pt
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'out_features,in_features'
,
[(
1024
,
1024
),
(
4096
,
4096
)])
def
test_fused_linear_bias_residual
(
in_features
,
out_features
,
dtype
):
device
=
'cuda'
rtol
,
atol
=
(
3e-3
,
1e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
x_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
model
=
FusedDenseResidual
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
model_pt
.
weight
)
model
.
bias
.
copy_
(
model_pt
.
bias
)
out_pt
=
model_pt
(
x_pt
)
+
F
.
gelu
(
x_pt
)
# Just add some random function of the residual x_pt
out
,
x_copy
=
model
(
x
)
out
=
out
+
F
.
gelu
(
x_copy
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
atol
*
2
)
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
out
)
/
32
out_pt
.
backward
(
g
)
out
.
backward
(
g
)
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
weight
.
grad
,
model_pt
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
bias
.
grad
,
model_pt
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
if
has_bias
:
assert
torch
.
allclose
(
model
.
bias
.
grad
,
model_pt
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'heuristic'
,
[
1
,
-
1
])
@
pytest
.
mark
.
parametrize
(
'heuristic'
,
[
0
,
-
1
])
@
pytest
.
mark
.
parametrize
(
'checkpoint_lvl'
,
[
0
,
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'return_residual'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'has_bias2'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'has_bias1'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
,
4096
])
def
test_fused_dense_gelu_dense
(
in_features
,
out_features
,
checkpoint_lvl
,
heuristic
,
dtype
):
def
test_fused_dense_gelu_dense
(
in_features
,
out_features
,
has_bias1
,
has_bias2
,
return_residual
,
checkpoint_lvl
,
heuristic
,
dtype
):
device
=
'cuda'
rtol
,
atol
=
(
3e-3
,
1
e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
rtol
,
atol
=
(
3e-3
,
3
e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
x_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt_fc1
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
model_pt_fc2
=
torch
.
nn
.
Linear
(
out_features
,
in_features
,
device
=
device
,
dtype
=
dtype
)
model
=
FusedDenseGeluDenseTD
(
in_features
,
out_features
,
in_features
,
checkpoint_lvl
=
checkpoint_lvl
,
heuristic
=
heuristic
,
device
=
device
,
dtype
=
dtype
)
model_pt_fc1
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
bias
=
has_bias1
,
device
=
device
,
dtype
=
dtype
)
model_pt_fc2
=
torch
.
nn
.
Linear
(
out_features
,
in_features
,
bias
=
has_bias2
,
device
=
device
,
dtype
=
dtype
)
model
=
FusedDenseGeluDense
(
in_features
,
out_features
,
in_features
,
bias1
=
has_bias1
,
bias2
=
has_bias2
,
return_residual
=
return_residual
,
checkpoint_lvl
=
checkpoint_lvl
,
heuristic
=
heuristic
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
model
.
fc1
.
weight
.
copy_
(
model_pt_fc1
.
weight
)
model
.
fc1
.
bias
.
copy_
(
model_pt_fc1
.
bias
)
if
has_bias1
:
model
.
fc1
.
bias
.
copy_
(
model_pt_fc1
.
bias
)
model
.
fc2
.
weight
.
copy_
(
model_pt_fc2
.
weight
)
model
.
fc2
.
bias
.
copy_
(
model_pt_fc2
.
bias
)
if
has_bias2
:
model
.
fc2
.
bias
.
copy_
(
model_pt_fc2
.
bias
)
out_pt
=
model_pt_fc2
(
F
.
gelu
(
model_pt_fc1
(
x_pt
),
approximate
=
'tanh'
))
out
=
model
(
x
)
if
not
return_residual
:
out
=
model
(
x
)
else
:
out
,
x_copy
=
model
(
x
)
# Just add some random function of the residual
out_pt
=
out_pt
+
F
.
gelu
(
x_pt
)
out
=
out
+
F
.
gelu
(
x_copy
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
atol
)
# If we don't divide by batch_size, the gradient gets a bit too large.
...
...
@@ -109,46 +110,8 @@ def test_fused_dense_gelu_dense(in_features, out_features, checkpoint_lvl, heuri
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
fc1
.
weight
.
grad
,
model_pt_fc1
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
fc1
.
bias
.
grad
,
model_pt_fc1
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
fc2
.
weight
.
grad
,
model_pt_fc2
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
fc2
.
bias
.
grad
,
model_pt_fc2
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'checkpoint_lvl'
,
[
0
,
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
,
4096
])
def
test_fused_dense_residual_gelu_dense
(
in_features
,
out_features
,
checkpoint_lvl
,
dtype
):
device
=
'cuda'
rtol
,
atol
=
(
3e-3
,
1e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
x_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt_fc1
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
model_pt_fc2
=
torch
.
nn
.
Linear
(
out_features
,
in_features
,
device
=
device
,
dtype
=
dtype
)
model
=
FusedDenseResGeluDense
(
in_features
,
out_features
,
in_features
,
checkpoint_lvl
=
checkpoint_lvl
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
model
.
fc1
.
weight
.
copy_
(
model_pt_fc1
.
weight
)
model
.
fc1
.
bias
.
copy_
(
model_pt_fc1
.
bias
)
model
.
fc2
.
weight
.
copy_
(
model_pt_fc2
.
weight
)
model
.
fc2
.
bias
.
copy_
(
model_pt_fc2
.
bias
)
out_pt
=
model_pt_fc2
(
F
.
gelu
(
model_pt_fc1
(
x_pt
),
approximate
=
'tanh'
))
+
F
.
gelu
(
x_pt
)
out
,
x_copy
=
model
(
x
)
out
=
out
+
F
.
gelu
(
x_copy
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
atol
*
2
)
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
out
)
/
32
out_pt
.
backward
(
g
)
out
.
backward
(
g
)
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
fc1
.
weight
.
grad
,
model_pt_fc1
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
fc1
.
bias
.
grad
,
model_pt_fc1
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
if
has_bias1
:
assert
torch
.
allclose
(
model
.
fc1
.
bias
.
grad
,
model_pt_fc1
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
fc2
.
weight
.
grad
,
model_pt_fc2
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
fc2
.
bias
.
grad
,
model_pt_fc2
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
if
has_bias2
:
assert
torch
.
allclose
(
model
.
fc2
.
bias
.
grad
,
model_pt_fc2
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment