Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
dec4f2e9
Commit
dec4f2e9
authored
Apr 06, 2023
by
Tri Dao
Browse files
[FusedDense] Set workspace size to 32M for Hopper and 4M for others
parent
d478eeec
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
3 deletions
+7
-3
csrc/fused_dense_lib/fused_dense_cuda.cu
csrc/fused_dense_lib/fused_dense_cuda.cu
+7
-3
No files found.
csrc/fused_dense_lib/fused_dense_cuda.cu
View file @
dec4f2e9
...
@@ -122,7 +122,9 @@ int gemm_bias_act_lt(
...
@@ -122,7 +122,9 @@ int gemm_bias_act_lt(
reinterpret_cast
<
cublasLtHandle_t
>
(
at
::
cuda
::
getCurrentCUDABlasHandle
());
reinterpret_cast
<
cublasLtHandle_t
>
(
at
::
cuda
::
getCurrentCUDABlasHandle
());
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// setting this to 1M.
// setting this to 1M.
size_t
workspaceSize
=
1024
*
1024
;
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
// https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91
size_t
workspaceSize
=
1024
*
1024
*
(
at
::
cuda
::
getCurrentDeviceProperties
()
->
major
>=
9
?
32
:
4
);
void
*
workspace
=
at
::
empty
(
void
*
workspace
=
at
::
empty
(
{
static_cast
<
int64_t
>
(
workspaceSize
)},
{
static_cast
<
int64_t
>
(
workspaceSize
)},
at
::
device
({
at
::
kCUDA
,
at
::
cuda
::
current_device
()}).
dtype
(
at
::
kByte
)).
data_ptr
();
at
::
device
({
at
::
kCUDA
,
at
::
cuda
::
current_device
()}).
dtype
(
at
::
kByte
)).
data_ptr
();
...
@@ -296,7 +298,8 @@ int gemm_bgradb_lt(
...
@@ -296,7 +298,8 @@ int gemm_bgradb_lt(
reinterpret_cast
<
cublasLtHandle_t
>
(
at
::
cuda
::
getCurrentCUDABlasHandle
());
reinterpret_cast
<
cublasLtHandle_t
>
(
at
::
cuda
::
getCurrentCUDABlasHandle
());
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// setting this to 1M.
// setting this to 1M.
size_t
workspaceSize
=
1024
*
1024
;
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
size_t
workspaceSize
=
1024
*
1024
*
(
at
::
cuda
::
getCurrentDeviceProperties
()
->
major
>=
9
?
32
:
4
);
void
*
workspace
=
at
::
empty
(
void
*
workspace
=
at
::
empty
(
{
static_cast
<
int64_t
>
(
workspaceSize
)},
{
static_cast
<
int64_t
>
(
workspaceSize
)},
at
::
device
({
at
::
kCUDA
,
at
::
cuda
::
current_device
()}).
dtype
(
at
::
kByte
)).
data_ptr
();
at
::
device
({
at
::
kCUDA
,
at
::
cuda
::
current_device
()}).
dtype
(
at
::
kByte
)).
data_ptr
();
...
@@ -449,7 +452,8 @@ int gemm_dact_bgradb_lt(
...
@@ -449,7 +452,8 @@ int gemm_dact_bgradb_lt(
reinterpret_cast
<
cublasLtHandle_t
>
(
at
::
cuda
::
getCurrentCUDABlasHandle
());
reinterpret_cast
<
cublasLtHandle_t
>
(
at
::
cuda
::
getCurrentCUDABlasHandle
());
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// setting this to 1M.
// setting this to 1M.
size_t
workspaceSize
=
1024
*
1024
;
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
size_t
workspaceSize
=
1024
*
1024
*
(
at
::
cuda
::
getCurrentDeviceProperties
()
->
major
>=
9
?
32
:
4
);
void
*
workspace
=
at
::
empty
(
void
*
workspace
=
at
::
empty
(
{
static_cast
<
int64_t
>
(
workspaceSize
)},
{
static_cast
<
int64_t
>
(
workspaceSize
)},
at
::
device
({
at
::
kCUDA
,
at
::
cuda
::
current_device
()}).
dtype
(
at
::
kByte
)).
data_ptr
();
at
::
device
({
at
::
kCUDA
,
at
::
cuda
::
current_device
()}).
dtype
(
at
::
kByte
)).
data_ptr
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment