Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
e68ebbe8
Commit
e68ebbe8
authored
Dec 22, 2022
by
Tri Dao
Browse files
Simplify FusedDense
parent
1bc6e5b0
Changes
9
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
318 additions
and
1063 deletions
+318
-1063
csrc/fused_dense_lib/fused_dense.cpp
csrc/fused_dense_lib/fused_dense.cpp
+72
-243
csrc/fused_dense_lib/fused_dense_cuda.cu
csrc/fused_dense_lib/fused_dense_cuda.cu
+18
-431
flash_attn/layers/patch_embed.py
flash_attn/layers/patch_embed.py
+4
-4
flash_attn/models/bert.py
flash_attn/models/bert.py
+10
-8
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+2
-0
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+7
-6
flash_attn/modules/mlp.py
flash_attn/modules/mlp.py
+2
-44
flash_attn/ops/fused_dense.py
flash_attn/ops/fused_dense.py
+146
-233
tests/ops/test_fused_dense.py
tests/ops/test_fused_dense.py
+57
-94
No files found.
csrc/fused_dense_lib/fused_dense.cpp
View file @
e68ebbe8
...
@@ -6,6 +6,8 @@
...
@@ -6,6 +6,8 @@
#include <stdio.h>
#include <stdio.h>
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h
// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \
#define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \
...
@@ -24,14 +26,6 @@
...
@@ -24,14 +26,6 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
}
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
template
<
typename
T
>
int
linear_bias_forward_cuda
(
at
::
Tensor
input
,
T
*
weight
,
at
::
Tensor
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
at
::
Tensor
output
,
void
*
lt_workspace
);
template
<
typename
T
>
int
linear_bias_backward_cuda
(
T
*
input
,
T
*
weight
,
T
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
T
*
d_weight
,
T
*
d_bias
,
T
*
d_input
,
bool
residual
,
void
*
lt_workspace
);
template
<
typename
T
>
template
<
typename
T
>
int
linear_bias_wgrad_cuda
(
T
*
input
,
T
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
T
*
d_weight
,
T
*
d_bias
,
void
*
lt_workspace
);
int
linear_bias_wgrad_cuda
(
T
*
input
,
T
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
T
*
d_weight
,
T
*
d_bias
,
void
*
lt_workspace
);
...
@@ -39,103 +33,34 @@ template <typename T>
...
@@ -39,103 +33,34 @@ template <typename T>
int
linear_gelu_forward_cuda
(
T
*
input
,
T
*
weight
,
T
*
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
int
heuristic
,
T
*
output
,
T
*
gelu_in
,
void
*
lt_workspace
)
;
int
linear_gelu_forward_cuda
(
T
*
input
,
T
*
weight
,
T
*
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
int
heuristic
,
T
*
output
,
T
*
gelu_in
,
void
*
lt_workspace
)
;
template
<
typename
T
>
template
<
typename
T
>
int
linear_gelu_linear_backward_cuda
(
T
*
input
,
T
*
gelu_in
,
T
*
output1
,
T
*
weight1
,
T
*
weight2
,
T
*
d_output1
,
T
*
d_output2
,
int
in_features
,
int
batch_size
,
int
hidden_features
,
int
out_features
,
int
heuristic
,
T
*
d_weight1
,
T
*
d_weight2
,
T
*
d_bias1
,
T
*
d_bias2
,
T
*
d_input
,
bool
residual
,
void
*
lt_workspace
);
int
bias_gelu_linear_dgrad_bgrad_cuda
(
T
*
weight
,
T
*
d_output
,
T
*
gelu_in
,
int
in_features
,
int
batch_size
,
int
out_features
,
int
heuristic
,
T
*
d_input
,
T
*
d_bias
,
void
*
lt_workspace
);
at
::
Tensor
linear_bias_forward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
at
::
Tensor
bias
)
{
auto
batch_size
=
input
.
size
(
0
);
auto
in_features
=
input
.
size
(
1
);
int
out_features
=
weight
.
size
(
0
);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto
out
=
at
::
empty
({
batch_size
,
out_features
},
at
::
dtype
(
input
.
dtype
()).
device
(
input
.
device
()));
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
at
::
dtype
(
input
.
dtype
()).
device
(
input
.
device
()));
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_bias_forward"
,
[
&
]
{
scalar_t
*
w_ptr
=
weight
.
data_ptr
<
scalar_t
>
();
auto
result
=
linear_bias_forward_cuda
<
scalar_t
>
(
input
,
w_ptr
,
bias
,
in_features
,
batch_size
,
out_features
,
out
,
//out.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
TORCH_CHECK
(
result
==
0
,
"linear_bias_forward failed."
)
});
return
{
out
};
}
std
::
vector
<
at
::
Tensor
>
linear_bias_backward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
at
::
Tensor
d_output
)
{
auto
batch_size
=
input
.
size
(
0
);
auto
in_features
=
input
.
size
(
1
);
int
out_features
=
weight
.
size
(
0
);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
std
::
vector
<
at
::
Tensor
>
linear_bias_wgrad
(
at
::
Tensor
input
,
at
::
Tensor
d_output
,
bool
has_d_bias
)
{
// create output/workspace tensor
auto
opts
=
input
.
options
();
auto
d_weight
=
at
::
empty
({
out_features
,
in_features
},
opts
);
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
auto
d_bias
=
d_output
.
view
({
-
1
,
out_features
}).
sum
(
0
,
false
);
#else
auto
d_bias
=
at
::
empty
({
out_features
},
opts
);
#endif
auto
d_input
=
at
::
empty
({
batch_size
,
in_features
},
opts
);
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_bias_backward"
,
[
&
]
{
scalar_t
*
w_ptr
=
weight
.
data_ptr
<
scalar_t
>
();
auto
result
=
linear_bias_backward_cuda
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
w_ptr
,
d_output
.
data_ptr
<
scalar_t
>
(),
in_features
,
batch_size
,
out_features
,
d_weight
.
data_ptr
<
scalar_t
>
(),
d_bias
.
data_ptr
<
scalar_t
>
(),
d_input
.
data_ptr
<
scalar_t
>
(),
// reserved_space.data_ptr<scalar_t>(),
/*residual=*/
false
,
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
TORCH_CHECK
(
result
==
0
,
"linear_bias_backward failed."
)
});
return
{
d_input
,
d_weight
,
d_bias
};
}
std
::
vector
<
at
::
Tensor
>
linear_bias_wgrad
(
at
::
Tensor
input
,
at
::
Tensor
d_output
)
{
auto
batch_size
=
input
.
size
(
0
);
auto
in_features
=
input
.
size
(
1
);
int
batch_size
=
input
.
size
(
0
);
int
in_features
=
input
.
size
(
1
);
int
out_features
=
d_output
.
size
(
1
);
int
out_features
=
d_output
.
size
(
1
);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
TORCH_CHECK
(
input
.
dtype
()
==
torch
::
kFloat16
||
input
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
input
.
dtype
()
==
d_output
.
dtype
());
TORCH_CHECK
(
input
.
is_cuda
());
TORCH_CHECK
(
d_output
.
is_cuda
());
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
d_output
.
is_contiguous
());
CHECK_SHAPE
(
input
,
batch_size
,
in_features
);
CHECK_SHAPE
(
d_output
,
batch_size
,
out_features
);
// create output/workspace tensor
// create output/workspace tensor
auto
opts
=
input
.
options
();
auto
opts
=
input
.
options
();
auto
d_weight
=
at
::
empty
({
out_features
,
in_features
},
opts
);
auto
d_weight
=
at
::
empty
({
out_features
,
in_features
},
opts
);
at
::
Tensor
d_bias
;
if
(
has_d_bias
)
{
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
auto
d_bias
=
d_output
.
view
({
-
1
,
out_features
}).
sum
(
0
,
false
);
d_bias
=
d_output
.
view
({
-
1
,
out_features
}).
sum
(
0
,
false
);
#else
#else
auto
d_bias
=
at
::
empty
({
out_features
},
opts
);
d_bias
=
at
::
empty
({
out_features
},
opts
);
#endif
#endif
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
}
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
...
@@ -147,93 +72,59 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output)
...
@@ -147,93 +72,59 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output)
batch_size
,
batch_size
,
out_features
,
out_features
,
d_weight
.
data_ptr
<
scalar_t
>
(),
d_weight
.
data_ptr
<
scalar_t
>
(),
d_bias
.
data_ptr
<
scalar_t
>
(),
has_d_bias
?
d_bias
.
data_ptr
<
scalar_t
>
()
:
nullptr
,
// reserved_space.data_ptr<scalar_t>(),
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
TORCH_CHECK
(
result
==
0
,
"linear_bias_wgrad failed."
)
TORCH_CHECK
(
result
==
0
,
"linear_bias_wgrad failed."
)
;
});
});
return
{
d_weight
,
d_bias
};
return
{
d_weight
,
d_bias
};
}
}
std
::
vector
<
at
::
Tensor
>
linear_bias_residual_backward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
at
::
Tensor
d_output
,
at
::
Tensor
d_input
)
{
std
::
vector
<
at
::
Tensor
>
linear_gelu_forward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
c10
::
optional
<
at
::
Tensor
>
bias_
,
auto
batch_size
=
input
.
size
(
0
);
auto
in_features
=
input
.
size
(
1
);
int
out_features
=
weight
.
size
(
0
);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto
opts
=
input
.
options
();
auto
d_weight
=
at
::
empty
({
out_features
,
in_features
},
opts
);
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
auto
d_bias
=
d_output
.
view
({
-
1
,
out_features
}).
sum
(
0
,
false
);
#else
auto
d_bias
=
at
::
empty
({
out_features
},
opts
);
#endif
CHECK_SHAPE
(
d_input
,
batch_size
,
in_features
);
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_bias_backward"
,
[
&
]
{
scalar_t
*
w_ptr
=
weight
.
data_ptr
<
scalar_t
>
();
auto
result
=
linear_bias_backward_cuda
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
w_ptr
,
d_output
.
data_ptr
<
scalar_t
>
(),
in_features
,
batch_size
,
out_features
,
d_weight
.
data_ptr
<
scalar_t
>
(),
d_bias
.
data_ptr
<
scalar_t
>
(),
d_input
.
data_ptr
<
scalar_t
>
(),
// reserved_space.data_ptr<scalar_t>(),
/*residual=*/
true
,
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
TORCH_CHECK
(
result
==
0
,
"linear_bias_residual_backward failed."
)
});
return
{
d_input
,
d_weight
,
d_bias
};
}
std
::
vector
<
at
::
Tensor
>
linear_gelu_forward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
at
::
Tensor
bias
,
bool
save_gelu_in
,
int
heuristic
)
{
bool
save_gelu_in
,
int
heuristic
)
{
auto
batch_size
=
input
.
size
(
0
);
int
batch_size
=
input
.
size
(
0
);
auto
in_features
=
input
.
size
(
1
);
int
in_features
=
input
.
size
(
1
);
int
out_features
=
weight
.
size
(
0
);
int
out_features
=
weight
.
size
(
0
);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
TORCH_CHECK
(
input
.
dtype
()
==
torch
::
kFloat16
||
input
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
input
.
dtype
()
==
weight
.
dtype
());
TORCH_CHECK
(
input
.
is_cuda
());
TORCH_CHECK
(
weight
.
is_cuda
());
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
weight
.
is_contiguous
());
CHECK_SHAPE
(
input
,
batch_size
,
in_features
);
CHECK_SHAPE
(
weight
,
out_features
,
in_features
);
if
(
bias_
.
has_value
())
{
auto
bias
=
bias_
.
value
();
TORCH_CHECK
(
bias
.
dtype
()
==
input
.
dtype
());
TORCH_CHECK
(
bias
.
is_cuda
());
TORCH_CHECK
(
bias
.
is_contiguous
());
CHECK_SHAPE
(
bias
,
out_features
);
}
// create output/workspace tensor
// create output/workspace tensor
auto
opts
=
input
.
options
();
auto
opts
=
input
.
options
();
auto
output
=
at
::
empty
({
batch_size
,
out_features
},
opts
);
auto
output
=
at
::
empty
({
batch_size
,
out_features
},
opts
);
at
::
Tensor
gelu_in
;
at
::
Tensor
gelu_in
;
if
(
save_gelu_in
)
{
gelu_in
=
at
::
empty
({
batch_size
,
out_features
},
opts
);
}
if
(
save_gelu_in
)
{
gelu_in
=
at
::
empty
({
batch_size
,
out_features
},
opts
);
}
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_gelu_forward"
,
[
&
]
{
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_gelu_forward"
,
[
&
]
{
scalar_t
*
w_ptr
=
weight
.
data_ptr
<
scalar_t
>
();
scalar_t
*
b_ptr
=
bias
.
data_ptr
<
scalar_t
>
();
auto
result
=
linear_gelu_forward_cuda
<
scalar_t
>
(
auto
result
=
linear_gelu_forward_cuda
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
w
_ptr
,
w
eight
.
data_ptr
<
scalar_t
>
()
,
b
_
ptr
,
b
ias_
.
has_value
()
?
bias_
.
value
().
data_ptr
<
scalar_t
>
()
:
null
ptr
,
in_features
,
in_features
,
batch_size
,
batch_size
,
out_features
,
out_features
,
heuristic
,
heuristic
,
output
.
data_ptr
<
scalar_t
>
(),
output
.
data_ptr
<
scalar_t
>
(),
save_gelu_in
?
gelu_in
.
data_ptr
<
scalar_t
>
()
:
nullptr
,
save_gelu_in
?
gelu_in
.
data_ptr
<
scalar_t
>
()
:
nullptr
,
// reserved_space.data_ptr<scalar_t>(),
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
TORCH_CHECK
(
result
==
0
,
"linear_gelu_forward failed."
)
TORCH_CHECK
(
result
==
0
,
"linear_gelu_forward failed."
)
;
});
});
std
::
vector
<
at
::
Tensor
>
result
=
{
output
};
std
::
vector
<
at
::
Tensor
>
result
=
{
output
};
...
@@ -241,116 +132,54 @@ std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight,
...
@@ -241,116 +132,54 @@ std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight,
return
result
;
return
result
;
}
}
std
::
vector
<
at
::
Tensor
>
linear_gelu_linear_backward
(
at
::
Tensor
input
,
at
::
Tensor
gelu_in
,
at
::
Tensor
output1
,
at
::
Tensor
weight1
,
at
::
Tensor
weight2
,
at
::
Tensor
d_output2
,
int
heuristic
)
{
std
::
vector
<
at
::
Tensor
>
bias_gelu_linear_dgrad_bgrad
(
at
::
Tensor
weight
,
at
::
Tensor
d_output
,
at
::
Tensor
gelu_in
,
int
heuristic
)
{
auto
batch_size
=
input
.
size
(
0
);
int
batch_size
=
d_output
.
size
(
0
);
auto
in_features
=
input
.
size
(
1
);
int
out_features
=
d_output
.
size
(
1
);
int
in_features
=
weight
.
size
(
1
);
int
hidden_features
=
weight1
.
size
(
0
);
int
out_features
=
weight2
.
size
(
0
);
TORCH_CHECK
(
weight
.
dtype
()
==
torch
::
kFloat16
||
weight
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
weight
.
dtype
()
==
d_output
.
dtype
());
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
TORCH_CHECK
(
weight
.
dtype
()
==
gelu_in
.
dtype
());
TORCH_CHECK
(
weight
.
is_cuda
());
TORCH_CHECK
(
d_output
.
is_cuda
());
TORCH_CHECK
(
gelu_in
.
is_cuda
());
TORCH_CHECK
(
weight
.
is_contiguous
());
TORCH_CHECK
(
d_output
.
is_contiguous
());
TORCH_CHECK
(
gelu_in
.
is_contiguous
());
CHECK_SHAPE
(
weight
,
out_features
,
in_features
);
CHECK_SHAPE
(
d_output
,
batch_size
,
out_features
);
CHECK_SHAPE
(
gelu_in
,
batch_size
,
in_features
);
// create output/workspace tensor
// create output/workspace tensor
auto
opts
=
input
.
options
();
auto
opts
=
weight
.
options
();
auto
d_weight1
=
at
::
empty
({
hidden_features
,
in_features
},
opts
);
auto
d_bias
=
at
::
empty
({
in_features
},
opts
);
auto
d_weight2
=
at
::
empty
({
out_features
,
hidden_features
},
opts
);
auto
d_bias1
=
at
::
empty
({
hidden_features
},
opts
);
auto
d_bias2
=
at
::
empty
({
out_features
},
opts
);
auto
d_input
=
at
::
empty
({
batch_size
,
in_features
},
opts
);
auto
d_input
=
at
::
empty
({
batch_size
,
in_features
},
opts
);
auto
d_output1
=
at
::
empty
({
batch_size
,
hidden_features
},
opts
);
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_bias_backward"
,
[
&
]
{
//scalar_t* w_ptr = weight.data_ptr<scalar_t>();
//scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
auto
result
=
linear_gelu_linear_backward_cuda
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
gelu_in
.
data_ptr
<
scalar_t
>
(),
output1
.
data_ptr
<
scalar_t
>
(),
weight1
.
data_ptr
<
scalar_t
>
(),
weight2
.
data_ptr
<
scalar_t
>
(),
d_output1
.
data_ptr
<
scalar_t
>
(),
d_output2
.
data_ptr
<
scalar_t
>
(),
in_features
,
batch_size
,
hidden_features
,
out_features
,
heuristic
,
d_weight1
.
data_ptr
<
scalar_t
>
(),
d_weight2
.
data_ptr
<
scalar_t
>
(),
d_bias1
.
data_ptr
<
scalar_t
>
(),
d_bias2
.
data_ptr
<
scalar_t
>
(),
d_input
.
data_ptr
<
scalar_t
>
(),
// reserved_space.data_ptr<scalar_t>(),
/*residual=*/
false
,
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
TORCH_CHECK
(
result
==
0
,
"linear_gelu_linear_backward failed."
)
});
return
{
d_input
,
d_weight1
,
d_bias1
,
d_weight2
,
d_bias2
};
}
std
::
vector
<
at
::
Tensor
>
linear_residual_gelu_linear_backward
(
at
::
Tensor
input
,
at
::
Tensor
gelu_in
,
at
::
Tensor
output1
,
at
::
Tensor
weight1
,
at
::
Tensor
weight2
,
at
::
Tensor
d_output2
,
at
::
Tensor
d_input
,
int
heuristic
)
{
auto
batch_size
=
input
.
size
(
0
);
auto
in_features
=
input
.
size
(
1
);
int
hidden_features
=
weight1
.
size
(
0
);
int
out_features
=
weight2
.
size
(
0
);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto
opts
=
input
.
options
();
auto
d_weight1
=
at
::
empty
({
hidden_features
,
in_features
},
opts
);
auto
d_weight2
=
at
::
empty
({
out_features
,
hidden_features
},
opts
);
auto
d_bias1
=
at
::
empty
({
hidden_features
},
opts
);
auto
d_bias2
=
at
::
empty
({
out_features
},
opts
);
CHECK_SHAPE
(
d_input
,
batch_size
,
in_features
);
auto
d_output1
=
at
::
empty
({
batch_size
,
hidden_features
},
opts
);
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
opts
);
DISPATCH_HALF_AND_BF16
(
input
.
scalar_type
(),
"linear_bias_backward"
,
[
&
]
{
DISPATCH_HALF_AND_BF16
(
weight
.
scalar_type
(),
"bias_gelu_linear_dgrad_bgrad"
,
[
&
]
{
//scalar_t* w_ptr = weight.data_ptr<scalar_t>();
auto
result
=
bias_gelu_linear_dgrad_bgrad_cuda
<
scalar_t
>
(
//scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
weight
.
data_ptr
<
scalar_t
>
(),
auto
result
=
linear_gelu_linear_backward_cuda
<
scalar_t
>
(
d_output
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
gelu_in
.
data_ptr
<
scalar_t
>
(),
gelu_in
.
data_ptr
<
scalar_t
>
(),
output1
.
data_ptr
<
scalar_t
>
(),
weight1
.
data_ptr
<
scalar_t
>
(),
weight2
.
data_ptr
<
scalar_t
>
(),
d_output1
.
data_ptr
<
scalar_t
>
(),
d_output2
.
data_ptr
<
scalar_t
>
(),
in_features
,
in_features
,
batch_size
,
batch_size
,
hidden_features
,
out_features
,
out_features
,
heuristic
,
heuristic
,
d_weight1
.
data_ptr
<
scalar_t
>
(),
d_weight2
.
data_ptr
<
scalar_t
>
(),
d_bias1
.
data_ptr
<
scalar_t
>
(),
d_bias2
.
data_ptr
<
scalar_t
>
(),
d_input
.
data_ptr
<
scalar_t
>
(),
d_input
.
data_ptr
<
scalar_t
>
(),
// reserved_space.data_ptr<scalar_t>(),
d_bias
.
data_ptr
<
scalar_t
>
(),
/*residual=*/
true
,
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
TORCH_CHECK
(
result
==
0
,
"
linear_residual
_gelu_linear_
backwar
d failed."
)
TORCH_CHECK
(
result
==
0
,
"
bias
_gelu_linear_
dgrad_bgra
d failed."
)
;
});
});
return
{
d_input
,
d_
weight1
,
d_bias1
,
d_weight2
,
d_bias2
};
return
{
d_input
,
d_
bias
};
}
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"linear_bias_forward"
,
&
linear_bias_forward
,
"linear bias forward"
);
m
.
def
(
"linear_bias_backward"
,
&
linear_bias_backward
,
"linear bias backward"
);
m
.
def
(
"linear_bias_wgrad"
,
&
linear_bias_wgrad
,
"linear bias wgrad"
);
m
.
def
(
"linear_bias_wgrad"
,
&
linear_bias_wgrad
,
"linear bias wgrad"
);
m
.
def
(
"linear_bias_residual_backward"
,
&
linear_bias_residual_backward
,
"linear bias residual backward"
);
m
.
def
(
"linear_gelu_forward"
,
&
linear_gelu_forward
,
"linear gelu forward"
);
m
.
def
(
"linear_gelu_forward"
,
&
linear_gelu_forward
,
"linear gelu forward"
);
m
.
def
(
"linear_gelu_linear_backward"
,
&
linear_gelu_linear_backward
,
"linear gelu linear backward"
);
m
.
def
(
"bias_gelu_linear_dgrad_bgrad"
,
&
bias_gelu_linear_dgrad_bgrad
,
"bias gelu linear dgrad bgrad"
);
m
.
def
(
"linear_residual_gelu_linear_backward"
,
&
linear_residual_gelu_linear_backward
,
"linear residual gelu linear backward"
);
}
}
csrc/fused_dense_lib/fused_dense_cuda.cu
View file @
e68ebbe8
This diff is collapsed.
Click to expand it.
flash_attn/layers/patch_embed.py
View file @
e68ebbe8
...
@@ -10,9 +10,9 @@ from torch.nn.modules.utils import _pair
...
@@ -10,9 +10,9 @@ from torch.nn.modules.utils import _pair
from
einops
import
rearrange
from
einops
import
rearrange
try
:
try
:
from
flash_attn.ops.fused_dense
import
FusedDense
TD
from
flash_attn.ops.fused_dense
import
FusedDense
except
ImportError
:
except
ImportError
:
FusedDense
TD
=
None
FusedDense
=
None
class
PatchEmbed
(
nn
.
Module
):
class
PatchEmbed
(
nn
.
Module
):
...
@@ -37,10 +37,10 @@ class PatchEmbed(nn.Module):
...
@@ -37,10 +37,10 @@ class PatchEmbed(nn.Module):
self
.
grid_size
=
(
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
])
self
.
grid_size
=
(
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
])
self
.
num_patches
=
self
.
grid_size
[
0
]
*
self
.
grid_size
[
1
]
self
.
num_patches
=
self
.
grid_size
[
0
]
*
self
.
grid_size
[
1
]
self
.
flatten
=
flatten
self
.
flatten
=
flatten
if
fused_bias_fc
and
FusedDense
TD
is
None
:
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
raise
ImportError
(
'fused_dense is not installed'
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
or
not
bias
else
FusedDense
TD
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
or
not
bias
else
FusedDense
self
.
proj
=
linear_cls
(
in_chans
*
patch_size
[
0
]
*
patch_size
[
1
],
embed_dim
,
bias
=
bias
)
self
.
proj
=
linear_cls
(
in_chans
*
patch_size
[
0
]
*
patch_size
[
1
],
embed_dim
,
bias
=
bias
)
self
.
norm
=
norm_layer
(
embed_dim
)
if
norm_layer
else
nn
.
Identity
()
self
.
norm
=
norm_layer
(
embed_dim
)
if
norm_layer
else
nn
.
Identity
()
...
...
flash_attn/models/bert.py
View file @
e68ebbe8
...
@@ -30,9 +30,9 @@ from flash_attn.bert_padding import unpad_input, pad_input
...
@@ -30,9 +30,9 @@ from flash_attn.bert_padding import unpad_input, pad_input
from
flash_attn.bert_padding
import
index_first_axis
,
index_first_axis_residual
from
flash_attn.bert_padding
import
index_first_axis
,
index_first_axis_residual
try
:
try
:
from
flash_attn.ops.fused_dense
import
FusedDense
TD
from
flash_attn.ops.fused_dense
import
FusedDense
except
ImportError
:
except
ImportError
:
FusedDense
TD
=
None
FusedDense
=
None
try
:
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm
,
layer_norm
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm
,
layer_norm
...
@@ -70,6 +70,8 @@ def create_mlp_cls(config, layer_idx=None, return_residual=False):
...
@@ -70,6 +70,8 @@ def create_mlp_cls(config, layer_idx=None, return_residual=False):
activation
=
partial
(
F
.
gelu
,
approximate
=
approximate
),
activation
=
partial
(
F
.
gelu
,
approximate
=
approximate
),
return_residual
=
return_residual
)
return_residual
=
return_residual
)
else
:
else
:
if
FusedDenseGeluDense
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
mlp_checkpoint_lvl
=
getattr
(
config
,
'mlp_checkpoint_lvl'
,
0
)
mlp_checkpoint_lvl
=
getattr
(
config
,
'mlp_checkpoint_lvl'
,
0
)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
if
isinstance
(
mlp_checkpoint_lvl
,
Sequence
):
if
isinstance
(
mlp_checkpoint_lvl
,
Sequence
):
...
@@ -168,9 +170,9 @@ class BertPooler(nn.Module):
...
@@ -168,9 +170,9 @@ class BertPooler(nn.Module):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
()
super
().
__init__
()
fused_bias_fc
=
getattr
(
config
,
'fused_bias_fc'
,
False
)
fused_bias_fc
=
getattr
(
config
,
'fused_bias_fc'
,
False
)
if
fused_bias_fc
and
FusedDense
TD
is
None
:
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
raise
ImportError
(
'fused_dense is not installed'
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
TD
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
self
.
dense
=
linear_cls
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
dense
=
linear_cls
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
activation
=
nn
.
Tanh
()
self
.
activation
=
nn
.
Tanh
()
...
@@ -188,12 +190,12 @@ class BertPredictionHeadTransform(nn.Module):
...
@@ -188,12 +190,12 @@ class BertPredictionHeadTransform(nn.Module):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
()
super
().
__init__
()
fused_bias_fc
=
getattr
(
config
,
'fused_bias_fc'
,
False
)
fused_bias_fc
=
getattr
(
config
,
'fused_bias_fc'
,
False
)
if
fused_bias_fc
and
FusedDense
TD
is
None
:
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
raise
ImportError
(
'fused_dense is not installed'
)
self
.
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
)
self
.
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
)
if
self
.
fused_dropout_add_ln
and
layer_norm
is
None
:
if
self
.
fused_dropout_add_ln
and
layer_norm
is
None
:
raise
ImportError
(
'dropout_add_layer_norm is not installed'
)
raise
ImportError
(
'dropout_add_layer_norm is not installed'
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
TD
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
self
.
dense
=
linear_cls
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
dense
=
linear_cls
(
config
.
hidden_size
,
config
.
hidden_size
)
approximate
=
'tanh'
if
config
.
hidden_act
in
[
'gelu_new'
,
'gelu_fast'
]
else
'none'
approximate
=
'tanh'
if
config
.
hidden_act
in
[
'gelu_new'
,
'gelu_fast'
]
else
'none'
self
.
transform_act_fn
=
nn
.
GELU
(
approximate
=
approximate
)
self
.
transform_act_fn
=
nn
.
GELU
(
approximate
=
approximate
)
...
@@ -215,9 +217,9 @@ class BertLMPredictionHead(nn.Module):
...
@@ -215,9 +217,9 @@ class BertLMPredictionHead(nn.Module):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
()
super
().
__init__
()
fused_bias_fc
=
getattr
(
config
,
'fused_bias_fc'
,
False
)
fused_bias_fc
=
getattr
(
config
,
'fused_bias_fc'
,
False
)
if
fused_bias_fc
and
FusedDense
TD
is
None
:
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
raise
ImportError
(
'fused_dense is not installed'
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
TD
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
self
.
transform
=
BertPredictionHeadTransform
(
config
)
self
.
transform
=
BertPredictionHeadTransform
(
config
)
...
...
flash_attn/models/gpt.py
View file @
e68ebbe8
...
@@ -61,6 +61,8 @@ def create_mlp_cls(config, layer_idx=None):
...
@@ -61,6 +61,8 @@ def create_mlp_cls(config, layer_idx=None):
assert
layer_idx
is
not
None
assert
layer_idx
is
not
None
mlp_checkpoint_lvl
=
mlp_checkpoint_lvl
[
layer_idx
]
mlp_checkpoint_lvl
=
mlp_checkpoint_lvl
[
layer_idx
]
if
fused_dense_gelu_dense
:
if
fused_dense_gelu_dense
:
if
FusedDenseGeluDense
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
mlp_cls
=
partial
(
FusedDenseGeluDense
,
hidden_features
=
inner_dim
,
mlp_cls
=
partial
(
FusedDenseGeluDense
,
hidden_features
=
inner_dim
,
checkpoint_lvl
=
mlp_checkpoint_lvl
)
checkpoint_lvl
=
mlp_checkpoint_lvl
)
elif
fused_dense_sqrelu_dense
:
elif
fused_dense_sqrelu_dense
:
...
...
flash_attn/modules/mha.py
View file @
e68ebbe8
...
@@ -21,9 +21,9 @@ except ImportError:
...
@@ -21,9 +21,9 @@ except ImportError:
flash_attn_qkvpacked_func
,
flash_attn_kvpacked_func
=
None
,
None
flash_attn_qkvpacked_func
,
flash_attn_kvpacked_func
=
None
,
None
try
:
try
:
from
flash_attn.ops.fused_dense
import
FusedDense
TD
,
FusedDenseResidual
from
flash_attn.ops.fused_dense
import
FusedDense
except
ImportError
:
except
ImportError
:
FusedDense
TD
,
FusedDenseResidual
=
None
,
None
FusedDense
=
None
try
:
try
:
from
flash_attn.layers.rotary
import
RotaryEmbedding
from
flash_attn.layers.rotary
import
RotaryEmbedding
...
@@ -270,7 +270,7 @@ class CrossAttention(nn.Module):
...
@@ -270,7 +270,7 @@ class CrossAttention(nn.Module):
class
LinearResidual
(
nn
.
Linear
):
class
LinearResidual
(
nn
.
Linear
):
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense
Residual
.
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.
"""
"""
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -311,10 +311,11 @@ class MHA(nn.Module):
...
@@ -311,10 +311,11 @@ class MHA(nn.Module):
assert
RotaryEmbedding
is
not
None
,
'rotary_emb is not installed'
assert
RotaryEmbedding
is
not
None
,
'rotary_emb is not installed'
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
scale_base
=
rotary_emb_scale_base
)
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
scale_base
=
rotary_emb_scale_base
)
if
fused_bias_fc
and
FusedDense
TD
is
None
:
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
raise
ImportError
(
'fused_dense is not installed'
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDenseTD
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
linear_resid_cls
=
LinearResidual
if
not
fused_bias_fc
else
FusedDenseResidual
linear_resid_cls
=
(
LinearResidual
if
not
fused_bias_fc
else
partial
(
FusedDense
,
return_residual
=
True
))
if
not
self
.
cross_attn
:
if
not
self
.
cross_attn
:
if
not
self
.
return_residual
:
if
not
self
.
return_residual
:
self
.
Wqkv
=
linear_cls
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
Wqkv
=
linear_cls
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
...
...
flash_attn/modules/mlp.py
View file @
e68ebbe8
...
@@ -5,11 +5,9 @@ import torch.nn as nn
...
@@ -5,11 +5,9 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
try
:
try
:
from
flash_attn.ops.fused_dense
import
fused_dense_gelu_dense_function_td
from
flash_attn.ops.fused_dense
import
FusedDenseGeluDense
from
flash_attn.ops.fused_dense
import
fused_dense_res_gelu_dense_function_td
except
ImportError
:
except
ImportError
:
fused_dense_gelu_dense_function_td
=
None
FusedDenseGeluDense
=
None
fused_dense_res_gelu_dense_function_td
=
None
class
Mlp
(
nn
.
Module
):
class
Mlp
(
nn
.
Module
):
...
@@ -30,43 +28,3 @@ class Mlp(nn.Module):
...
@@ -30,43 +28,3 @@ class Mlp(nn.Module):
y
=
self
.
activation
(
y
)
y
=
self
.
activation
(
y
)
y
=
self
.
fc2
(
y
)
y
=
self
.
fc2
(
y
)
return
y
if
not
self
.
return_residual
else
(
y
,
x
)
return
y
if
not
self
.
return_residual
else
(
y
,
x
)
class
FusedDenseGeluDense
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
bias
=
True
,
checkpoint_lvl
=
0
,
heuristic
=
0
,
return_residual
=
False
,
device
=
None
,
dtype
=
None
):
"""
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
heuristic:
-1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu
For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
assert
bias
==
True
,
"DenseGeluDense module without bias is currently not supported"
assert
(
fused_dense_gelu_dense_function_td
is
not
None
and
fused_dense_res_gelu_dense_function_td
is
not
None
),
'fused_dense_lib is not installed'
self
.
checkpoint_lvl
=
checkpoint_lvl
self
.
heuristic
=
heuristic
self
.
return_residual
=
return_residual
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
,
bias
=
bias
,
**
factory_kwargs
)
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
assert
x
.
is_cuda
fn
=
(
fused_dense_gelu_dense_function_td
if
not
self
.
return_residual
else
fused_dense_res_gelu_dense_function_td
)
return
fn
(
x
,
self
.
fc1
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
weight
,
self
.
fc2
.
bias
,
self
.
checkpoint_lvl
,
self
.
heuristic
)
flash_attn/ops/fused_dense.py
View file @
e68ebbe8
This diff is collapsed.
Click to expand it.
tests/ops/test_fused_dense.py
View file @
e68ebbe8
...
@@ -6,29 +6,44 @@ import pytest
...
@@ -6,29 +6,44 @@ import pytest
from
einops
import
rearrange
from
einops
import
rearrange
from
flash_attn.ops.fused_dense
import
FusedDenseTD
,
FusedDenseGeluDenseTD
from
flash_attn.ops.fused_dense
import
FusedDense
,
FusedDenseGeluDense
from
flash_attn.ops.fused_dense
import
FusedDenseResidual
,
FusedDenseResGeluDense
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'return_residual'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'has_bias'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
,
4096
])
def
test_fused_linear_bias
(
in_features
,
out_features
,
dtype
):
def
test_fused_linear_bias
(
in_features
,
out_features
,
has_bias
,
return_residual
,
dtype
):
device
=
'cuda'
device
=
'cuda'
rtol
,
atol
=
(
3e-3
,
1e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
rtol
,
atol
=
(
3e-3
,
1e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
batch_size
=
8
seqlen
=
512
seqlen
=
512
x_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
model_pt
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
bias
=
has_bias
,
device
=
device
,
dtype
=
dtype
)
model
=
FusedDenseTD
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
model
=
FusedDense
(
in_features
,
out_features
,
bias
=
has_bias
,
return_residual
=
return_residual
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
model_pt
.
weight
)
model
.
weight
.
copy_
(
model_pt
.
weight
)
if
has_bias
:
model
.
bias
.
copy_
(
model_pt
.
bias
)
model
.
bias
.
copy_
(
model_pt
.
bias
)
out_pt
=
model_pt
(
x_pt
)
out_pt
=
model_pt
(
x_pt
)
if
not
return_residual
:
out
=
model
(
x
)
out
=
model
(
x
)
else
:
out
,
x_copy
=
model
(
x
)
x_copy
=
(
x_copy
[...,
:
out_features
]
if
out_features
<
in_features
else
F
.
pad
(
x_copy
,
(
0
,
out_features
-
in_features
)))
x_pt_copy
=
(
x_pt
[...,
:
out_features
]
if
out_features
<
in_features
else
F
.
pad
(
x_pt
,
(
0
,
out_features
-
in_features
)))
# Just add some random function of the residual
out_pt
=
out_pt
+
F
.
gelu
(
x_pt_copy
)
out
=
out
+
F
.
gelu
(
x_copy
)
# with torch.no_grad():
# with torch.no_grad():
# out_fl = F.linear(x_pt.float(), model.weight.float(), model.bias.float()).half()
# out_fl = F.linear(x_pt.float(), model.weight.float(), model.bias.float()).half()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
atol
)
...
@@ -40,107 +55,53 @@ def test_fused_linear_bias(in_features, out_features, dtype):
...
@@ -40,107 +55,53 @@ def test_fused_linear_bias(in_features, out_features, dtype):
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
# The error for d_weight and d_bias is quite a bit higher
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
weight
.
grad
,
model_pt
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
weight
.
grad
,
model_pt
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
if
has_bias
:
assert
torch
.
allclose
(
model
.
bias
.
grad
,
model_pt
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
bias
.
grad
,
model_pt
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'out_features,in_features'
,
[(
1024
,
1024
),
(
4096
,
4096
)])
@
pytest
.
mark
.
parametrize
(
'heuristic'
,
[
0
,
-
1
])
def
test_fused_linear_bias_residual
(
in_features
,
out_features
,
dtype
):
device
=
'cuda'
rtol
,
atol
=
(
3e-3
,
1e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
x_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
model
=
FusedDenseResidual
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
model_pt
.
weight
)
model
.
bias
.
copy_
(
model_pt
.
bias
)
out_pt
=
model_pt
(
x_pt
)
+
F
.
gelu
(
x_pt
)
# Just add some random function of the residual x_pt
out
,
x_copy
=
model
(
x
)
out
=
out
+
F
.
gelu
(
x_copy
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
atol
*
2
)
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
out
)
/
32
out_pt
.
backward
(
g
)
out
.
backward
(
g
)
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
weight
.
grad
,
model_pt
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
bias
.
grad
,
model_pt
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'heuristic'
,
[
1
,
-
1
])
@
pytest
.
mark
.
parametrize
(
'checkpoint_lvl'
,
[
0
,
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'checkpoint_lvl'
,
[
0
,
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'return_residual'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'has_bias2'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'has_bias1'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
,
4096
])
def
test_fused_dense_gelu_dense
(
in_features
,
out_features
,
checkpoint_lvl
,
heuristic
,
dtype
):
def
test_fused_dense_gelu_dense
(
in_features
,
out_features
,
has_bias1
,
has_bias2
,
return_residual
,
checkpoint_lvl
,
heuristic
,
dtype
):
device
=
'cuda'
device
=
'cuda'
rtol
,
atol
=
(
3e-3
,
1
e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
rtol
,
atol
=
(
3e-3
,
3
e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
batch_size
=
8
seqlen
=
512
seqlen
=
512
x_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt_fc1
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
model_pt_fc1
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
bias
=
has_bias1
,
device
=
device
,
model_pt_fc2
=
torch
.
nn
.
Linear
(
out_features
,
in_features
,
device
=
device
,
dtype
=
dtype
)
dtype
=
dtype
)
model
=
FusedDenseGeluDenseTD
(
in_features
,
out_features
,
in_features
,
model_pt_fc2
=
torch
.
nn
.
Linear
(
out_features
,
in_features
,
bias
=
has_bias2
,
device
=
device
,
dtype
=
dtype
)
model
=
FusedDenseGeluDense
(
in_features
,
out_features
,
in_features
,
bias1
=
has_bias1
,
bias2
=
has_bias2
,
return_residual
=
return_residual
,
checkpoint_lvl
=
checkpoint_lvl
,
heuristic
=
heuristic
,
checkpoint_lvl
=
checkpoint_lvl
,
heuristic
=
heuristic
,
device
=
device
,
dtype
=
dtype
)
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
.
fc1
.
weight
.
copy_
(
model_pt_fc1
.
weight
)
model
.
fc1
.
weight
.
copy_
(
model_pt_fc1
.
weight
)
if
has_bias1
:
model
.
fc1
.
bias
.
copy_
(
model_pt_fc1
.
bias
)
model
.
fc1
.
bias
.
copy_
(
model_pt_fc1
.
bias
)
model
.
fc2
.
weight
.
copy_
(
model_pt_fc2
.
weight
)
model
.
fc2
.
weight
.
copy_
(
model_pt_fc2
.
weight
)
if
has_bias2
:
model
.
fc2
.
bias
.
copy_
(
model_pt_fc2
.
bias
)
model
.
fc2
.
bias
.
copy_
(
model_pt_fc2
.
bias
)
out_pt
=
model_pt_fc2
(
F
.
gelu
(
model_pt_fc1
(
x_pt
),
approximate
=
'tanh'
))
out_pt
=
model_pt_fc2
(
F
.
gelu
(
model_pt_fc1
(
x_pt
),
approximate
=
'tanh'
))
if
not
return_residual
:
out
=
model
(
x
)
out
=
model
(
x
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
atol
)
else
:
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
out
)
/
32
out_pt
.
backward
(
g
)
out
.
backward
(
g
)
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
fc1
.
weight
.
grad
,
model_pt_fc1
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
fc1
.
bias
.
grad
,
model_pt_fc1
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
fc2
.
weight
.
grad
,
model_pt_fc2
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
fc2
.
bias
.
grad
,
model_pt_fc2
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'checkpoint_lvl'
,
[
0
,
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
,
4096
])
def
test_fused_dense_residual_gelu_dense
(
in_features
,
out_features
,
checkpoint_lvl
,
dtype
):
device
=
'cuda'
rtol
,
atol
=
(
3e-3
,
1e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
x_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt_fc1
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
model_pt_fc2
=
torch
.
nn
.
Linear
(
out_features
,
in_features
,
device
=
device
,
dtype
=
dtype
)
model
=
FusedDenseResGeluDense
(
in_features
,
out_features
,
in_features
,
checkpoint_lvl
=
checkpoint_lvl
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
model
.
fc1
.
weight
.
copy_
(
model_pt_fc1
.
weight
)
model
.
fc1
.
bias
.
copy_
(
model_pt_fc1
.
bias
)
model
.
fc2
.
weight
.
copy_
(
model_pt_fc2
.
weight
)
model
.
fc2
.
bias
.
copy_
(
model_pt_fc2
.
bias
)
out_pt
=
model_pt_fc2
(
F
.
gelu
(
model_pt_fc1
(
x_pt
),
approximate
=
'tanh'
))
+
F
.
gelu
(
x_pt
)
out
,
x_copy
=
model
(
x
)
out
,
x_copy
=
model
(
x
)
# Just add some random function of the residual
out_pt
=
out_pt
+
F
.
gelu
(
x_pt
)
out
=
out
+
F
.
gelu
(
x_copy
)
out
=
out
+
F
.
gelu
(
x_copy
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
atol
*
2
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
atol
)
# If we don't divide by batch_size, the gradient gets a bit too large.
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
out
)
/
32
g
=
torch
.
randn_like
(
out
)
/
32
...
@@ -149,6 +110,8 @@ def test_fused_dense_residual_gelu_dense(in_features, out_features, checkpoint_l
...
@@ -149,6 +110,8 @@ def test_fused_dense_residual_gelu_dense(in_features, out_features, checkpoint_l
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
# The error for d_weight and d_bias is quite a bit higher
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
fc1
.
weight
.
grad
,
model_pt_fc1
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
fc1
.
weight
.
grad
,
model_pt_fc1
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
if
has_bias1
:
assert
torch
.
allclose
(
model
.
fc1
.
bias
.
grad
,
model_pt_fc1
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
fc1
.
bias
.
grad
,
model_pt_fc1
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
fc2
.
weight
.
grad
,
model_pt_fc2
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
fc2
.
weight
.
grad
,
model_pt_fc2
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
if
has_bias2
:
assert
torch
.
allclose
(
model
.
fc2
.
bias
.
grad
,
model_pt_fc2
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
fc2
.
bias
.
grad
,
model_pt_fc2
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment