Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
315fd31f
Unverified
Commit
315fd31f
authored
Apr 12, 2023
by
Kirthi Shankar Sivamani
Committed by
GitHub
Apr 12, 2023
Browse files
Merge branch 'HazyResearch:main' into enable_cuda_graph_capture
parents
31018c5f
5cee0714
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
12 additions
and
8 deletions
+12
-8
README.md
README.md
+1
-1
csrc/fused_dense_lib/fused_dense_cuda.cu
csrc/fused_dense_lib/fused_dense_cuda.cu
+7
-3
flash_attn/modules/mlp.py
flash_attn/modules/mlp.py
+1
-1
setup.py
setup.py
+1
-1
training/Dockerfile
training/Dockerfile
+2
-2
No files found.
README.md
View file @
315fd31f
...
@@ -38,7 +38,7 @@ and experiment with. The notations in the Triton implementation are also closer
...
@@ -38,7 +38,7 @@ and experiment with. The notations in the Triton implementation are also closer
to what's used in our paper.
to what's used in our paper.
##
Beta release (0.2).
##
Installation and features
Requirements:
Requirements:
-
CUDA 11.4 and above.
-
CUDA 11.4 and above.
...
...
csrc/fused_dense_lib/fused_dense_cuda.cu
View file @
315fd31f
...
@@ -122,7 +122,9 @@ int gemm_bias_act_lt(
...
@@ -122,7 +122,9 @@ int gemm_bias_act_lt(
reinterpret_cast
<
cublasLtHandle_t
>
(
at
::
cuda
::
getCurrentCUDABlasHandle
());
reinterpret_cast
<
cublasLtHandle_t
>
(
at
::
cuda
::
getCurrentCUDABlasHandle
());
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// setting this to 1M.
// setting this to 1M.
size_t
workspaceSize
=
1024
*
1024
;
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
// https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91
size_t
workspaceSize
=
1024
*
1024
*
(
at
::
cuda
::
getCurrentDeviceProperties
()
->
major
>=
9
?
32
:
4
);
void
*
workspace
=
at
::
empty
(
void
*
workspace
=
at
::
empty
(
{
static_cast
<
int64_t
>
(
workspaceSize
)},
{
static_cast
<
int64_t
>
(
workspaceSize
)},
at
::
device
({
at
::
kCUDA
,
at
::
cuda
::
current_device
()}).
dtype
(
at
::
kByte
)).
data_ptr
();
at
::
device
({
at
::
kCUDA
,
at
::
cuda
::
current_device
()}).
dtype
(
at
::
kByte
)).
data_ptr
();
...
@@ -296,7 +298,8 @@ int gemm_bgradb_lt(
...
@@ -296,7 +298,8 @@ int gemm_bgradb_lt(
reinterpret_cast
<
cublasLtHandle_t
>
(
at
::
cuda
::
getCurrentCUDABlasHandle
());
reinterpret_cast
<
cublasLtHandle_t
>
(
at
::
cuda
::
getCurrentCUDABlasHandle
());
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// setting this to 1M.
// setting this to 1M.
size_t
workspaceSize
=
1024
*
1024
;
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
size_t
workspaceSize
=
1024
*
1024
*
(
at
::
cuda
::
getCurrentDeviceProperties
()
->
major
>=
9
?
32
:
4
);
void
*
workspace
=
at
::
empty
(
void
*
workspace
=
at
::
empty
(
{
static_cast
<
int64_t
>
(
workspaceSize
)},
{
static_cast
<
int64_t
>
(
workspaceSize
)},
at
::
device
({
at
::
kCUDA
,
at
::
cuda
::
current_device
()}).
dtype
(
at
::
kByte
)).
data_ptr
();
at
::
device
({
at
::
kCUDA
,
at
::
cuda
::
current_device
()}).
dtype
(
at
::
kByte
)).
data_ptr
();
...
@@ -449,7 +452,8 @@ int gemm_dact_bgradb_lt(
...
@@ -449,7 +452,8 @@ int gemm_dact_bgradb_lt(
reinterpret_cast
<
cublasLtHandle_t
>
(
at
::
cuda
::
getCurrentCUDABlasHandle
());
reinterpret_cast
<
cublasLtHandle_t
>
(
at
::
cuda
::
getCurrentCUDABlasHandle
());
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// setting this to 1M.
// setting this to 1M.
size_t
workspaceSize
=
1024
*
1024
;
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
size_t
workspaceSize
=
1024
*
1024
*
(
at
::
cuda
::
getCurrentDeviceProperties
()
->
major
>=
9
?
32
:
4
);
void
*
workspace
=
at
::
empty
(
void
*
workspace
=
at
::
empty
(
{
static_cast
<
int64_t
>
(
workspaceSize
)},
{
static_cast
<
int64_t
>
(
workspaceSize
)},
at
::
device
({
at
::
kCUDA
,
at
::
cuda
::
current_device
()}).
dtype
(
at
::
kByte
)).
data_ptr
();
at
::
device
({
at
::
kCUDA
,
at
::
cuda
::
current_device
()}).
dtype
(
at
::
kByte
)).
data_ptr
();
...
...
flash_attn/modules/mlp.py
View file @
315fd31f
...
@@ -17,7 +17,7 @@ class Mlp(nn.Module):
...
@@ -17,7 +17,7 @@ class Mlp(nn.Module):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
super
().
__init__
()
out_features
=
out_features
or
in_features
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
hidden_features
=
hidden_features
or
in_features
*
4
self
.
return_residual
=
return_residual
self
.
return_residual
=
return_residual
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
,
**
factory_kwargs
)
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
,
**
factory_kwargs
)
self
.
activation
=
activation
self
.
activation
=
activation
...
...
setup.py
View file @
315fd31f
...
@@ -162,7 +162,7 @@ ext_modules.append(
...
@@ -162,7 +162,7 @@ ext_modules.append(
setup
(
setup
(
name
=
"flash_attn"
,
name
=
"flash_attn"
,
version
=
"
0.2.8
"
,
version
=
"
1.0.1
"
,
packages
=
find_packages
(
packages
=
find_packages
(
exclude
=
(
"build"
,
"csrc"
,
"include"
,
"tests"
,
"dist"
,
"docs"
,
"benchmarks"
,
"flash_attn.egg-info"
,)
exclude
=
(
"build"
,
"csrc"
,
"include"
,
"tests"
,
"dist"
,
"docs"
,
"benchmarks"
,
"flash_attn.egg-info"
,)
),
),
...
...
training/Dockerfile
View file @
315fd31f
...
@@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
...
@@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
RUN
pip
install
git+https://github.com/mlcommons/logging.git@2.1.0
RUN
pip
install
git+https://github.com/mlcommons/logging.git@2.1.0
# Install FlashAttention
# Install FlashAttention
RUN
pip
install
flash-attn
==
0.2.8
RUN
pip
install
flash-attn
==
1.0.1
# Install CUDA extensions for cross-entropy, fused dense, layer norm
# Install CUDA extensions for cross-entropy, fused dense, layer norm
RUN
git clone https://github.com/HazyResearch/flash-attention
\
RUN
git clone https://github.com/HazyResearch/flash-attention
\
&&
cd
flash-attention
&&
git checkout v
0.2.8
\
&&
cd
flash-attention
&&
git checkout v
1.0.1
\
&&
cd
csrc/fused_softmax
&&
pip
install
.
&&
cd
../../
\
&&
cd
csrc/fused_softmax
&&
pip
install
.
&&
cd
../../
\
&&
cd
csrc/rotary
&&
pip
install
.
&&
cd
../../
\
&&
cd
csrc/rotary
&&
pip
install
.
&&
cd
../../
\
&&
cd
csrc/xentropy
&&
pip
install
.
&&
cd
../../
\
&&
cd
csrc/xentropy
&&
pip
install
.
&&
cd
../../
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment