Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c2488003
Unverified
Commit
c2488003
authored
Nov 07, 2022
by
Jiarui Fang
Committed by
GitHub
Nov 07, 2022
Browse files
[kernel] skip tests of flash_attn and triton when they are not available (#1798)
parent
e34e850a
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
405 additions
and
294 deletions
+405
-294
colossalai/gemini/gemini_mgr.py
colossalai/gemini/gemini_mgr.py
+1
-1
colossalai/kernel/cuda_native/flash_attention.py
colossalai/kernel/cuda_native/flash_attention.py
+387
-286
tests/test_utils/test_flash_attention.py
tests/test_utils/test_flash_attention.py
+17
-7
No files found.
colossalai/gemini/gemini_mgr.py
View file @
c2488003
...
...
@@ -61,7 +61,7 @@ class GeminiManager:
self
.
_comp_cuda_demand_time
=
0
def
adjust_layout
(
self
,
chunks
:
Tuple
[
Chunk
,
...])
->
None
:
""" Adjust the layout of statefu
i
l tensor according to the information provided
""" Adjust the layout of stateful tensor
s
according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.
"""
# find stateful tensor in state COMPUTE
...
...
colossalai/kernel/cuda_native/flash_attention.py
View file @
c2488003
This diff is collapsed.
Click to expand it.
tests/test_utils/test_flash_attention.py
View file @
c2488003
import
torch
import
pytest
import
torch
from
einops
import
rearrange
from
colossalai.kernel.cuda_native.flash_attention
import
flash_attention
,
triton_flash_attention
,
TRITON_AVALIABLE
from
colossalai.kernel.cuda_native.flash_attention
import
HAS_FLASH_ATTN
,
HAS_TRITON
,
TRITON_AVALIABLE
if
HAS_FLASH_ATTN
:
from
colossalai.kernel.cuda_native.flash_attention
import
flash_attention
if
HAS_TRITON
:
from
colossalai.kernel.cuda_native.flash_attention
import
triton_flash_attention
def
baseline_attention
(
Z
,
N_CTX
,
H
,
q
,
k
,
v
,
sm_scale
):
...
...
@@ -14,7 +21,8 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
ref_out
=
torch
.
matmul
(
p
,
v
)
return
ref_out
@
pytest
.
mark
.
skipif
(
HAS_FLASH_ATTN
==
False
,
reason
=
"triton is not available"
)
@
pytest
.
mark
.
parametrize
(
'Z, H, N_CTX, D_HEAD'
,
[(
3
,
2
,
16
,
8
)])
def
test_triton_flash_attention
(
Z
,
H
,
N_CTX
,
D_HEAD
,
dtype
=
torch
.
float16
):
torch
.
manual_seed
(
20
)
...
...
@@ -23,7 +31,7 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
v
=
torch
.
empty
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0
,
std
=
.
5
).
requires_grad_
()
sm_scale
=
0.3
dout
=
torch
.
randn_like
(
q
)
ref_out
=
baseline_attention
(
Z
,
N_CTX
,
H
,
q
,
k
,
v
,
sm_scale
)
ref_out
.
backward
(
dout
)
ref_dv
,
v
.
grad
=
v
.
grad
.
clone
(),
None
...
...
@@ -51,6 +59,7 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
raise
TypeError
(
"Error type not match!"
)
@
pytest
.
mark
.
skipif
(
HAS_FLASH_ATTN
==
False
,
reason
=
"triton is not available"
)
@
pytest
.
mark
.
parametrize
(
'Z, H, N_CTX, D_HEAD'
,
[(
3
,
2
,
16
,
8
)])
def
test_flash_attention
(
Z
,
H
,
N_CTX
,
D_HEAD
,
dtype
=
torch
.
float16
):
torch
.
manual_seed
(
20
)
...
...
@@ -59,21 +68,22 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
v
=
torch
.
randn
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0
,
std
=
.
5
).
requires_grad_
()
sm_scale
=
0.3
dout
=
torch
.
randn_like
(
q
)
# reference implementation
ref_out
=
baseline_attention
(
Z
,
N_CTX
,
H
,
q
,
k
,
v
,
sm_scale
)
ref_out
.
backward
(
dout
)
ref_dv
,
v
.
grad
=
v
.
grad
.
clone
(),
None
ref_dk
,
k
.
grad
=
k
.
grad
.
clone
(),
None
ref_dq
,
q
.
grad
=
q
.
grad
.
clone
(),
None
# flash implementation
q
,
k
,
v
=
map
(
lambda
x
:
rearrange
(
x
,
'z h n d -> (z n) h d'
),
[
q
,
k
,
v
])
tri_out
=
flash_attention
(
q
,
k
,
v
,
sm_scale
,
Z
,
N_CTX
)
dout
=
rearrange
(
dout
,
'z h n d -> (z n) h d'
).
detach
()
tri_out
.
backward
(
dout
,
retain_graph
=
True
)
tri_dq
,
tri_dk
,
tri_dv
,
=
torch
.
autograd
.
grad
(
tri_out
,
(
q
,
k
,
v
),
dout
)
tri_out
,
tri_dq
,
tri_dk
,
tri_dv
=
map
(
lambda
x
:
rearrange
(
x
,
'(z n) h d -> z h n d'
,
z
=
Z
),
(
tri_out
,
tri_dq
,
tri_dk
,
tri_dv
))
tri_out
,
tri_dq
,
tri_dk
,
tri_dv
=
map
(
lambda
x
:
rearrange
(
x
,
'(z n) h d -> z h n d'
,
z
=
Z
),
(
tri_out
,
tri_dq
,
tri_dk
,
tri_dv
))
# compare
assert
torch
.
allclose
(
ref_out
,
tri_out
,
atol
=
1e-3
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment