Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
bedcbd6a
"ppocr/utils/loggers/base_logger.py" did not exist on "e4ab0ebe86d679586b9cde01d0648d3bf5d65860"
Commit
bedcbd6a
authored
Oct 31, 2022
by
Tri Dao
Browse files
Disable some autotune configs that give wrong results in Triton bwd
parent
e78d509c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
5 deletions
+6
-5
flash_attn/flash_attn_triton.py
flash_attn/flash_attn_triton.py
+6
-5
No files found.
flash_attn/flash_attn_triton.py
View file @
bedcbd6a
...
...
@@ -385,11 +385,12 @@ def init_to_zero(name):
configs
=
[
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"SEQUENCE_PARALLEL"
:
False
},
num_warps
=
8
,
num_stages
=
1
,
pre_hook
=
init_to_zero
(
'DQ'
)),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"SEQUENCE_PARALLEL"
:
True
},
num_warps
=
8
,
num_stages
=
1
,
pre_hook
=
init_to_zero
(
'DQ'
)),
# Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"SEQUENCE_PARALLEL"
:
False
},
num_warps
=
8
,
num_stages
=
1
,
pre_hook
=
init_to_zero
(
'DQ'
)),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"SEQUENCE_PARALLEL"
:
True
},
num_warps
=
8
,
num_stages
=
1
,
pre_hook
=
init_to_zero
(
'DQ'
)),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
64
,
"SEQUENCE_PARALLEL"
:
False
},
num_warps
=
4
,
num_stages
=
1
,
pre_hook
=
init_to_zero
(
'DQ'
)),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
64
,
"SEQUENCE_PARALLEL"
:
True
},
num_warps
=
4
,
num_stages
=
1
,
pre_hook
=
init_to_zero
(
'DQ'
)),
# Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now
# # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1),
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1),
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment