Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
0842ec0d
Commit
0842ec0d
authored
Dec 23, 2023
by
Tri Dao
Browse files
Don't dispatch to local if window size >= seqlen_k
parent
73265458
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
5 deletions
+20
-5
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+20
-5
No files found.
csrc/flash_attn/flash_api.cpp
View file @
0842ec0d
...
@@ -260,7 +260,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -260,7 +260,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const
float
p_dropout
,
const
float
p_dropout
,
const
float
softmax_scale
,
const
float
softmax_scale
,
bool
is_causal
,
bool
is_causal
,
const
int
window_size_left
,
int
window_size_left
,
int
window_size_right
,
int
window_size_right
,
const
bool
return_softmax
,
const
bool
return_softmax
,
c10
::
optional
<
at
::
Generator
>
gen_
)
{
c10
::
optional
<
at
::
Generator
>
gen_
)
{
...
@@ -300,6 +300,9 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -300,6 +300,9 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
if
(
window_size_left
>=
seqlen_k
)
{
window_size_left
=
-
1
;
}
if
(
window_size_right
>=
seqlen_k
)
{
window_size_right
=
-
1
;
}
// causal=true is the same as causal=false in this case
// causal=true is the same as causal=false in this case
if
(
seqlen_q
==
1
&&
!
alibi_slopes_
.
has_value
())
{
is_causal
=
false
;
}
if
(
seqlen_q
==
1
&&
!
alibi_slopes_
.
has_value
())
{
is_causal
=
false
;
}
if
(
is_causal
)
{
window_size_right
=
0
;
}
if
(
is_causal
)
{
window_size_right
=
0
;
}
...
@@ -465,7 +468,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
...
@@ -465,7 +468,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const
float
softmax_scale
,
const
float
softmax_scale
,
const
bool
zero_tensors
,
const
bool
zero_tensors
,
const
bool
is_causal
,
const
bool
is_causal
,
const
int
window_size_left
,
int
window_size_left
,
int
window_size_right
,
int
window_size_right
,
const
bool
return_softmax
,
const
bool
return_softmax
,
c10
::
optional
<
at
::
Generator
>
gen_
)
{
c10
::
optional
<
at
::
Generator
>
gen_
)
{
...
@@ -512,6 +515,9 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
...
@@ -512,6 +515,9 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
if
(
window_size_left
>=
max_seqlen_k
)
{
window_size_left
=
-
1
;
}
if
(
window_size_right
>=
max_seqlen_k
)
{
window_size_right
=
-
1
;
}
CHECK_SHAPE
(
q
,
total_q
,
num_heads
,
head_size_og
);
CHECK_SHAPE
(
q
,
total_q
,
num_heads
,
head_size_og
);
CHECK_SHAPE
(
k
,
total_k
,
num_heads_k
,
head_size_og
);
CHECK_SHAPE
(
k
,
total_k
,
num_heads_k
,
head_size_og
);
CHECK_SHAPE
(
v
,
total_k
,
num_heads_k
,
head_size_og
);
CHECK_SHAPE
(
v
,
total_k
,
num_heads_k
,
head_size_og
);
...
@@ -675,7 +681,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
...
@@ -675,7 +681,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
const
float
p_dropout
,
// probability to drop
const
float
p_dropout
,
// probability to drop
const
float
softmax_scale
,
const
float
softmax_scale
,
const
bool
is_causal
,
const
bool
is_causal
,
const
int
window_size_left
,
int
window_size_left
,
int
window_size_right
,
int
window_size_right
,
const
bool
deterministic
,
const
bool
deterministic
,
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Generator
>
gen_
,
...
@@ -738,6 +744,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
...
@@ -738,6 +744,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
TORCH_CHECK
(
head_size
==
round_multiple
(
head_size_og
,
8
),
"head_size must be head_size_og rounded to a multiple of 8"
);
TORCH_CHECK
(
head_size
==
round_multiple
(
head_size_og
,
8
),
"head_size must be head_size_og rounded to a multiple of 8"
);
if
(
window_size_left
>=
seqlen_k
)
{
window_size_left
=
-
1
;
}
if
(
window_size_right
>=
seqlen_k
)
{
window_size_right
=
-
1
;
}
CHECK_SHAPE
(
q
,
batch_size
,
seqlen_q
,
num_heads
,
head_size
);
CHECK_SHAPE
(
q
,
batch_size
,
seqlen_q
,
num_heads
,
head_size
);
CHECK_SHAPE
(
k
,
batch_size
,
seqlen_k
,
num_heads_k
,
head_size
);
CHECK_SHAPE
(
k
,
batch_size
,
seqlen_k
,
num_heads_k
,
head_size
);
CHECK_SHAPE
(
v
,
batch_size
,
seqlen_k
,
num_heads_k
,
head_size
);
CHECK_SHAPE
(
v
,
batch_size
,
seqlen_k
,
num_heads_k
,
head_size
);
...
@@ -912,7 +921,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
...
@@ -912,7 +921,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const
float
softmax_scale
,
const
float
softmax_scale
,
const
bool
zero_tensors
,
const
bool
zero_tensors
,
const
bool
is_causal
,
const
bool
is_causal
,
const
int
window_size_left
,
int
window_size_left
,
int
window_size_right
,
int
window_size_right
,
const
bool
deterministic
,
const
bool
deterministic
,
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Generator
>
gen_
,
...
@@ -979,6 +988,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
...
@@ -979,6 +988,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
TORCH_CHECK
(
head_size
==
round_multiple
(
head_size_og
,
8
),
"head_size must be head_size_og rounded to a multiple of 8"
);
TORCH_CHECK
(
head_size
==
round_multiple
(
head_size_og
,
8
),
"head_size must be head_size_og rounded to a multiple of 8"
);
if
(
window_size_left
>=
max_seqlen_k
)
{
window_size_left
=
-
1
;
}
if
(
window_size_right
>=
max_seqlen_k
)
{
window_size_right
=
-
1
;
}
CHECK_SHAPE
(
q
,
total_q
,
num_heads
,
head_size
);
CHECK_SHAPE
(
q
,
total_q
,
num_heads
,
head_size
);
CHECK_SHAPE
(
k
,
total_k
,
num_heads_k
,
head_size
);
CHECK_SHAPE
(
k
,
total_k
,
num_heads_k
,
head_size
);
CHECK_SHAPE
(
v
,
total_k
,
num_heads_k
,
head_size
);
CHECK_SHAPE
(
v
,
total_k
,
num_heads_k
,
head_size
);
...
@@ -1160,7 +1172,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
...
@@ -1160,7 +1172,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
const
float
softmax_scale
,
const
float
softmax_scale
,
bool
is_causal
,
bool
is_causal
,
const
int
window_size_left
,
int
window_size_left
,
int
window_size_right
,
int
window_size_right
,
bool
is_rotary_interleaved
,
// if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
bool
is_rotary_interleaved
,
// if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int
num_splits
int
num_splits
...
@@ -1216,6 +1228,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
...
@@ -1216,6 +1228,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
num_heads
=
num_heads_k
;
num_heads
=
num_heads_k
;
}
}
if
(
window_size_left
>=
seqlen_k
)
{
window_size_left
=
-
1
;
}
if
(
window_size_right
>=
seqlen_k
)
{
window_size_right
=
-
1
;
}
CHECK_SHAPE
(
q
,
batch_size
,
seqlen_q
,
num_heads
,
head_size_og
);
CHECK_SHAPE
(
q
,
batch_size
,
seqlen_q
,
num_heads
,
head_size_og
);
CHECK_SHAPE
(
kcache
,
batch_size_c
,
seqlen_k
,
num_heads_k
,
head_size_og
);
CHECK_SHAPE
(
kcache
,
batch_size_c
,
seqlen_k
,
num_heads_k
,
head_size_og
);
CHECK_SHAPE
(
vcache
,
batch_size_c
,
seqlen_k
,
num_heads_k
,
head_size_og
);
CHECK_SHAPE
(
vcache
,
batch_size_c
,
seqlen_k
,
num_heads_k
,
head_size_og
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment