Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
e68ebbe8
"tools/vscode:/vscode.git/clone" did not exist on "ee716e43466a62dd3a671cbf768f01ecc118a8aa"
Commit
e68ebbe8
authored
Dec 22, 2022
by
Tri Dao
Browse files
Simplify FusedDense
parent
1bc6e5b0
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
318 additions
and
1063 deletions
+318
-1063
csrc/fused_dense_lib/fused_dense.cpp
csrc/fused_dense_lib/fused_dense.cpp
+72
-243
csrc/fused_dense_lib/fused_dense_cuda.cu
csrc/fused_dense_lib/fused_dense_cuda.cu
+18
-431
flash_attn/layers/patch_embed.py
flash_attn/layers/patch_embed.py
+4
-4
flash_attn/models/bert.py
flash_attn/models/bert.py
+10
-8
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+2
-0
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+7
-6
flash_attn/modules/mlp.py
flash_attn/modules/mlp.py
+2
-44
flash_attn/ops/fused_dense.py
flash_attn/ops/fused_dense.py
+146
-233
tests/ops/test_fused_dense.py
tests/ops/test_fused_dense.py
+57
-94
No files found.
csrc/fused_dense_lib/fused_dense.cpp
View file @
e68ebbe8
...
@@ -6,6 +6,8 @@
...
@@ -6,6 +6,8 @@
#include <stdio.h>
#include <stdio.h>
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h
// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \
#define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \
...
@@ -24,14 +26,6 @@
...
@@ -24,14 +26,6 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
}
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
template
<
typename
T
>
int
linear_bias_forward_cuda
(
at
::
Tensor
input
,
T
*
weight
,
at
::
Tensor
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
at
::
Tensor
output
,
void
*
lt_workspace
);
template
<
typename
T
>
int
linear_bias_backward_cuda
(
T
*
input
,
T
*
weight
,
T
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
T
*
d_weight
,
T
*
d_bias
,
T
*
d_input
,
bool
residual
,
void
*
lt_workspace
);
template
<
typename
T
>
template
<
typename
T
>
int
linear_bias_wgrad_cuda
(
T
*
input
,
T
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
T
*
d_weight
,
T
*
d_bias
,
void
*
lt_workspace
);
int
linear_bias_wgrad_cuda
(
T
*
input
,
T
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
T
*
d_weight
,
T
*
d_bias
,
void
*
lt_workspace
);
...
@@ -39,103 +33,34 @@ template <typename T>
...
@@ -39,103 +33,34 @@ template <typename T>
int
linear_gelu_forward_cuda
(
T
*
input
,
T
*
weight
,
T
*
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
int
heuristic
,
T
*
output
,
T
*
gelu_in
,
void
*
lt_workspace
)
;
int
linear_gelu_forward_cuda
(
T
*
input
,
T
*
weight
,
T
*
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
int
heuristic
,
T
*
output
,
T
*
gelu_in
,
void
*
lt_workspace
)
;
template
<
typename
T
>
template
<
typename
T
>
int
linear_gelu_linear_backward_cuda
(
T
*
input
,
T
*
gelu_in
,
T
*
output1
,
T
*
weight1
,
T
*
weight2
,
T
*
d_output1
,
T
*
d_output2
,
int
in_features
,
int
batch_size
,
int
hidden_features
,
int
out_features
,
int
heuristic
,
T
*
d_weight1
,
T
*
d_weight2
,
T
*
d_bias1
,
T
*
d_bias2
,
T
*
d_input
,
bool
residual
,
void
*
lt_workspace
);
int
bias_gelu_linear_dgrad_bgrad_cuda
(
T
*
weight
,
T
*
d_output
,
T
*
gelu_in
,
int
in_features
,
int
batch_size
,
int
out_features
,
int
heuristic
,
T
*
d_input
,
T
*
d_bias
,
void
*
lt_workspace
);
at
::
Tensor
linear_bias_forward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
at
::
Tensor
bias
)
{
auto
batch_size
=
input
.
size
(
0
);
auto
in_features
=
input
.
size
(
1
);
int
out_features
=
weight
.
size
(
0
);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto
out
=
at
::
empty
({
batch_size
,
out_features
},
at
::
dtype
(
input
.
dtype
()).
device
(
input
.
device
()));
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
at
::
dtype
(
input
.
dtype
()).
device
(
input
.
device
()));
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_bias_forward"
,
[
&
]
{
scalar_t
*
w_ptr
=
weight
.
data_ptr
<
scalar_t
>
();
auto
result
=
linear_bias_forward_cuda
<
scalar_t
>
(
input
,
w_ptr
,
bias
,
in_features
,
batch_size
,
out_features
,
out
,
//out.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
TORCH_CHECK
(
result
==
0
,
"linear_bias_forward failed."
)
});
return
{
out
};
}
std
::
vector
<
at
::
Tensor
>
linear_bias_backward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
at
::
Tensor
d_output
)
{
auto
batch_size
=
input
.
size
(
0
);
auto
in_features
=
input
.
size
(
1
);
int
out_features
=
weight
.
size
(
0
);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
std
::
vector
<
at
::
Tensor
>
linear_bias_wgrad
(
at
::
Tensor
input
,
at
::
Tensor
d_output
,
bool
has_d_bias
)
{
// create output/workspace tensor
auto
opts
=
input
.
options
();
auto
d_weight
=
at
::
empty
({
out_features
,
in_features
},
opts
);
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
auto
d_bias
=
d_output
.
view
({
-
1
,
out_features
}).
sum
(
0
,
false
);
#else
auto
d_bias
=
at
::
empty
({
out_features
},
opts
);
#endif
auto
d_input
=
at
::
empty
({
batch_size
,
in_features
},
opts
);
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_bias_backward"
,
[
&
]
{
scalar_t
*
w_ptr
=
weight
.
data_ptr
<
scalar_t
>
();
auto
result
=
linear_bias_backward_cuda
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
w_ptr
,
d_output
.
data_ptr
<
scalar_t
>
(),
in_features
,
batch_size
,
out_features
,
d_weight
.
data_ptr
<
scalar_t
>
(),
d_bias
.
data_ptr
<
scalar_t
>
(),
d_input
.
data_ptr
<
scalar_t
>
(),
// reserved_space.data_ptr<scalar_t>(),
/*residual=*/
false
,
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
TORCH_CHECK
(
result
==
0
,
"linear_bias_backward failed."
)
});
return
{
d_input
,
d_weight
,
d_bias
};
}
std
::
vector
<
at
::
Tensor
>
linear_bias_wgrad
(
at
::
Tensor
input
,
at
::
Tensor
d_output
)
{
auto
batch_size
=
input
.
size
(
0
);
auto
in_features
=
input
.
size
(
1
);
int
batch_size
=
input
.
size
(
0
);
int
in_features
=
input
.
size
(
1
);
int
out_features
=
d_output
.
size
(
1
);
int
out_features
=
d_output
.
size
(
1
);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
TORCH_CHECK
(
input
.
dtype
()
==
torch
::
kFloat16
||
input
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
input
.
dtype
()
==
d_output
.
dtype
());
TORCH_CHECK
(
input
.
is_cuda
());
TORCH_CHECK
(
d_output
.
is_cuda
());
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
d_output
.
is_contiguous
());
CHECK_SHAPE
(
input
,
batch_size
,
in_features
);
CHECK_SHAPE
(
d_output
,
batch_size
,
out_features
);
// create output/workspace tensor
// create output/workspace tensor
auto
opts
=
input
.
options
();
auto
opts
=
input
.
options
();
auto
d_weight
=
at
::
empty
({
out_features
,
in_features
},
opts
);
auto
d_weight
=
at
::
empty
({
out_features
,
in_features
},
opts
);
at
::
Tensor
d_bias
;
if
(
has_d_bias
)
{
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
auto
d_bias
=
d_output
.
view
({
-
1
,
out_features
}).
sum
(
0
,
false
);
d_bias
=
d_output
.
view
({
-
1
,
out_features
}).
sum
(
0
,
false
);
#else
#else
auto
d_bias
=
at
::
empty
({
out_features
},
opts
);
d_bias
=
at
::
empty
({
out_features
},
opts
);
#endif
#endif
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
}
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
...
@@ -147,93 +72,59 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output)
...
@@ -147,93 +72,59 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output)
batch_size
,
batch_size
,
out_features
,
out_features
,
d_weight
.
data_ptr
<
scalar_t
>
(),
d_weight
.
data_ptr
<
scalar_t
>
(),
d_bias
.
data_ptr
<
scalar_t
>
(),
has_d_bias
?
d_bias
.
data_ptr
<
scalar_t
>
()
:
nullptr
,
// reserved_space.data_ptr<scalar_t>(),
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
TORCH_CHECK
(
result
==
0
,
"linear_bias_wgrad failed."
)
TORCH_CHECK
(
result
==
0
,
"linear_bias_wgrad failed."
)
;
});
});
return
{
d_weight
,
d_bias
};
return
{
d_weight
,
d_bias
};
}
}
std
::
vector
<
at
::
Tensor
>
linear_bias_residual_backward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
at
::
Tensor
d_output
,
at
::
Tensor
d_input
)
{
std
::
vector
<
at
::
Tensor
>
linear_gelu_forward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
c10
::
optional
<
at
::
Tensor
>
bias_
,
auto
batch_size
=
input
.
size
(
0
);
auto
in_features
=
input
.
size
(
1
);
int
out_features
=
weight
.
size
(
0
);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto
opts
=
input
.
options
();
auto
d_weight
=
at
::
empty
({
out_features
,
in_features
},
opts
);
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
auto
d_bias
=
d_output
.
view
({
-
1
,
out_features
}).
sum
(
0
,
false
);
#else
auto
d_bias
=
at
::
empty
({
out_features
},
opts
);
#endif
CHECK_SHAPE
(
d_input
,
batch_size
,
in_features
);
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_bias_backward"
,
[
&
]
{
scalar_t
*
w_ptr
=
weight
.
data_ptr
<
scalar_t
>
();
auto
result
=
linear_bias_backward_cuda
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
w_ptr
,
d_output
.
data_ptr
<
scalar_t
>
(),
in_features
,
batch_size
,
out_features
,
d_weight
.
data_ptr
<
scalar_t
>
(),
d_bias
.
data_ptr
<
scalar_t
>
(),
d_input
.
data_ptr
<
scalar_t
>
(),
// reserved_space.data_ptr<scalar_t>(),
/*residual=*/
true
,
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
TORCH_CHECK
(
result
==
0
,
"linear_bias_residual_backward failed."
)
});
return
{
d_input
,
d_weight
,
d_bias
};
}
std
::
vector
<
at
::
Tensor
>
linear_gelu_forward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
at
::
Tensor
bias
,
bool
save_gelu_in
,
int
heuristic
)
{
bool
save_gelu_in
,
int
heuristic
)
{
auto
batch_size
=
input
.
size
(
0
);
int
batch_size
=
input
.
size
(
0
);
auto
in_features
=
input
.
size
(
1
);
int
in_features
=
input
.
size
(
1
);
int
out_features
=
weight
.
size
(
0
);
int
out_features
=
weight
.
size
(
0
);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
TORCH_CHECK
(
input
.
dtype
()
==
torch
::
kFloat16
||
input
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
input
.
dtype
()
==
weight
.
dtype
());
TORCH_CHECK
(
input
.
is_cuda
());
TORCH_CHECK
(
weight
.
is_cuda
());
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
weight
.
is_contiguous
());
CHECK_SHAPE
(
input
,
batch_size
,
in_features
);
CHECK_SHAPE
(
weight
,
out_features
,
in_features
);
if
(
bias_
.
has_value
())
{
auto
bias
=
bias_
.
value
();
TORCH_CHECK
(
bias
.
dtype
()
==
input
.
dtype
());
TORCH_CHECK
(
bias
.
is_cuda
());
TORCH_CHECK
(
bias
.
is_contiguous
());
CHECK_SHAPE
(
bias
,
out_features
);
}
// create output/workspace tensor
// create output/workspace tensor
auto
opts
=
input
.
options
();
auto
opts
=
input
.
options
();
auto
output
=
at
::
empty
({
batch_size
,
out_features
},
opts
);
auto
output
=
at
::
empty
({
batch_size
,
out_features
},
opts
);
at
::
Tensor
gelu_in
;
at
::
Tensor
gelu_in
;
if
(
save_gelu_in
)
{
gelu_in
=
at
::
empty
({
batch_size
,
out_features
},
opts
);
}
if
(
save_gelu_in
)
{
gelu_in
=
at
::
empty
({
batch_size
,
out_features
},
opts
);
}
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_gelu_forward"
,
[
&
]
{
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_gelu_forward"
,
[
&
]
{
scalar_t
*
w_ptr
=
weight
.
data_ptr
<
scalar_t
>
();
scalar_t
*
b_ptr
=
bias
.
data_ptr
<
scalar_t
>
();
auto
result
=
linear_gelu_forward_cuda
<
scalar_t
>
(
auto
result
=
linear_gelu_forward_cuda
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
w
_ptr
,
w
eight
.
data_ptr
<
scalar_t
>
()
,
b
_
ptr
,
b
ias_
.
has_value
()
?
bias_
.
value
().
data_ptr
<
scalar_t
>
()
:
null
ptr
,
in_features
,
in_features
,
batch_size
,
batch_size
,
out_features
,
out_features
,
heuristic
,
heuristic
,
output
.
data_ptr
<
scalar_t
>
(),
output
.
data_ptr
<
scalar_t
>
(),
save_gelu_in
?
gelu_in
.
data_ptr
<
scalar_t
>
()
:
nullptr
,
save_gelu_in
?
gelu_in
.
data_ptr
<
scalar_t
>
()
:
nullptr
,
// reserved_space.data_ptr<scalar_t>(),
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
TORCH_CHECK
(
result
==
0
,
"linear_gelu_forward failed."
)
TORCH_CHECK
(
result
==
0
,
"linear_gelu_forward failed."
)
;
});
});
std
::
vector
<
at
::
Tensor
>
result
=
{
output
};
std
::
vector
<
at
::
Tensor
>
result
=
{
output
};
...
@@ -241,116 +132,54 @@ std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight,
...
@@ -241,116 +132,54 @@ std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight,
return
result
;
return
result
;
}
}
std
::
vector
<
at
::
Tensor
>
linear_gelu_linear_backward
(
at
::
Tensor
input
,
at
::
Tensor
gelu_in
,
at
::
Tensor
output1
,
at
::
Tensor
weight1
,
at
::
Tensor
weight2
,
at
::
Tensor
d_output2
,
int
heuristic
)
{
std
::
vector
<
at
::
Tensor
>
bias_gelu_linear_dgrad_bgrad
(
at
::
Tensor
weight
,
at
::
Tensor
d_output
,
at
::
Tensor
gelu_in
,
int
heuristic
)
{
auto
batch_size
=
input
.
size
(
0
);
int
batch_size
=
d_output
.
size
(
0
);
auto
in_features
=
input
.
size
(
1
);
int
out_features
=
d_output
.
size
(
1
);
int
in_features
=
weight
.
size
(
1
);
int
hidden_features
=
weight1
.
size
(
0
);
int
out_features
=
weight2
.
size
(
0
);
TORCH_CHECK
(
weight
.
dtype
()
==
torch
::
kFloat16
||
weight
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
weight
.
dtype
()
==
d_output
.
dtype
());
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
TORCH_CHECK
(
weight
.
dtype
()
==
gelu_in
.
dtype
());
TORCH_CHECK
(
weight
.
is_cuda
());
TORCH_CHECK
(
d_output
.
is_cuda
());
TORCH_CHECK
(
gelu_in
.
is_cuda
());
TORCH_CHECK
(
weight
.
is_contiguous
());
TORCH_CHECK
(
d_output
.
is_contiguous
());
TORCH_CHECK
(
gelu_in
.
is_contiguous
());
CHECK_SHAPE
(
weight
,
out_features
,
in_features
);
CHECK_SHAPE
(
d_output
,
batch_size
,
out_features
);
CHECK_SHAPE
(
gelu_in
,
batch_size
,
in_features
);
// create output/workspace tensor
// create output/workspace tensor
auto
opts
=
input
.
options
();
auto
opts
=
weight
.
options
();
auto
d_weight1
=
at
::
empty
({
hidden_features
,
in_features
},
opts
);
auto
d_bias
=
at
::
empty
({
in_features
},
opts
);
auto
d_weight2
=
at
::
empty
({
out_features
,
hidden_features
},
opts
);
auto
d_bias1
=
at
::
empty
({
hidden_features
},
opts
);
auto
d_bias2
=
at
::
empty
({
out_features
},
opts
);
auto
d_input
=
at
::
empty
({
batch_size
,
in_features
},
opts
);
auto
d_input
=
at
::
empty
({
batch_size
,
in_features
},
opts
);
auto
d_output1
=
at
::
empty
({
batch_size
,
hidden_features
},
opts
);
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_bias_backward"
,
[
&
]
{
//scalar_t* w_ptr = weight.data_ptr<scalar_t>();
//scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
auto
result
=
linear_gelu_linear_backward_cuda
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
gelu_in
.
data_ptr
<
scalar_t
>
(),
output1
.
data_ptr
<
scalar_t
>
(),
weight1
.
data_ptr
<
scalar_t
>
(),
weight2
.
data_ptr
<
scalar_t
>
(),
d_output1
.
data_ptr
<
scalar_t
>
(),
d_output2
.
data_ptr
<
scalar_t
>
(),
in_features
,
batch_size
,
hidden_features
,
out_features
,
heuristic
,
d_weight1
.
data_ptr
<
scalar_t
>
(),
d_weight2
.
data_ptr
<
scalar_t
>
(),
d_bias1
.
data_ptr
<
scalar_t
>
(),
d_bias2
.
data_ptr
<
scalar_t
>
(),
d_input
.
data_ptr
<
scalar_t
>
(),
// reserved_space.data_ptr<scalar_t>(),
/*residual=*/
false
,
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
TORCH_CHECK
(
result
==
0
,
"linear_gelu_linear_backward failed."
)
});
return
{
d_input
,
d_weight1
,
d_bias1
,
d_weight2
,
d_bias2
};
}
std
::
vector
<
at
::
Tensor
>
linear_residual_gelu_linear_backward
(
at
::
Tensor
input
,
at
::
Tensor
gelu_in
,
at
::
Tensor
output1
,
at
::
Tensor
weight1
,
at
::
Tensor
weight2
,
at
::
Tensor
d_output2
,
at
::
Tensor
d_input
,
int
heuristic
)
{
auto
batch_size
=
input
.
size
(
0
);
auto
in_features
=
input
.
size
(
1
);
int
hidden_features
=
weight1
.
size
(
0
);
int
out_features
=
weight2
.
size
(
0
);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto
opts
=
input
.
options
();
auto
d_weight1
=
at
::
empty
({
hidden_features
,
in_features
},
opts
);
auto
d_weight2
=
at
::
empty
({
out_features
,
hidden_features
},
opts
);
auto
d_bias1
=
at
::
empty
({
hidden_features
},
opts
);
auto
d_bias2
=
at
::
empty
({
out_features
},
opts
);
CHECK_SHAPE
(
d_input
,
batch_size
,
in_features
);
auto
d_output1
=
at
::
empty
({
batch_size
,
hidden_features
},
opts
);
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_bias_backward"
,
[
&
]
{
DISPATCH_HALF_AND_BF16
(
weight
.
scalar_type
(),
"bias_gelu_linear_dgrad_bgrad"
,
[
&
]
{
//scalar_t* w_ptr = weight.data_ptr<scalar_t>();
auto
result
=
bias_gelu_linear_dgrad_bgrad_cuda
<
scalar_t
>
(
//scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
weight
.
data_ptr
<
scalar_t
>
(),
auto
result
=
linear_gelu_linear_backward_cuda
<
scalar_t
>
(
d_output
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
gelu_in
.
data_ptr
<
scalar_t
>
(),
gelu_in
.
data_ptr
<
scalar_t
>
(),
output1
.
data_ptr
<
scalar_t
>
(),
weight1
.
data_ptr
<
scalar_t
>
(),
weight2
.
data_ptr
<
scalar_t
>
(),
d_output1
.
data_ptr
<
scalar_t
>
(),
d_output2
.
data_ptr
<
scalar_t
>
(),
in_features
,
in_features
,
batch_size
,
batch_size
,
hidden_features
,
out_features
,
out_features
,
heuristic
,
heuristic
,
d_weight1
.
data_ptr
<
scalar_t
>
(),
d_weight2
.
data_ptr
<
scalar_t
>
(),
d_bias1
.
data_ptr
<
scalar_t
>
(),
d_bias2
.
data_ptr
<
scalar_t
>
(),
d_input
.
data_ptr
<
scalar_t
>
(),
d_input
.
data_ptr
<
scalar_t
>
(),
// reserved_space.data_ptr<scalar_t>(),
d_bias
.
data_ptr
<
scalar_t
>
(),
/*residual=*/
true
,
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
TORCH_CHECK
(
result
==
0
,
"
linear_residual
_gelu_linear_
backwar
d failed."
)
TORCH_CHECK
(
result
==
0
,
"
bias
_gelu_linear_
dgrad_bgra
d failed."
)
;
});
});
return
{
d_input
,
d_
weight1
,
d_bias1
,
d_weight2
,
d_bias2
};
return
{
d_input
,
d_
bias
};
}
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"linear_bias_forward"
,
&
linear_bias_forward
,
"linear bias forward"
);
m
.
def
(
"linear_bias_backward"
,
&
linear_bias_backward
,
"linear bias backward"
);
m
.
def
(
"linear_bias_wgrad"
,
&
linear_bias_wgrad
,
"linear bias wgrad"
);
m
.
def
(
"linear_bias_wgrad"
,
&
linear_bias_wgrad
,
"linear bias wgrad"
);
m
.
def
(
"linear_bias_residual_backward"
,
&
linear_bias_residual_backward
,
"linear bias residual backward"
);
m
.
def
(
"linear_gelu_forward"
,
&
linear_gelu_forward
,
"linear gelu forward"
);
m
.
def
(
"linear_gelu_forward"
,
&
linear_gelu_forward
,
"linear gelu forward"
);
m
.
def
(
"linear_gelu_linear_backward"
,
&
linear_gelu_linear_backward
,
"linear gelu linear backward"
);
m
.
def
(
"bias_gelu_linear_dgrad_bgrad"
,
&
bias_gelu_linear_dgrad_bgrad
,
"bias gelu linear dgrad bgrad"
);
m
.
def
(
"linear_residual_gelu_linear_backward"
,
&
linear_residual_gelu_linear_backward
,
"linear residual gelu linear backward"
);
}
}
csrc/fused_dense_lib/fused_dense_cuda.cu
View file @
e68ebbe8
...
@@ -94,226 +94,6 @@ cublasStatus_t gemm_bias(
...
@@ -94,226 +94,6 @@ cublasStatus_t gemm_bias(
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
int
gemm_bias_lt
(
cublasLtHandle_t
ltHandle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
/* host pointer */
at
::
Half
*
A
,
int
lda
,
at
::
Half
*
B
,
int
ldb
,
const
float
*
beta
,
/* host pointer */
at
::
Half
*
C
,
int
ldc
,
void
*
workspace
,
size_t
workspaceSize
,
cudaStream_t
stream
,
bool
use_bias
,
const
void
*
bias
)
{
cublasStatus_t
status
=
CUBLAS_STATUS_SUCCESS
;
cublasLtMatmulDescOpaque_t
operationDesc
=
{};
cublasLtMatrixLayoutOpaque_t
Adesc
=
{},
Bdesc
=
{},
Cdesc
=
{};
cublasLtMatmulPreferenceOpaque_t
preference
=
{};
int
returnedResults
=
0
;
cublasLtMatmulHeuristicResult_t
heuristicResult
=
{};
cublasLtEpilogue_t
epilogue
=
CUBLASLT_EPILOGUE_DEFAULT
;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status
=
cublasLtMatmulDescInit
(
&
operationDesc
,
CUBLAS_COMPUTE_32F
,
CUDA_R_32F
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
transa
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
transb
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
use_bias
)
{
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bias
,
sizeof
(
bias
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
epilogue
=
CUBLASLT_EPILOGUE_BIAS
;
}
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
epilogue
,
sizeof
(
epilogue
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
// Create matrix descriptors. Not setting any extra attributes.
status
=
cublasLtMatrixLayoutInit
(
&
Adesc
,
CUDA_R_16F
,
transa
==
CUBLAS_OP_N
?
m
:
k
,
transa
==
CUBLAS_OP_N
?
k
:
m
,
lda
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Bdesc
,
CUDA_R_16F
,
transb
==
CUBLAS_OP_N
?
k
:
n
,
transb
==
CUBLAS_OP_N
?
n
:
k
,
ldb
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Cdesc
,
CUDA_R_16F
,
m
,
n
,
ldc
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status
=
cublasLtMatmulPreferenceInit
(
&
preference
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulPreferenceSetAttribute
(
&
preference
,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
,
&
workspaceSize
,
sizeof
(
workspaceSize
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status
=
cublasLtMatmulAlgoGetHeuristic
(
ltHandle
,
&
operationDesc
,
&
Adesc
,
&
Bdesc
,
&
Cdesc
,
&
Cdesc
,
&
preference
,
1
,
&
heuristicResult
,
&
returnedResults
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
returnedResults
==
0
)
{
status
=
CUBLAS_STATUS_NOT_SUPPORTED
;
goto
CLEANUP
;
}
status
=
cublasLtMatmul
(
ltHandle
,
&
operationDesc
,
alpha
,
A
,
&
Adesc
,
B
,
&
Bdesc
,
beta
,
C
,
&
Cdesc
,
C
,
&
Cdesc
,
//&heuristicResult.algo,
NULL
,
workspace
,
workspaceSize
,
stream
);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return
status
==
CUBLAS_STATUS_SUCCESS
?
0
:
1
;
}
int
gemm_bias_lt
(
cublasLtHandle_t
ltHandle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
/* host pointer */
at
::
BFloat16
*
A
,
int
lda
,
at
::
BFloat16
*
B
,
int
ldb
,
const
float
*
beta
,
/* host pointer */
at
::
BFloat16
*
C
,
int
ldc
,
void
*
workspace
,
size_t
workspaceSize
,
cudaStream_t
stream
,
bool
use_bias
,
const
void
*
bias
)
{
cublasStatus_t
status
=
CUBLAS_STATUS_SUCCESS
;
cublasLtMatmulDescOpaque_t
operationDesc
=
{};
cublasLtMatrixLayoutOpaque_t
Adesc
=
{},
Bdesc
=
{},
Cdesc
=
{};
cublasLtMatmulPreferenceOpaque_t
preference
=
{};
int
returnedResults
=
0
;
cublasLtMatmulHeuristicResult_t
heuristicResult
=
{};
cublasLtEpilogue_t
epilogue
=
CUBLASLT_EPILOGUE_DEFAULT
;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status
=
cublasLtMatmulDescInit
(
&
operationDesc
,
CUBLAS_COMPUTE_32F
,
CUDA_R_32F
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
transa
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
transb
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
use_bias
)
{
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bias
,
sizeof
(
bias
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
epilogue
=
CUBLASLT_EPILOGUE_BIAS
;
}
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
epilogue
,
sizeof
(
epilogue
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
// Create matrix descriptors. Not setting any extra attributes.
status
=
cublasLtMatrixLayoutInit
(
&
Adesc
,
CUDA_R_16BF
,
transa
==
CUBLAS_OP_N
?
m
:
k
,
transa
==
CUBLAS_OP_N
?
k
:
m
,
lda
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Bdesc
,
CUDA_R_16BF
,
transb
==
CUBLAS_OP_N
?
k
:
n
,
transb
==
CUBLAS_OP_N
?
n
:
k
,
ldb
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Cdesc
,
CUDA_R_16BF
,
m
,
n
,
ldc
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status
=
cublasLtMatmulPreferenceInit
(
&
preference
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulPreferenceSetAttribute
(
&
preference
,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
,
&
workspaceSize
,
sizeof
(
workspaceSize
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status
=
cublasLtMatmulAlgoGetHeuristic
(
ltHandle
,
&
operationDesc
,
&
Adesc
,
&
Bdesc
,
&
Cdesc
,
&
Cdesc
,
&
preference
,
1
,
&
heuristicResult
,
&
returnedResults
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
returnedResults
==
0
)
{
status
=
CUBLAS_STATUS_NOT_SUPPORTED
;
goto
CLEANUP
;
}
status
=
cublasLtMatmul
(
ltHandle
,
&
operationDesc
,
alpha
,
A
,
&
Adesc
,
B
,
&
Bdesc
,
beta
,
C
,
&
Cdesc
,
C
,
&
Cdesc
,
//&heuristicResult.algo,
NULL
,
workspace
,
workspaceSize
,
stream
);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return
status
==
CUBLAS_STATUS_SUCCESS
?
0
:
1
;
}
int
gemm_bias_gelu_lt
(
int
gemm_bias_gelu_lt
(
cublasLtHandle_t
ltHandle
,
cublasLtHandle_t
ltHandle
,
cublasOperation_t
transa
,
cublasOperation_t
transa
,
...
@@ -332,7 +112,6 @@ int gemm_bias_gelu_lt(
...
@@ -332,7 +112,6 @@ int gemm_bias_gelu_lt(
void
*
workspace
,
void
*
workspace
,
size_t
workspaceSize
,
size_t
workspaceSize
,
cudaStream_t
stream
,
cudaStream_t
stream
,
bool
use_bias
,
int
heuristic
,
int
heuristic
,
const
void
*
gelu_in
,
const
void
*
gelu_in
,
const
void
*
bias
)
{
const
void
*
bias
)
{
...
@@ -363,12 +142,14 @@ int gemm_bias_gelu_lt(
...
@@ -363,12 +142,14 @@ int gemm_bias_gelu_lt(
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD
,
&
ldc
,
sizeof
(
ldc
));
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD
,
&
ldc
,
sizeof
(
ldc
));
}
}
if
(
use_
bias
)
{
if
(
bias
!=
nullptr
)
{
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bias
,
sizeof
(
bias
));
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bias
,
sizeof
(
bias
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
goto
CLEANUP
;
}
}
epilogue
=
save_gelu_in
?
CUBLASLT_EPILOGUE_GELU_AUX_BIAS
:
CUBLASLT_EPILOGUE_GELU_BIAS
;
epilogue
=
save_gelu_in
?
CUBLASLT_EPILOGUE_GELU_AUX_BIAS
:
CUBLASLT_EPILOGUE_GELU_BIAS
;
}
else
{
epilogue
=
save_gelu_in
?
CUBLASLT_EPILOGUE_GELU_AUX
:
CUBLASLT_EPILOGUE_GELU
;
}
}
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
epilogue
,
sizeof
(
epilogue
));
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
epilogue
,
sizeof
(
epilogue
));
...
@@ -453,7 +234,6 @@ int gemm_bias_gelu_lt(
...
@@ -453,7 +234,6 @@ int gemm_bias_gelu_lt(
void
*
workspace
,
void
*
workspace
,
size_t
workspaceSize
,
size_t
workspaceSize
,
cudaStream_t
stream
,
cudaStream_t
stream
,
bool
use_bias
,
int
heuristic
,
int
heuristic
,
const
void
*
gelu_in
,
const
void
*
gelu_in
,
const
void
*
bias
)
{
const
void
*
bias
)
{
...
@@ -484,12 +264,14 @@ int gemm_bias_gelu_lt(
...
@@ -484,12 +264,14 @@ int gemm_bias_gelu_lt(
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD
,
&
ldc
,
sizeof
(
ldc
));
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD
,
&
ldc
,
sizeof
(
ldc
));
}
}
if
(
use_
bias
)
{
if
(
bias
!=
nullptr
)
{
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bias
,
sizeof
(
bias
));
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bias
,
sizeof
(
bias
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
goto
CLEANUP
;
}
}
epilogue
=
save_gelu_in
?
CUBLASLT_EPILOGUE_GELU_AUX_BIAS
:
CUBLASLT_EPILOGUE_GELU_BIAS
;
epilogue
=
save_gelu_in
?
CUBLASLT_EPILOGUE_GELU_AUX_BIAS
:
CUBLASLT_EPILOGUE_GELU_BIAS
;
}
else
{
epilogue
=
save_gelu_in
?
CUBLASLT_EPILOGUE_GELU_AUX
:
CUBLASLT_EPILOGUE_GELU
;
}
}
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
epilogue
,
sizeof
(
epilogue
));
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
epilogue
,
sizeof
(
epilogue
));
...
@@ -574,7 +356,6 @@ int gemm_bgradb_lt(
...
@@ -574,7 +356,6 @@ int gemm_bgradb_lt(
void
*
workspace
,
void
*
workspace
,
size_t
workspaceSize
,
size_t
workspaceSize
,
cudaStream_t
stream
,
cudaStream_t
stream
,
bool
use_bias
,
const
void
*
bgrad
)
{
const
void
*
bgrad
)
{
cublasStatus_t
status
=
CUBLAS_STATUS_SUCCESS
;
cublasStatus_t
status
=
CUBLAS_STATUS_SUCCESS
;
...
@@ -596,7 +377,7 @@ int gemm_bgradb_lt(
...
@@ -596,7 +377,7 @@ int gemm_bgradb_lt(
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
transb
,
sizeof
(
transa
));
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
transb
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
use_bias
)
{
if
(
bgrad
!=
nullptr
)
{
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bgrad
,
sizeof
(
bgrad
));
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bgrad
,
sizeof
(
bgrad
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
goto
CLEANUP
;
...
@@ -684,7 +465,6 @@ int gemm_bgradb_lt(
...
@@ -684,7 +465,6 @@ int gemm_bgradb_lt(
void
*
workspace
,
void
*
workspace
,
size_t
workspaceSize
,
size_t
workspaceSize
,
cudaStream_t
stream
,
cudaStream_t
stream
,
bool
use_bias
,
const
void
*
bgrad
)
{
const
void
*
bgrad
)
{
cublasStatus_t
status
=
CUBLAS_STATUS_SUCCESS
;
cublasStatus_t
status
=
CUBLAS_STATUS_SUCCESS
;
...
@@ -706,7 +486,7 @@ int gemm_bgradb_lt(
...
@@ -706,7 +486,7 @@ int gemm_bgradb_lt(
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
transb
,
sizeof
(
transa
));
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
transb
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
use_bias
)
{
if
(
bgrad
!=
nullptr
)
{
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bgrad
,
sizeof
(
bgrad
));
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bgrad
,
sizeof
(
bgrad
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
goto
CLEANUP
;
...
@@ -1008,132 +788,6 @@ CLEANUP:
...
@@ -1008,132 +788,6 @@ CLEANUP:
#endif
#endif
template
<
typename
T
>
int
linear_bias_forward_cuda
(
at
::
Tensor
input
,
T
*
weight
,
at
::
Tensor
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
at
::
Tensor
output
,
void
*
lt_workspace
)
{
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
// Get the stream from cublas handle to reuse for biasReLU kernel.
cudaStream_t
stream
;
cublasGetStream
(
handle
,
&
stream
);
const
float
alpha
=
1.0
;
const
float
beta_zero
=
0.0
;
const
float
beta_one
=
1.0
;
int
status
=
1
;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
status
=
gemm_bias_lt
(
(
cublasLtHandle_t
)
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
out_features
,
batch_size
,
in_features
,
&
alpha
,
/* host pointer */
weight
,
in_features
,
input
.
data_ptr
<
T
>
(),
in_features
,
&
beta_zero
,
/* host pointer */
output
.
data_ptr
<
T
>
(),
out_features
,
lt_workspace
,
1
<<
22
,
stream
,
true
,
static_cast
<
const
void
*>
(
bias
.
data_ptr
<
T
>
()));
#endif
if
(
status
!=
0
){
output
.
copy_
(
bias
);
status
=
gemm_bias
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
out_features
,
batch_size
,
in_features
,
&
alpha
,
weight
,
in_features
,
input
.
data_ptr
<
T
>
(),
in_features
,
&
beta_one
,
output
.
data_ptr
<
T
>
(),
out_features
);
}
return
status
;
}
template
<
typename
T
>
int
linear_bias_backward_cuda
(
T
*
input
,
T
*
weight
,
T
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
T
*
d_weight
,
T
*
d_bias
,
T
*
d_input
,
bool
residual
,
void
*
lt_workspace
)
{
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
// Get the stream from cublas handle to reuse for biasReLU kernel.
cudaStream_t
stream
;
cublasGetStream
(
handle
,
&
stream
);
const
float
alpha
=
1.0
;
const
float
beta_zero
=
0.0
;
const
float
beta
=
residual
?
1.0
:
0.0
;
int
status
=
1
;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
status
=
gemm_bgradb_lt
(
(
cublasLtHandle_t
)
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
in_features
,
out_features
,
batch_size
,
&
alpha
,
/* host pointer */
input
,
in_features
,
d_output
,
out_features
,
&
beta_zero
,
/* host pointer */
d_weight
,
in_features
,
lt_workspace
,
1
<<
22
,
stream
,
true
,
static_cast
<
const
void
*>
(
d_bias
));
#endif
if
(
status
!=
0
){
status
=
gemm_bias
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
in_features
,
out_features
,
batch_size
,
&
alpha
,
input
,
in_features
,
d_output
,
out_features
,
&
beta_zero
,
d_weight
,
in_features
);
}
status
=
gemm_bias
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
in_features
,
batch_size
,
out_features
,
&
alpha
,
weight
,
in_features
,
d_output
,
out_features
,
&
beta
,
d_input
,
in_features
);
return
status
;
}
template
<
typename
T
>
template
<
typename
T
>
int
linear_bias_wgrad_cuda
(
T
*
input
,
T
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
T
*
d_weight
,
T
*
d_bias
,
void
*
lt_workspace
)
{
int
linear_bias_wgrad_cuda
(
T
*
input
,
T
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
T
*
d_weight
,
T
*
d_bias
,
void
*
lt_workspace
)
{
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
...
@@ -1162,13 +816,10 @@ int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_siz
...
@@ -1162,13 +816,10 @@ int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_siz
lt_workspace
,
lt_workspace
,
1
<<
22
,
1
<<
22
,
stream
,
stream
,
true
,
static_cast
<
const
void
*>
(
d_bias
));
static_cast
<
const
void
*>
(
d_bias
));
#endif
#endif
if
(
status
!=
0
){
if
(
status
!=
0
){
status
=
gemm_bias
(
status
=
gemm_bias
(
handle
,
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
...
@@ -1217,7 +868,6 @@ int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int
...
@@ -1217,7 +868,6 @@ int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int
lt_workspace
,
lt_workspace
,
1
<<
22
,
1
<<
22
,
stream
,
stream
,
true
,
heuristic
,
heuristic
,
static_cast
<
const
void
*>
(
gelu_in
),
static_cast
<
const
void
*>
(
gelu_in
),
static_cast
<
const
void
*>
(
bias
));
static_cast
<
const
void
*>
(
bias
));
...
@@ -1228,109 +878,46 @@ int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int
...
@@ -1228,109 +878,46 @@ int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int
}
}
template
<
typename
T
>
template
<
typename
T
>
int
linear
_gelu_linear_
backward_cuda
(
T
*
input
,
T
*
gelu_in
,
T
*
output1
,
T
*
weight1
,
T
*
weight2
,
T
*
d_output
1
,
T
*
d_output2
,
int
in_features
,
int
batch_size
,
int
hidden_features
,
int
out_features
,
int
heuristic
,
T
*
d_
weight1
,
T
*
d_weight2
,
T
*
d_bias1
,
T
*
d_bias2
,
T
*
d_input
,
bool
residual
,
void
*
lt_workspace
)
{
int
bias
_gelu_linear_
dgrad_bgrad_cuda
(
T
*
weight
,
T
*
d_output
,
T
*
gelu_in
,
int
in_features
,
int
batch_size
,
int
out_features
,
int
heuristic
,
T
*
d_
input
,
T
*
d_bias
,
void
*
lt_workspace
)
{
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
// Get the stream from cublas handle to reuse for biasReLU kernel.
// Get the stream from cublas handle to reuse for biasReLU kernel.
cudaStream_t
stream
;
cudaStream_t
stream
;
cublasGetStream
(
handle
,
&
stream
);
cublasGetStream
(
handle
,
&
stream
);
const
float
alpha
=
1.0
;
const
float
alpha
=
1.0
;
const
float
beta_zero
=
0.0
;
const
float
beta_zero
=
0.0
;
const
float
beta
=
residual
?
1.0
:
0.0
;
int
status
=
1
;
int
status
=
1
;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
//wgrad for first gemm
status
=
gemm_bgradb_lt
(
(
cublasLtHandle_t
)
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
hidden_features
,
out_features
,
batch_size
,
&
alpha
,
/* host pointer */
output1
,
hidden_features
,
d_output2
,
out_features
,
&
beta_zero
,
/* host pointer */
d_weight2
,
hidden_features
,
lt_workspace
,
1
<<
22
,
stream
,
true
,
static_cast
<
const
void
*>
(
d_bias2
));
//dgrad for second GEMM
status
=
gemm_dgelu_bgradb_lt
(
status
=
gemm_dgelu_bgradb_lt
(
(
cublasLtHandle_t
)
handle
,
(
cublasLtHandle_t
)
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
hidde
n_features
,
i
n_features
,
batch_size
,
batch_size
,
out_features
,
out_features
,
&
alpha
,
/* host pointer */
&
alpha
,
/* host pointer */
weight
2
,
weight
,
hidde
n_features
,
i
n_features
,
d_output
2
,
d_output
,
out_features
,
out_features
,
&
beta_zero
,
/* host pointer */
&
beta_zero
,
/* host pointer */
d_
out
put
1
,
d_
in
put
,
hidde
n_features
,
i
n_features
,
lt_workspace
,
lt_workspace
,
1
<<
22
,
1
<<
22
,
stream
,
stream
,
heuristic
,
heuristic
,
static_cast
<
const
void
*>
(
gelu_in
),
static_cast
<
const
void
*>
(
gelu_in
),
static_cast
<
const
void
*>
(
d_bias1
));
static_cast
<
const
void
*>
(
d_bias
));
//wgrad for the first GEMM
status
=
gemm_bias
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
in_features
,
hidden_features
,
batch_size
,
&
alpha
,
input
,
in_features
,
d_output1
,
hidden_features
,
&
beta_zero
,
d_weight1
,
in_features
);
//dgrad for the first GEMM
status
=
gemm_bias
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
in_features
,
batch_size
,
hidden_features
,
&
alpha
,
weight1
,
in_features
,
d_output1
,
hidden_features
,
&
beta
,
d_input
,
in_features
);
#endif
#endif
return
status
;
return
status
;
}
}
template
int
linear_bias_forward_cuda
<
at
::
Half
>(
at
::
Tensor
input
,
at
::
Half
*
weight
,
at
::
Tensor
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
at
::
Tensor
output
,
void
*
lt_workspace
);
template
int
linear_bias_forward_cuda
<
at
::
BFloat16
>(
at
::
Tensor
input
,
at
::
BFloat16
*
weight
,
at
::
Tensor
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
at
::
Tensor
output
,
void
*
lt_workspace
);
template
int
linear_bias_backward_cuda
<
at
::
Half
>(
at
::
Half
*
input
,
at
::
Half
*
weight
,
at
::
Half
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
at
::
Half
*
d_weight
,
at
::
Half
*
d_bias
,
at
::
Half
*
d_input
,
bool
residual
,
void
*
lt_workspace
)
;
template
int
linear_bias_backward_cuda
<
at
::
BFloat16
>(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
weight
,
at
::
BFloat16
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
at
::
BFloat16
*
d_weight
,
at
::
BFloat16
*
d_bias
,
at
::
BFloat16
*
d_input
,
bool
residual
,
void
*
lt_workspace
)
;
template
int
linear_bias_wgrad_cuda
<
at
::
Half
>(
at
::
Half
*
input
,
at
::
Half
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
at
::
Half
*
d_weight
,
at
::
Half
*
d_bias
,
void
*
lt_workspace
)
;
template
int
linear_bias_wgrad_cuda
<
at
::
Half
>(
at
::
Half
*
input
,
at
::
Half
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
at
::
Half
*
d_weight
,
at
::
Half
*
d_bias
,
void
*
lt_workspace
)
;
template
int
linear_bias_wgrad_cuda
<
at
::
BFloat16
>(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
at
::
BFloat16
*
d_weight
,
at
::
BFloat16
*
d_bias
,
void
*
lt_workspace
)
;
template
int
linear_bias_wgrad_cuda
<
at
::
BFloat16
>(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
at
::
BFloat16
*
d_weight
,
at
::
BFloat16
*
d_bias
,
void
*
lt_workspace
)
;
template
int
linear_gelu_forward_cuda
<
at
::
Half
>(
at
::
Half
*
input
,
at
::
Half
*
weight
,
at
::
Half
*
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
int
heuristic
,
at
::
Half
*
output
,
at
::
Half
*
gelu_in
,
void
*
lt_workspace
)
;
template
int
linear_gelu_forward_cuda
<
at
::
Half
>(
at
::
Half
*
input
,
at
::
Half
*
weight
,
at
::
Half
*
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
int
heuristic
,
at
::
Half
*
output
,
at
::
Half
*
gelu_in
,
void
*
lt_workspace
)
;
template
int
linear_gelu_forward_cuda
<
at
::
BFloat16
>(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
weight
,
at
::
BFloat16
*
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
int
heuristic
,
at
::
BFloat16
*
output
,
at
::
BFloat16
*
gelu_in
,
void
*
lt_workspace
)
;
template
int
linear_gelu_forward_cuda
<
at
::
BFloat16
>(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
weight
,
at
::
BFloat16
*
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
int
heuristic
,
at
::
BFloat16
*
output
,
at
::
BFloat16
*
gelu_in
,
void
*
lt_workspace
)
;
template
int
linear
_gelu_linear_
backwar
d_cuda
<
at
::
Half
>(
at
::
Half
*
input
,
at
::
Half
*
gelu_in
,
at
::
Half
*
output1
,
at
::
Half
*
weight1
,
at
::
Half
*
weight2
,
at
::
Half
*
d_output
1
,
at
::
Half
*
d_output2
,
int
in_features
,
int
batch_size
,
int
hidden_features
,
int
out_features
,
int
heuristic
,
at
::
Half
*
d_
weight1
,
at
::
Half
*
d_weight2
,
at
::
Half
*
d_bias1
,
at
::
Half
*
d_bias2
,
at
::
Half
*
d_input
,
bool
residual
,
void
*
lt_workspace
);
template
int
bias
_gelu_linear_
dgrad_bgra
d_cuda
<
at
::
Half
>(
at
::
Half
*
weight
,
at
::
Half
*
d_output
,
at
::
Half
*
gelu_in
,
int
in_features
,
int
batch_size
,
int
out_features
,
int
heuristic
,
at
::
Half
*
d_
input
,
at
::
Half
*
d_bias
,
void
*
lt_workspace
);
template
int
linear
_gelu_linear_
backwar
d_cuda
<
at
::
BFloat16
>(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
gelu_in
,
at
::
BFloat16
*
output1
,
at
::
BFloat16
*
weight1
,
at
::
BFloat16
*
weight2
,
at
::
BFloat16
*
d_output
1
,
at
::
BFloat16
*
d_output2
,
int
in_features
,
int
batch_size
,
int
hidden_features
,
int
out_features
,
int
heuristic
,
at
::
BFloat16
*
d_
weight1
,
at
::
BFloat16
*
d_weight2
,
at
::
BFloat16
*
d_bias1
,
at
::
BFloat16
*
d_bias2
,
at
::
BFloat16
*
d_input
,
bool
residual
,
void
*
lt_workspace
);
template
int
bias
_gelu_linear_
dgrad_bgra
d_cuda
<
at
::
BFloat16
>(
at
::
BFloat16
*
weight
,
at
::
BFloat16
*
d_output
,
at
::
BFloat16
*
gelu_in
,
int
in_features
,
int
batch_size
,
int
out_features
,
int
heuristic
,
at
::
BFloat16
*
d_
input
,
at
::
BFloat16
*
d_bias
,
void
*
lt_workspace
);
\ No newline at end of file
flash_attn/layers/patch_embed.py
View file @
e68ebbe8
...
@@ -10,9 +10,9 @@ from torch.nn.modules.utils import _pair
...
@@ -10,9 +10,9 @@ from torch.nn.modules.utils import _pair
from
einops
import
rearrange
from
einops
import
rearrange
try
:
try
:
from
flash_attn.ops.fused_dense
import
FusedDense
TD
from
flash_attn.ops.fused_dense
import
FusedDense
except
ImportError
:
except
ImportError
:
FusedDense
TD
=
None
FusedDense
=
None
class
PatchEmbed
(
nn
.
Module
):
class
PatchEmbed
(
nn
.
Module
):
...
@@ -37,10 +37,10 @@ class PatchEmbed(nn.Module):
...
@@ -37,10 +37,10 @@ class PatchEmbed(nn.Module):
self
.
grid_size
=
(
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
])
self
.
grid_size
=
(
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
])
self
.
num_patches
=
self
.
grid_size
[
0
]
*
self
.
grid_size
[
1
]
self
.
num_patches
=
self
.
grid_size
[
0
]
*
self
.
grid_size
[
1
]
self
.
flatten
=
flatten
self
.
flatten
=
flatten
if
fused_bias_fc
and
FusedDense
TD
is
None
:
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
raise
ImportError
(
'fused_dense is not installed'
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
or
not
bias
else
FusedDense
TD
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
or
not
bias
else
FusedDense
self
.
proj
=
linear_cls
(
in_chans
*
patch_size
[
0
]
*
patch_size
[
1
],
embed_dim
,
bias
=
bias
)
self
.
proj
=
linear_cls
(
in_chans
*
patch_size
[
0
]
*
patch_size
[
1
],
embed_dim
,
bias
=
bias
)
self
.
norm
=
norm_layer
(
embed_dim
)
if
norm_layer
else
nn
.
Identity
()
self
.
norm
=
norm_layer
(
embed_dim
)
if
norm_layer
else
nn
.
Identity
()
...
...
flash_attn/models/bert.py
View file @
e68ebbe8
...
@@ -30,9 +30,9 @@ from flash_attn.bert_padding import unpad_input, pad_input
...
@@ -30,9 +30,9 @@ from flash_attn.bert_padding import unpad_input, pad_input
from
flash_attn.bert_padding
import
index_first_axis
,
index_first_axis_residual
from
flash_attn.bert_padding
import
index_first_axis
,
index_first_axis_residual
try
:
try
:
from
flash_attn.ops.fused_dense
import
FusedDense
TD
from
flash_attn.ops.fused_dense
import
FusedDense
except
ImportError
:
except
ImportError
:
FusedDense
TD
=
None
FusedDense
=
None
try
:
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm
,
layer_norm
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm
,
layer_norm
...
@@ -70,6 +70,8 @@ def create_mlp_cls(config, layer_idx=None, return_residual=False):
...
@@ -70,6 +70,8 @@ def create_mlp_cls(config, layer_idx=None, return_residual=False):
activation
=
partial
(
F
.
gelu
,
approximate
=
approximate
),
activation
=
partial
(
F
.
gelu
,
approximate
=
approximate
),
return_residual
=
return_residual
)
return_residual
=
return_residual
)
else
:
else
:
if
FusedDenseGeluDense
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
mlp_checkpoint_lvl
=
getattr
(
config
,
'mlp_checkpoint_lvl'
,
0
)
mlp_checkpoint_lvl
=
getattr
(
config
,
'mlp_checkpoint_lvl'
,
0
)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
if
isinstance
(
mlp_checkpoint_lvl
,
Sequence
):
if
isinstance
(
mlp_checkpoint_lvl
,
Sequence
):
...
@@ -168,9 +170,9 @@ class BertPooler(nn.Module):
...
@@ -168,9 +170,9 @@ class BertPooler(nn.Module):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
()
super
().
__init__
()
fused_bias_fc
=
getattr
(
config
,
'fused_bias_fc'
,
False
)
fused_bias_fc
=
getattr
(
config
,
'fused_bias_fc'
,
False
)
if
fused_bias_fc
and
FusedDense
TD
is
None
:
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
raise
ImportError
(
'fused_dense is not installed'
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
TD
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
self
.
dense
=
linear_cls
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
dense
=
linear_cls
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
activation
=
nn
.
Tanh
()
self
.
activation
=
nn
.
Tanh
()
...
@@ -188,12 +190,12 @@ class BertPredictionHeadTransform(nn.Module):
...
@@ -188,12 +190,12 @@ class BertPredictionHeadTransform(nn.Module):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
()
super
().
__init__
()
fused_bias_fc
=
getattr
(
config
,
'fused_bias_fc'
,
False
)
fused_bias_fc
=
getattr
(
config
,
'fused_bias_fc'
,
False
)
if
fused_bias_fc
and
FusedDense
TD
is
None
:
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
raise
ImportError
(
'fused_dense is not installed'
)
self
.
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
)
self
.
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
)
if
self
.
fused_dropout_add_ln
and
layer_norm
is
None
:
if
self
.
fused_dropout_add_ln
and
layer_norm
is
None
:
raise
ImportError
(
'dropout_add_layer_norm is not installed'
)
raise
ImportError
(
'dropout_add_layer_norm is not installed'
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
TD
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
self
.
dense
=
linear_cls
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
dense
=
linear_cls
(
config
.
hidden_size
,
config
.
hidden_size
)
approximate
=
'tanh'
if
config
.
hidden_act
in
[
'gelu_new'
,
'gelu_fast'
]
else
'none'
approximate
=
'tanh'
if
config
.
hidden_act
in
[
'gelu_new'
,
'gelu_fast'
]
else
'none'
self
.
transform_act_fn
=
nn
.
GELU
(
approximate
=
approximate
)
self
.
transform_act_fn
=
nn
.
GELU
(
approximate
=
approximate
)
...
@@ -215,9 +217,9 @@ class BertLMPredictionHead(nn.Module):
...
@@ -215,9 +217,9 @@ class BertLMPredictionHead(nn.Module):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
()
super
().
__init__
()
fused_bias_fc
=
getattr
(
config
,
'fused_bias_fc'
,
False
)
fused_bias_fc
=
getattr
(
config
,
'fused_bias_fc'
,
False
)
if
fused_bias_fc
and
FusedDense
TD
is
None
:
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
raise
ImportError
(
'fused_dense is not installed'
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
TD
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
self
.
transform
=
BertPredictionHeadTransform
(
config
)
self
.
transform
=
BertPredictionHeadTransform
(
config
)
...
...
flash_attn/models/gpt.py
View file @
e68ebbe8
...
@@ -61,6 +61,8 @@ def create_mlp_cls(config, layer_idx=None):
...
@@ -61,6 +61,8 @@ def create_mlp_cls(config, layer_idx=None):
assert
layer_idx
is
not
None
assert
layer_idx
is
not
None
mlp_checkpoint_lvl
=
mlp_checkpoint_lvl
[
layer_idx
]
mlp_checkpoint_lvl
=
mlp_checkpoint_lvl
[
layer_idx
]
if
fused_dense_gelu_dense
:
if
fused_dense_gelu_dense
:
if
FusedDenseGeluDense
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
mlp_cls
=
partial
(
FusedDenseGeluDense
,
hidden_features
=
inner_dim
,
mlp_cls
=
partial
(
FusedDenseGeluDense
,
hidden_features
=
inner_dim
,
checkpoint_lvl
=
mlp_checkpoint_lvl
)
checkpoint_lvl
=
mlp_checkpoint_lvl
)
elif
fused_dense_sqrelu_dense
:
elif
fused_dense_sqrelu_dense
:
...
...
flash_attn/modules/mha.py
View file @
e68ebbe8
...
@@ -21,9 +21,9 @@ except ImportError:
...
@@ -21,9 +21,9 @@ except ImportError:
flash_attn_qkvpacked_func
,
flash_attn_kvpacked_func
=
None
,
None
flash_attn_qkvpacked_func
,
flash_attn_kvpacked_func
=
None
,
None
try
:
try
:
from
flash_attn.ops.fused_dense
import
FusedDense
TD
,
FusedDenseResidual
from
flash_attn.ops.fused_dense
import
FusedDense
except
ImportError
:
except
ImportError
:
FusedDense
TD
,
FusedDenseResidual
=
None
,
None
FusedDense
=
None
try
:
try
:
from
flash_attn.layers.rotary
import
RotaryEmbedding
from
flash_attn.layers.rotary
import
RotaryEmbedding
...
@@ -270,7 +270,7 @@ class CrossAttention(nn.Module):
...
@@ -270,7 +270,7 @@ class CrossAttention(nn.Module):
class
LinearResidual
(
nn
.
Linear
):
class
LinearResidual
(
nn
.
Linear
):
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense
Residual
.
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.
"""
"""
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -311,10 +311,11 @@ class MHA(nn.Module):
...
@@ -311,10 +311,11 @@ class MHA(nn.Module):
assert
RotaryEmbedding
is
not
None
,
'rotary_emb is not installed'
assert
RotaryEmbedding
is
not
None
,
'rotary_emb is not installed'
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
scale_base
=
rotary_emb_scale_base
)
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
scale_base
=
rotary_emb_scale_base
)
if
fused_bias_fc
and
FusedDense
TD
is
None
:
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
raise
ImportError
(
'fused_dense is not installed'
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDenseTD
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
linear_resid_cls
=
LinearResidual
if
not
fused_bias_fc
else
FusedDenseResidual
linear_resid_cls
=
(
LinearResidual
if
not
fused_bias_fc
else
partial
(
FusedDense
,
return_residual
=
True
))
if
not
self
.
cross_attn
:
if
not
self
.
cross_attn
:
if
not
self
.
return_residual
:
if
not
self
.
return_residual
:
self
.
Wqkv
=
linear_cls
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
Wqkv
=
linear_cls
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
...
...
flash_attn/modules/mlp.py
View file @
e68ebbe8
...
@@ -5,11 +5,9 @@ import torch.nn as nn
...
@@ -5,11 +5,9 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
try
:
try
:
from
flash_attn.ops.fused_dense
import
fused_dense_gelu_dense_function_td
from
flash_attn.ops.fused_dense
import
FusedDenseGeluDense
from
flash_attn.ops.fused_dense
import
fused_dense_res_gelu_dense_function_td
except
ImportError
:
except
ImportError
:
fused_dense_gelu_dense_function_td
=
None
FusedDenseGeluDense
=
None
fused_dense_res_gelu_dense_function_td
=
None
class
Mlp
(
nn
.
Module
):
class
Mlp
(
nn
.
Module
):
...
@@ -30,43 +28,3 @@ class Mlp(nn.Module):
...
@@ -30,43 +28,3 @@ class Mlp(nn.Module):
y
=
self
.
activation
(
y
)
y
=
self
.
activation
(
y
)
y
=
self
.
fc2
(
y
)
y
=
self
.
fc2
(
y
)
return
y
if
not
self
.
return_residual
else
(
y
,
x
)
return
y
if
not
self
.
return_residual
else
(
y
,
x
)
class
FusedDenseGeluDense
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
bias
=
True
,
checkpoint_lvl
=
0
,
heuristic
=
0
,
return_residual
=
False
,
device
=
None
,
dtype
=
None
):
"""
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
heuristic:
-1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu
For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
assert
bias
==
True
,
"DenseGeluDense module without bias is currently not supported"
assert
(
fused_dense_gelu_dense_function_td
is
not
None
and
fused_dense_res_gelu_dense_function_td
is
not
None
),
'fused_dense_lib is not installed'
self
.
checkpoint_lvl
=
checkpoint_lvl
self
.
heuristic
=
heuristic
self
.
return_residual
=
return_residual
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
,
bias
=
bias
,
**
factory_kwargs
)
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
assert
x
.
is_cuda
fn
=
(
fused_dense_gelu_dense_function_td
if
not
self
.
return_residual
else
fused_dense_res_gelu_dense_function_td
)
return
fn
(
x
,
self
.
fc1
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
weight
,
self
.
fc2
.
bias
,
self
.
checkpoint_lvl
,
self
.
heuristic
)
flash_attn/ops/fused_dense.py
View file @
e68ebbe8
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
# We make it work with pytorch amp and with bfloat16.
# We make it work with pytorch amp and with bfloat16.
from
typing
import
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
# import fused_dense_cuda # from apex
# import fused_dense_cuda # from apex
...
@@ -11,126 +13,84 @@ import fused_dense_lib as fused_dense_cuda
...
@@ -11,126 +13,84 @@ import fused_dense_lib as fused_dense_cuda
from
flash_attn.ops.gelu_activation
import
gelu_bwd
from
flash_attn.ops.gelu_activation
import
gelu_bwd
# implements fused GEMM+bias in forward pass using mlp_cuda from apex
class
FusedDenseFunc
(
torch
.
autograd
.
Function
):
class
FusedDenseFuncTD
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
@
custom_fwd
@
custom_fwd
def
forward
(
ctx
,
x
,
weight
,
bias
):
def
forward
(
ctx
,
x
,
weight
,
bias
,
return_residual
=
False
):
if
torch
.
is_autocast_enabled
():
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_gpu_dtype
()
dtype
=
torch
.
get_autocast_gpu_dtype
()
x
,
weight
,
bias
=
[
a
.
to
(
dtype
=
dtype
)
for
a
in
[
x
,
weight
,
bias
]]
x
,
weight
=
[
a
.
to
(
dtype
=
dtype
)
for
a
in
[
x
,
weight
]]
bias
=
bias
.
to
(
dtype
=
dtype
)
if
bias
is
not
None
else
None
ctx
.
return_residual
=
return_residual
x
=
x
.
contiguous
()
x
=
x
.
contiguous
()
weight
=
weight
.
contiguous
()
weight
=
weight
.
contiguous
()
bias
=
bias
.
contiguous
()
ctx
.
save_for_backward
(
x
,
weight
)
ctx
.
save_for_backward
(
x
,
weight
)
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
batch_dim
=
batch_shape
.
numel
()
assert
batch_dim
<=
64
*
1024
,
'fused_dense only supports dimension at most 64k'
assert
batch_dim
<=
64
*
1024
,
'fused_dense only supports dimension at most 64k'
output
=
fused_dense_cuda
.
linear_bias_forward
(
x
.
reshape
(
batch_dim
,
n
)
,
weight
,
bias
)
output
=
F
.
linear
(
x
,
weight
,
bias
)
return
output
.
reshape
(
*
batch_shape
,
output
.
shape
[
-
1
]
)
return
output
if
not
return_residual
else
(
output
,
x
)
@
staticmethod
@
staticmethod
@
custom_bwd
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
,
*
args
):
grad_output
=
grad_output
.
contiguous
()
grad_output
=
grad_output
.
contiguous
()
if
ctx
.
return_residual
:
grad_input
,
=
args
grad_input
=
grad_input
.
contiguous
()
x
,
weight
=
ctx
.
saved_tensors
x
,
weight
=
ctx
.
saved_tensors
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
batch_dim
=
batch_shape
.
numel
()
if
ctx
.
needs_input_grad
[
0
]:
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
grad_input
,
grad_weight
,
grad_bias
=
fused_dense_cuda
.
linear_bias_backward
(
if
ctx
.
needs_input_grad
[
1
]:
x
.
reshape
(
batch_dim
,
n
),
weight
,
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
grad_weight
,
grad_bias
=
fused_dense_cuda
.
linear_bias_wgrad
(
x
.
reshape
(
batch_dim
,
n
),
grad_output
,
ctx
.
needs_input_grad
[
2
]
)
)
else
:
grad_weight
=
None
grad_bias
=
grad_output
if
ctx
.
needs_input_grad
[
2
]
else
None
if
ctx
.
needs_input_grad
[
0
]:
if
not
ctx
.
return_residual
:
grad_input
=
F
.
linear
(
grad_output
,
weight
.
t
())
else
:
grad_input
=
torch
.
addmm
(
grad_input
.
reshape
(
batch_dim
,
n
),
grad_output
,
weight
)
grad_input
=
grad_input
.
reshape_as
(
x
)
grad_input
=
grad_input
.
reshape_as
(
x
)
else
:
else
:
grad_weight
,
grad_bias
=
fused_dense_cuda
.
linear_bias_wgrad
(
x
.
reshape
(
batch_dim
,
n
),
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
)
grad_input
=
None
grad_input
=
None
# print((grad_bias - grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)).abs().max())
return
grad_input
,
grad_weight
,
grad_bias
,
None
return
grad_input
,
grad_weight
,
grad_bias
# grad_input, grad_weight = None, None
# grad_output_reshaped = grad_output.reshape(batch_dim, grad_output.shape[-1])
# if ctx.needs_input_grad[0]:
# grad_input = (grad_output_reshaped @ weight.conj()).reshape(*batch_shape, n)
# if ctx.needs_input_grad[1]:
# grad_weight = grad_output_reshaped.t() @ x.conj().reshape(batch_dim, n)
# # We don't need to compute grad_bias explicitly, when we return grad_out Pytorch
# # will sum over the batch dimension to get grad_bias.
# return grad_input, grad_weight, grad_output
fused_dense_function_td
=
FusedDenseFuncTD
.
apply
def
fused_dense_func
(
x
:
Tensor
,
weight
:
Tensor
,
bias
:
Optional
[
Tensor
]
=
None
,
return_residual
:
bool
=
False
):
batch_dim
=
x
.
shape
[:
-
1
].
numel
()
dtype_eligible
=
(
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
or
(
x
.
dtype
==
torch
.
float32
and
torch
.
is_autocast_enabled
()))
if
(
x
.
is_cuda
and
weight
.
is_cuda
and
(
bias
is
None
or
bias
.
is_cuda
)
and
batch_dim
<=
64
*
1024
and
dtype_eligible
):
return
FusedDenseFunc
.
apply
(
x
,
weight
,
bias
,
return_residual
)
else
:
out
=
F
.
linear
(
x
,
weight
,
bias
)
return
out
if
not
return_residual
else
(
out
,
x
)
class
FusedDense
TD
(
nn
.
Linear
):
class
FusedDense
(
nn
.
Linear
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
bias
:
bool
=
True
,
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
bias
:
bool
=
True
,
device
=
None
,
dtype
=
None
)
->
None
:
return_residual
:
bool
=
False
,
device
=
None
,
dtype
=
None
)
->
None
:
super
().
__init__
(
in_features
,
out_features
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
)
super
().
__init__
(
in_features
,
out_features
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
)
self
.
return_residual
=
return_residual
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
if
x
.
is_cuda
and
self
.
bias
is
not
None
:
return
fused_dense_func
(
x
,
self
.
weight
,
self
.
bias
,
return_residual
=
self
.
return_residual
)
return
fused_dense_function_td
(
x
,
self
.
weight
,
self
.
bias
)
else
:
return
F
.
linear
(
x
,
self
.
weight
,
self
.
bias
)
class
FusedDense
Residual
Func
(
torch
.
autograd
.
Function
):
class
FusedDense
GeluDense
Func
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
@
custom_fwd
@
custom_fwd
def
forward
(
ctx
,
x
,
weight
,
bias
):
def
forward
(
ctx
,
x
,
weight1
,
bias1
,
weight2
,
bias2
,
save_gelu_in
=
True
,
return_residual
=
False
,
if
torch
.
is_autocast_enabled
():
checkpoint_lvl
=
0
,
heuristic
=
0
):
dtype
=
torch
.
get_autocast_gpu_dtype
()
x
,
weight
,
bias
=
[
a
.
to
(
dtype
=
dtype
)
for
a
in
[
x
,
weight
,
bias
]]
x
=
x
.
contiguous
()
x
=
x
.
contiguous
()
weight
=
weight
.
contiguous
()
bias
=
bias
.
contiguous
()
ctx
.
save_for_backward
(
x
,
weight
)
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
assert
batch_dim
<=
64
*
1024
,
'fused_dense only supports dimension at most 64k'
output
=
fused_dense_cuda
.
linear_bias_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight
,
bias
)
return
output
.
reshape
(
*
batch_shape
,
output
.
shape
[
-
1
]),
x
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
grad_output
,
grad_input
):
grad_output
=
grad_output
.
contiguous
()
grad_input
=
grad_input
.
contiguous
()
x
,
weight
=
ctx
.
saved_tensors
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
grad_input
,
grad_weight
,
grad_bias
=
fused_dense_cuda
.
linear_bias_residual_backward
(
x
.
reshape
(
batch_dim
,
n
),
weight
,
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
]),
grad_input
.
reshape
(
batch_dim
,
n
)
)
return
grad_input
.
reshape_as
(
x
),
grad_weight
,
grad_bias
fused_dense_residual_function
=
FusedDenseResidualFunc
.
apply
class
FusedDenseResidual
(
nn
.
Linear
):
"""Similar to FusedDense, but we return both the output and the input.
This is so that in the backward pass, we can combine the input gradient from the residual branch
with the input gradient from the matrix multiply, without having to do a separate addition.
"""
def
forward
(
self
,
x
):
if
x
.
is_cuda
and
self
.
bias
is
not
None
:
return
fused_dense_residual_function
(
x
,
self
.
weight
,
self
.
bias
)
else
:
return
F
.
linear
(
x
,
self
.
weight
,
self
.
bias
),
x
class
FusedDenseGeluDenseFuncTD
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
x
,
weight1
,
bias1
,
weight2
,
bias2
,
checkpoint_lvl
=
0
,
heuristic
=
0
):
"""checkpoint_lvl:
"""checkpoint_lvl:
0: no recomputation in the bwd
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
1: recompute gelu_out in the bwd
...
@@ -139,49 +99,53 @@ class FusedDenseGeluDenseFuncTD(torch.autograd.Function):
...
@@ -139,49 +99,53 @@ class FusedDenseGeluDenseFuncTD(torch.autograd.Function):
assert
-
1
<=
heuristic
<=
4
assert
-
1
<=
heuristic
<=
4
if
torch
.
is_autocast_enabled
():
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_gpu_dtype
()
dtype
=
torch
.
get_autocast_gpu_dtype
()
x
,
weight1
,
bias1
,
weight2
,
bias2
=
[
a
.
to
(
dtype
=
dtype
)
x
,
weight1
,
weight2
=
[
a
.
to
(
dtype
=
dtype
)
for
a
in
[
x
,
weight1
,
weight2
]]
for
a
in
[
x
,
weight1
,
bias1
,
weight2
,
bias2
]]
bias1
=
bias1
.
to
(
dtype
=
dtype
)
if
bias1
is
not
None
else
None
bias2
=
bias2
.
to
(
dtype
=
dtype
)
if
bias2
is
not
None
else
None
if
not
save_gelu_in
:
checkpoint_lvl
=
2
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
ctx
.
return_residual
=
return_residual
x
=
x
.
contiguous
()
x
=
x
.
contiguous
()
weight1
=
weight1
.
contiguous
()
weight1
=
weight1
.
contiguous
()
bias1
=
bias1
.
contiguous
()
bias1
=
bias1
.
contiguous
()
if
bias1
is
not
None
else
None
weight2
=
weight2
.
contiguous
()
weight2
=
weight2
.
contiguous
()
bias2
=
bias2
.
contiguous
()
bias2
=
bias2
.
contiguous
()
if
bias2
is
not
None
else
None
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
batch_dim
=
batch_shape
.
numel
()
assert
batch_dim
<=
64
*
1024
,
'fused_dense only supports dimension at most 64k'
assert
batch_dim
<=
64
*
1024
,
'fused_dense only supports dimension at most 64k'
# output1, output2, gelu_in = fused_dense_cuda.linear_gelu_linear_forward(
# x.reshape(batch_dim, n), weight1, bias1, weight2, bias2
# )
if
heuristic
==
-
1
:
if
heuristic
==
-
1
:
gelu_in
=
fused_dense_cuda
.
linear_bias_forward
(
x
.
reshape
(
batch_dim
,
n
)
,
weight1
,
bias1
)
gelu_in
=
F
.
linear
(
x
,
weight1
,
bias1
)
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
# gelu_in = F.linear(x.reshape(batch_dim, n), weight1) # This is before adding bias1
# gelu_in = F.linear(x.reshape(batch_dim, n), weight1) # This is before adding bias1
# with torch.jit.fuser('fuser2'):
# with torch.jit.fuser('fuser2'):
# output1 = bias_gelu(gelu_in, bias1)
# output1 = bias_gelu(gelu_in, bias1)
else
:
else
:
save_gelu_in
=
checkpoint_lvl
!=
2
output1
,
*
rest
=
fused_dense_cuda
.
linear_gelu_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
output1
,
*
rest
=
fused_dense_cuda
.
linear_gelu_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
save_gelu_in
,
heuristic
)
bias1
,
save_gelu_in
,
heuristic
)
if
save_gelu_in
:
if
save_gelu_in
:
gelu_in
=
rest
[
0
]
gelu_in
=
rest
[
0
]
output2
=
fused_dense_cuda
.
linear_bias_forward
(
output1
,
weight2
,
bias2
)
output2
=
F
.
linear
(
output1
,
weight2
,
bias2
)
ctx
.
checkpoint_lvl
=
checkpoint_lvl
ctx
.
checkpoint_lvl
=
checkpoint_lvl
ctx
.
heuristic
=
heuristic
ctx
.
heuristic
=
heuristic
if
checkpoint_lvl
==
0
:
if
checkpoint_lvl
==
0
:
ctx
.
save_for_backward
(
x
,
weight1
,
bias1
,
weight2
,
gelu_in
,
output1
)
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
gelu_in
,
output1
)
elif
checkpoint_lvl
==
1
:
elif
checkpoint_lvl
==
1
:
ctx
.
save_for_backward
(
x
,
weight1
,
bias1
,
weight2
,
gelu_in
)
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
gelu_in
)
elif
checkpoint_lvl
==
2
:
elif
checkpoint_lvl
==
2
:
ctx
.
save_for_backward
(
x
,
weight1
,
bias1
,
weight2
)
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
bias1
)
return
output2
.
reshape
(
*
batch_shape
,
output2
.
shape
[
-
1
])
output2
=
output2
.
reshape
(
*
batch_shape
,
output2
.
shape
[
-
1
])
return
output2
if
not
return_residual
else
(
output2
,
x
)
@
staticmethod
@
staticmethod
@
custom_bwd
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
,
*
args
):
grad_output
=
grad_output
.
contiguous
()
grad_output
=
grad_output
.
contiguous
()
checkpoint_lvl
=
ctx
.
checkpoint_lvl
checkpoint_lvl
=
ctx
.
checkpoint_lvl
x
,
weight1
,
bias1
,
weight2
,
*
rest
=
ctx
.
saved_tensors
if
ctx
.
return_residual
:
grad_input
,
=
args
grad_input
=
grad_input
.
contiguous
()
x
,
weight1
,
weight2
,
*
rest
=
ctx
.
saved_tensors
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
batch_dim
=
batch_shape
.
numel
()
if
checkpoint_lvl
==
0
:
if
checkpoint_lvl
==
0
:
...
@@ -190,55 +154,88 @@ class FusedDenseGeluDenseFuncTD(torch.autograd.Function):
...
@@ -190,55 +154,88 @@ class FusedDenseGeluDenseFuncTD(torch.autograd.Function):
gelu_in
,
=
rest
gelu_in
,
=
rest
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
elif
checkpoint_lvl
==
2
:
elif
checkpoint_lvl
==
2
:
#
bias1, = rest
bias1
,
=
rest
if
ctx
.
heuristic
==
-
1
:
if
ctx
.
heuristic
==
-
1
:
gelu_in
=
fused_dense_cuda
.
linear_bias_forward
(
x
.
reshape
(
batch_dim
,
n
)
,
weight1
,
bias1
)
gelu_in
=
F
.
linear
(
x
,
weight1
,
bias1
)
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
else
:
else
:
output1
,
gelu_in
=
fused_dense_cuda
.
linear_gelu_forward
(
x
.
reshape
(
batch_dim
,
n
),
output1
,
gelu_in
=
fused_dense_cuda
.
linear_gelu_forward
(
weight1
,
bias1
,
True
,
ctx
.
heuristic
)
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
True
,
ctx
.
heuristic
)
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
output1
=
output1
.
reshape
(
batch_dim
,
output1
.
shape
[
-
1
])
gelu_in
=
gelu_in
.
reshape
(
batch_dim
,
gelu_in
.
shape
[
-
1
])
if
ctx
.
needs_input_grad
[
3
]:
grad_weight2
,
grad_bias2
=
fused_dense_cuda
.
linear_bias_wgrad
(
output1
,
grad_output
,
ctx
.
needs_input_grad
[
4
]
)
else
:
grad_weight2
=
None
grad_bias2
=
grad_output
if
ctx
.
needs_input_grad
[
4
]
else
None
if
ctx
.
heuristic
==
-
1
:
if
ctx
.
heuristic
==
-
1
:
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
# grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output)
grad_weight2
,
grad_bias2
=
fused_dense_cuda
.
linear_bias_wgrad
(
output1
,
grad_output
)
# grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
# grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
grad_output1
=
grad_output
@
weight2
grad_output1
=
F
.
linear
(
grad_output
,
weight2
.
t
())
with
torch
.
jit
.
fuser
(
'fuser2'
):
with
torch
.
jit
.
fuser
(
'fuser2'
):
grad_gelu
=
gelu_bwd
(
grad_output1
,
gelu_in
)
grad_gelu
=
gelu_bwd
(
grad_output1
,
gelu_in
)
grad_input
,
grad_weight1
,
grad_bias1
=
fused_dense_cuda
.
linear_bias_backward
(
if
ctx
.
needs_input_grad
[
1
]:
x
.
reshape
(
batch_dim
,
n
),
weight1
,
grad_gelu
grad_weight1
,
grad_bias1
=
fused_dense_cuda
.
linear_bias_wgrad
(
)
x
.
reshape
(
batch_dim
,
n
),
grad_gelu
,
ctx
.
needs_input_grad
[
2
]
# with torch.jit.fuser('fuser2'):
)
# grad_gelu, grad_bias1 = bias_gelu_back(grad_output1, gelu_in, bias1)
else
:
# grad_input = grad_gelu @ weight1
grad_weight1
=
None
# grad_weight1 = grad_gelu.reshape(batch_dim, -1).T @ x.reshape(batch_dim, n)
grad_bias1
=
grad_gelu
if
ctx
.
needs_input_grad
[
2
]
else
None
# grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
# x.reshape(batch_dim, n), weight1, grad_gelu
# )
else
:
else
:
grad_input
,
grad_weight1
,
grad_
bias
1
,
grad
_weight2
,
grad_bias2
=
fused_dense_cuda
.
linear_gelu_linear_backward
(
# The cublasLt epilogue has to compute both gelu grad and
bias grad
, we can't
x
.
reshape
(
batch_dim
,
n
),
gelu_in
,
output1
,
weight1
,
weight2
,
# just compute gelu grad
grad_
output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
]),
grad_
gelu
,
grad_bias1
=
fused_dense_cuda
.
bias_gelu_linear_dgrad_bgrad
(
ctx
.
heuristic
weight2
,
grad_output
,
gelu_in
,
ctx
.
heuristic
)
)
# grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
if
not
ctx
.
needs_input_grad
[
2
]:
# # grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output)
grad_bias1
=
None
# grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
if
ctx
.
needs_input_grad
[
1
]:
# grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
grad_weight1
=
F
.
linear
(
grad_gelu
.
t
(),
x
.
reshape
(
batch_dim
,
n
).
t
())
# grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
else
:
# x.reshape(batch_dim, n), weight1, grad_gelu
grad_weight1
=
None
# )
if
ctx
.
needs_input_grad
[
0
]:
return
grad_input
.
reshape_as
(
x
),
grad_weight1
,
grad_bias1
,
grad_weight2
,
grad_bias2
,
None
,
None
if
not
ctx
.
return_residual
:
grad_input
=
F
.
linear
(
grad_gelu
,
weight1
.
t
())
else
:
fused_dense_gelu_dense_function_td
=
FusedDenseGeluDenseFuncTD
.
apply
grad_input
=
torch
.
addmm
(
grad_input
.
reshape
(
batch_dim
,
n
),
grad_gelu
,
weight1
)
grad_input
=
grad_input
.
reshape_as
(
x
)
else
:
grad_input
=
None
return
grad_input
,
grad_weight1
,
grad_bias1
,
grad_weight2
,
grad_bias2
,
None
,
None
,
None
,
None
def
fused_dense_gelu_dense_func
(
x
:
Tensor
,
weight1
:
Tensor
,
weight2
:
Tensor
,
bias1
:
Optional
[
Tensor
]
=
None
,
bias2
:
Optional
[
Tensor
]
=
None
,
save_gelu_in
:
bool
=
True
,
return_residual
:
bool
=
False
,
checkpoint_lvl
:
int
=
0
,
heuristic
:
int
=
0
):
batch_dim
=
x
.
shape
[:
-
1
].
numel
()
dtype_eligible
=
(
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
or
(
x
.
dtype
==
torch
.
float32
and
torch
.
is_autocast_enabled
()))
if
(
x
.
is_cuda
and
weight1
.
is_cuda
and
weight2
.
is_cuda
and
(
bias1
is
None
or
bias1
.
is_cuda
)
and
(
bias2
is
None
or
bias2
.
is_cuda
)
and
batch_dim
<=
64
*
1024
and
dtype_eligible
):
return
FusedDenseGeluDenseFunc
.
apply
(
x
,
weight1
,
bias1
,
weight2
,
bias2
,
save_gelu_in
,
return_residual
,
checkpoint_lvl
,
heuristic
)
else
:
gelu_in
=
F
.
linear
(
x
,
weight1
,
bias1
)
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
output2
=
F
.
linear
(
output1
,
weight2
,
bias2
)
return
output2
if
not
return_residual
else
(
output2
,
x
)
class
FusedDenseGeluDense
TD
(
nn
.
Module
):
class
FusedDenseGeluDense
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
intermediate_features
,
out_features
=
None
,
bias
=
True
,
def
__init__
(
self
,
in_features
,
hidden_features
,
out_features
=
None
,
bias1
=
True
,
checkpoint_lvl
=
0
,
heuristic
=
0
,
device
=
None
,
dtype
=
None
):
bias2
=
True
,
return_residual
=
False
,
checkpoint_lvl
=
0
,
heuristic
=
0
,
device
=
None
,
dtype
=
None
):
"""
"""
checkpoint_lvl (increasing lvl means slower but more memory saving):
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
0: no recomputation in the bwd
...
@@ -247,110 +244,26 @@ class FusedDenseGeluDenseTD(nn.Module):
...
@@ -247,110 +244,26 @@ class FusedDenseGeluDenseTD(nn.Module):
heuristic:
heuristic:
-1: don't fuse gemm + gelu (separate kernel)
-1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu
0..4: use this heuristic for the algo section in the fused gemm + gelu
For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
"""
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
super
().
__init__
()
if
out_features
is
None
:
if
out_features
is
None
:
out_features
=
in_features
out_features
=
in_features
assert
bias
==
True
,
"DenseGeluDense module without bias is currently not supported"
self
.
return_residual
=
return_residual
self
.
checkpoint_lvl
=
checkpoint_lvl
self
.
checkpoint_lvl
=
checkpoint_lvl
self
.
heuristic
=
heuristic
self
.
heuristic
=
heuristic
self
.
fc1
=
nn
.
Linear
(
in_features
,
intermediate
_features
,
bias
=
bias
,
**
factory_kwargs
)
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden
_features
,
bias
=
bias
1
,
**
factory_kwargs
)
self
.
fc2
=
nn
.
Linear
(
intermediate
_features
,
out_features
,
bias
=
bias
,
**
factory_kwargs
)
self
.
fc2
=
nn
.
Linear
(
hidden
_features
,
out_features
,
bias
=
bias
2
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
fused_dense_gelu_dense_function_td
(
x
,
self
.
fc1
.
weight
,
self
.
fc1
.
bias
,
return
fused_dense_gelu_dense_func
(
self
.
fc2
.
weight
,
self
.
fc2
.
bias
,
x
,
self
.
fc1
.
weight
,
self
.
fc2
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
bias
,
self
.
checkpoint_lvl
,
self
.
heuristic
)
save_gelu_in
=
self
.
training
,
return_residual
=
self
.
return_residual
,
checkpoint_lvl
=
self
.
checkpoint_lvl
,
heuristic
=
self
.
heuristic
class
FusedDenseResGeluDenseFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
x
,
weight1
,
bias1
,
weight2
,
bias2
,
checkpoint_lvl
=
0
,
heuristic
=
0
):
"""checkpoint_lvl:
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
"""
assert
-
1
<=
heuristic
<=
4
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_gpu_dtype
()
x
,
weight1
,
bias1
,
weight2
,
bias2
=
[
a
.
to
(
dtype
=
dtype
)
for
a
in
[
x
,
weight1
,
bias1
,
weight2
,
bias2
]]
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
x
=
x
.
contiguous
()
weight1
=
weight1
.
contiguous
()
bias1
=
bias1
.
contiguous
()
weight2
=
weight2
.
contiguous
()
bias2
=
bias2
.
contiguous
()
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
assert
batch_dim
<=
64
*
1024
,
'fused_dense only supports dimension at most 64k'
# output1, output2, gelu_in = fused_dense_cuda.linear_gelu_linear_forward(
# x.reshape(batch_dim, n), weight1, bias1, weight2, bias2
# )
# gelu_in = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
# output1 = F.gelu(gelu_in, approximate='tanh')
save_gelu_in
=
checkpoint_lvl
!=
2
output1
,
*
rest
=
fused_dense_cuda
.
linear_gelu_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
save_gelu_in
,
heuristic
)
if
save_gelu_in
:
gelu_in
=
rest
[
0
]
output2
=
fused_dense_cuda
.
linear_bias_forward
(
output1
,
weight2
,
bias2
)
ctx
.
checkpoint_lvl
=
checkpoint_lvl
ctx
.
heuristic
=
heuristic
if
checkpoint_lvl
==
0
:
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
gelu_in
,
output1
)
elif
checkpoint_lvl
==
1
:
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
gelu_in
)
elif
checkpoint_lvl
==
2
:
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
bias1
)
return
output2
.
reshape
(
*
batch_shape
,
output2
.
shape
[
-
1
]),
x
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
grad_output
,
grad_input
):
grad_output
=
grad_output
.
contiguous
()
grad_input
=
grad_input
.
contiguous
()
checkpoint_lvl
=
ctx
.
checkpoint_lvl
x
,
weight1
,
weight2
,
*
rest
=
ctx
.
saved_tensors
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
if
checkpoint_lvl
==
0
:
gelu_in
,
output1
=
rest
elif
checkpoint_lvl
==
1
:
gelu_in
,
=
rest
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
elif
checkpoint_lvl
==
2
:
bias1
,
=
rest
output1
,
gelu_in
=
fused_dense_cuda
.
linear_gelu_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
True
,
ctx
.
heuristic
)
grad_input
,
grad_weight1
,
grad_bias1
,
grad_weight2
,
grad_bias2
=
fused_dense_cuda
.
linear_residual_gelu_linear_backward
(
x
.
reshape
(
batch_dim
,
n
),
gelu_in
,
output1
,
weight1
,
weight2
,
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
]),
grad_input
.
reshape
(
batch_dim
,
n
),
ctx
.
heuristic
)
)
# grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
# # grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output)
# grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
# grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
# grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_residual_backward(
# x.reshape(batch_dim, n), weight1, grad_gelu,
# grad_input.reshape(batch_dim, n)
# )
return
grad_input
.
reshape_as
(
x
),
grad_weight1
,
grad_bias1
,
grad_weight2
,
grad_bias2
,
None
,
None
fused_dense_res_gelu_dense_function_td
=
FusedDenseResGeluDenseFunc
.
apply
class
FusedDenseResGeluDense
(
FusedDenseGeluDenseTD
):
def
forward
(
self
,
x
):
return
fused_dense_res_gelu_dense_function_td
(
x
,
self
.
fc1
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
weight
,
self
.
fc2
.
bias
,
self
.
checkpoint_lvl
,
False
,
self
.
heuristic
)
tests/ops/test_fused_dense.py
View file @
e68ebbe8
...
@@ -6,29 +6,44 @@ import pytest
...
@@ -6,29 +6,44 @@ import pytest
from
einops
import
rearrange
from
einops
import
rearrange
from
flash_attn.ops.fused_dense
import
FusedDenseTD
,
FusedDenseGeluDenseTD
from
flash_attn.ops.fused_dense
import
FusedDense
,
FusedDenseGeluDense
from
flash_attn.ops.fused_dense
import
FusedDenseResidual
,
FusedDenseResGeluDense
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'return_residual'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'has_bias'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
,
4096
])
def
test_fused_linear_bias
(
in_features
,
out_features
,
dtype
):
def
test_fused_linear_bias
(
in_features
,
out_features
,
has_bias
,
return_residual
,
dtype
):
device
=
'cuda'
device
=
'cuda'
rtol
,
atol
=
(
3e-3
,
1e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
rtol
,
atol
=
(
3e-3
,
1e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
batch_size
=
8
seqlen
=
512
seqlen
=
512
x_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
model_pt
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
bias
=
has_bias
,
device
=
device
,
dtype
=
dtype
)
model
=
FusedDenseTD
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
model
=
FusedDense
(
in_features
,
out_features
,
bias
=
has_bias
,
return_residual
=
return_residual
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
model_pt
.
weight
)
model
.
weight
.
copy_
(
model_pt
.
weight
)
model
.
bias
.
copy_
(
model_pt
.
bias
)
if
has_bias
:
model
.
bias
.
copy_
(
model_pt
.
bias
)
out_pt
=
model_pt
(
x_pt
)
out_pt
=
model_pt
(
x_pt
)
out
=
model
(
x
)
if
not
return_residual
:
out
=
model
(
x
)
else
:
out
,
x_copy
=
model
(
x
)
x_copy
=
(
x_copy
[...,
:
out_features
]
if
out_features
<
in_features
else
F
.
pad
(
x_copy
,
(
0
,
out_features
-
in_features
)))
x_pt_copy
=
(
x_pt
[...,
:
out_features
]
if
out_features
<
in_features
else
F
.
pad
(
x_pt
,
(
0
,
out_features
-
in_features
)))
# Just add some random function of the residual
out_pt
=
out_pt
+
F
.
gelu
(
x_pt_copy
)
out
=
out
+
F
.
gelu
(
x_copy
)
# with torch.no_grad():
# with torch.no_grad():
# out_fl = F.linear(x_pt.float(), model.weight.float(), model.bias.float()).half()
# out_fl = F.linear(x_pt.float(), model.weight.float(), model.bias.float()).half()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
atol
)
...
@@ -40,66 +55,52 @@ def test_fused_linear_bias(in_features, out_features, dtype):
...
@@ -40,66 +55,52 @@ def test_fused_linear_bias(in_features, out_features, dtype):
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
# The error for d_weight and d_bias is quite a bit higher
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
weight
.
grad
,
model_pt
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
weight
.
grad
,
model_pt
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
bias
.
grad
,
model_pt
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
if
has_bias
:
assert
torch
.
allclose
(
model
.
bias
.
grad
,
model_pt
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'out_features,in_features'
,
[(
1024
,
1024
),
(
4096
,
4096
)])
def
test_fused_linear_bias_residual
(
in_features
,
out_features
,
dtype
):
device
=
'cuda'
rtol
,
atol
=
(
3e-3
,
1e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
x_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
model
=
FusedDenseResidual
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
model_pt
.
weight
)
model
.
bias
.
copy_
(
model_pt
.
bias
)
out_pt
=
model_pt
(
x_pt
)
+
F
.
gelu
(
x_pt
)
# Just add some random function of the residual x_pt
out
,
x_copy
=
model
(
x
)
out
=
out
+
F
.
gelu
(
x_copy
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
atol
*
2
)
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
out
)
/
32
out_pt
.
backward
(
g
)
out
.
backward
(
g
)
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
weight
.
grad
,
model_pt
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
bias
.
grad
,
model_pt
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'heuristic'
,
[
1
,
-
1
])
@
pytest
.
mark
.
parametrize
(
'heuristic'
,
[
0
,
-
1
])
@
pytest
.
mark
.
parametrize
(
'checkpoint_lvl'
,
[
0
,
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'checkpoint_lvl'
,
[
0
,
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'return_residual'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'has_bias2'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'has_bias1'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
,
4096
])
def
test_fused_dense_gelu_dense
(
in_features
,
out_features
,
checkpoint_lvl
,
heuristic
,
dtype
):
def
test_fused_dense_gelu_dense
(
in_features
,
out_features
,
has_bias1
,
has_bias2
,
return_residual
,
checkpoint_lvl
,
heuristic
,
dtype
):
device
=
'cuda'
device
=
'cuda'
rtol
,
atol
=
(
3e-3
,
1
e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
rtol
,
atol
=
(
3e-3
,
3
e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
batch_size
=
8
seqlen
=
512
seqlen
=
512
x_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt_fc1
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
model_pt_fc1
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
bias
=
has_bias1
,
device
=
device
,
model_pt_fc2
=
torch
.
nn
.
Linear
(
out_features
,
in_features
,
device
=
device
,
dtype
=
dtype
)
dtype
=
dtype
)
model
=
FusedDenseGeluDenseTD
(
in_features
,
out_features
,
in_features
,
model_pt_fc2
=
torch
.
nn
.
Linear
(
out_features
,
in_features
,
bias
=
has_bias2
,
device
=
device
,
checkpoint_lvl
=
checkpoint_lvl
,
heuristic
=
heuristic
,
dtype
=
dtype
)
device
=
device
,
dtype
=
dtype
)
model
=
FusedDenseGeluDense
(
in_features
,
out_features
,
in_features
,
bias1
=
has_bias1
,
bias2
=
has_bias2
,
return_residual
=
return_residual
,
checkpoint_lvl
=
checkpoint_lvl
,
heuristic
=
heuristic
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
.
fc1
.
weight
.
copy_
(
model_pt_fc1
.
weight
)
model
.
fc1
.
weight
.
copy_
(
model_pt_fc1
.
weight
)
model
.
fc1
.
bias
.
copy_
(
model_pt_fc1
.
bias
)
if
has_bias1
:
model
.
fc1
.
bias
.
copy_
(
model_pt_fc1
.
bias
)
model
.
fc2
.
weight
.
copy_
(
model_pt_fc2
.
weight
)
model
.
fc2
.
weight
.
copy_
(
model_pt_fc2
.
weight
)
model
.
fc2
.
bias
.
copy_
(
model_pt_fc2
.
bias
)
if
has_bias2
:
model
.
fc2
.
bias
.
copy_
(
model_pt_fc2
.
bias
)
out_pt
=
model_pt_fc2
(
F
.
gelu
(
model_pt_fc1
(
x_pt
),
approximate
=
'tanh'
))
out_pt
=
model_pt_fc2
(
F
.
gelu
(
model_pt_fc1
(
x_pt
),
approximate
=
'tanh'
))
out
=
model
(
x
)
if
not
return_residual
:
out
=
model
(
x
)
else
:
out
,
x_copy
=
model
(
x
)
# Just add some random function of the residual
out_pt
=
out_pt
+
F
.
gelu
(
x_pt
)
out
=
out
+
F
.
gelu
(
x_copy
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
atol
)
# If we don't divide by batch_size, the gradient gets a bit too large.
# If we don't divide by batch_size, the gradient gets a bit too large.
...
@@ -109,46 +110,8 @@ def test_fused_dense_gelu_dense(in_features, out_features, checkpoint_lvl, heuri
...
@@ -109,46 +110,8 @@ def test_fused_dense_gelu_dense(in_features, out_features, checkpoint_lvl, heuri
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
# The error for d_weight and d_bias is quite a bit higher
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
fc1
.
weight
.
grad
,
model_pt_fc1
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
fc1
.
weight
.
grad
,
model_pt_fc1
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
fc1
.
bias
.
grad
,
model_pt_fc1
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
if
has_bias1
:
assert
torch
.
allclose
(
model
.
fc2
.
weight
.
grad
,
model_pt_fc2
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
fc1
.
bias
.
grad
,
model_pt_fc1
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
fc2
.
bias
.
grad
,
model_pt_fc2
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'checkpoint_lvl'
,
[
0
,
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
,
4096
])
def
test_fused_dense_residual_gelu_dense
(
in_features
,
out_features
,
checkpoint_lvl
,
dtype
):
device
=
'cuda'
rtol
,
atol
=
(
3e-3
,
1e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
x_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt_fc1
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
model_pt_fc2
=
torch
.
nn
.
Linear
(
out_features
,
in_features
,
device
=
device
,
dtype
=
dtype
)
model
=
FusedDenseResGeluDense
(
in_features
,
out_features
,
in_features
,
checkpoint_lvl
=
checkpoint_lvl
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
model
.
fc1
.
weight
.
copy_
(
model_pt_fc1
.
weight
)
model
.
fc1
.
bias
.
copy_
(
model_pt_fc1
.
bias
)
model
.
fc2
.
weight
.
copy_
(
model_pt_fc2
.
weight
)
model
.
fc2
.
bias
.
copy_
(
model_pt_fc2
.
bias
)
out_pt
=
model_pt_fc2
(
F
.
gelu
(
model_pt_fc1
(
x_pt
),
approximate
=
'tanh'
))
+
F
.
gelu
(
x_pt
)
out
,
x_copy
=
model
(
x
)
out
=
out
+
F
.
gelu
(
x_copy
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
atol
*
2
)
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
out
)
/
32
out_pt
.
backward
(
g
)
out
.
backward
(
g
)
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
fc1
.
weight
.
grad
,
model_pt_fc1
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
fc1
.
bias
.
grad
,
model_pt_fc1
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
fc2
.
weight
.
grad
,
model_pt_fc2
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
fc2
.
weight
.
grad
,
model_pt_fc2
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
fc2
.
bias
.
grad
,
model_pt_fc2
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
if
has_bias2
:
assert
torch
.
allclose
(
model
.
fc2
.
bias
.
grad
,
model_pt_fc2
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment