Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhangdong1
Block-Sparse-Attention
Commits
0d23f715
Commit
0d23f715
authored
Feb 18, 2025
by
sxtyzhangzk
Browse files
[major] support sm_120
parent
4d6d778f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
6 deletions
+9
-6
csrc/block_sparse_attn/flash_api.cpp
csrc/block_sparse_attn/flash_api.cpp
+9
-6
No files found.
csrc/block_sparse_attn/flash_api.cpp
View file @
0d23f715
...
@@ -301,7 +301,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -301,7 +301,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
;
bool
is_sm90
=
dprops
->
major
==
9
&&
dprops
->
minor
==
0
;
bool
is_sm90
=
dprops
->
major
==
9
&&
dprops
->
minor
==
0
;
TORCH_CHECK
(
is_sm90
||
is_sm8x
,
"FlashAttention only supports Ampere GPUs or newer."
);
bool
is_sm120
=
dprops
->
major
==
12
&&
dprops
->
minor
==
0
;
TORCH_CHECK
(
is_sm120
||
is_sm90
||
is_sm8x
,
"FlashAttention only supports Ampere GPUs or newer."
);
// We will support Turing in the near future
// We will support Turing in the near future
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
...
@@ -309,7 +310,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -309,7 +310,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK
(
q_dtype
==
torch
::
kFloat16
||
q_dtype
==
torch
::
kBFloat16
,
TORCH_CHECK
(
q_dtype
==
torch
::
kFloat16
||
q_dtype
==
torch
::
kBFloat16
,
"FlashAttention only support fp16 and bf16 data type"
);
"FlashAttention only support fp16 and bf16 data type"
);
if
(
q_dtype
==
torch
::
kBFloat16
)
{
if
(
q_dtype
==
torch
::
kBFloat16
)
{
TORCH_CHECK
(
is_sm90
||
is_sm8x
,
"bfloat16 is only supported on Ampere GPUs or newer"
);
TORCH_CHECK
(
is_sm120
||
is_sm90
||
is_sm8x
,
"bfloat16 is only supported on Ampere GPUs or newer"
);
}
}
TORCH_CHECK
(
k
.
dtype
()
==
q_dtype
,
"query and key must have the same dtype"
);
TORCH_CHECK
(
k
.
dtype
()
==
q_dtype
,
"query and key must have the same dtype"
);
TORCH_CHECK
(
v
.
dtype
()
==
q_dtype
,
"query and value must have the same dtype"
);
TORCH_CHECK
(
v
.
dtype
()
==
q_dtype
,
"query and value must have the same dtype"
);
...
@@ -514,7 +515,8 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
...
@@ -514,7 +515,8 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
;
bool
is_sm90
=
dprops
->
major
==
9
&&
dprops
->
minor
==
0
;
bool
is_sm90
=
dprops
->
major
==
9
&&
dprops
->
minor
==
0
;
TORCH_CHECK
(
is_sm90
||
is_sm8x
,
"FlashAttention only supports Ampere GPUs or newer."
);
bool
is_sm120
=
dprops
->
major
==
12
&&
dprops
->
minor
==
0
;
TORCH_CHECK
(
is_sm120
||
is_sm90
||
is_sm8x
,
"FlashAttention only supports Ampere GPUs or newer."
);
// We will support Turing in the near future
// We will support Turing in the near future
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
...
@@ -522,7 +524,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
...
@@ -522,7 +524,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
TORCH_CHECK
(
q_dtype
==
torch
::
kFloat16
||
q_dtype
==
torch
::
kBFloat16
,
TORCH_CHECK
(
q_dtype
==
torch
::
kFloat16
||
q_dtype
==
torch
::
kBFloat16
,
"FlashAttention only support fp16 and bf16 data type"
);
"FlashAttention only support fp16 and bf16 data type"
);
if
(
q_dtype
==
torch
::
kBFloat16
)
{
if
(
q_dtype
==
torch
::
kBFloat16
)
{
TORCH_CHECK
(
is_sm90
||
is_sm8x
,
"bfloat16 is only supported on Ampere GPUs or newer"
);
TORCH_CHECK
(
is_sm120
||
is_sm90
||
is_sm8x
,
"bfloat16 is only supported on Ampere GPUs or newer"
);
}
}
TORCH_CHECK
(
k
.
dtype
()
==
q_dtype
,
"query and key must have the same dtype"
);
TORCH_CHECK
(
k
.
dtype
()
==
q_dtype
,
"query and key must have the same dtype"
);
TORCH_CHECK
(
v
.
dtype
()
==
q_dtype
,
"query and value must have the same dtype"
);
TORCH_CHECK
(
v
.
dtype
()
==
q_dtype
,
"query and value must have the same dtype"
);
...
@@ -1504,6 +1506,7 @@ mha_fwd_block(const at::Tensor &q,
...
@@ -1504,6 +1506,7 @@ mha_fwd_block(const at::Tensor &q,
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
;
bool
is_sm90
=
dprops
->
major
==
9
&&
dprops
->
minor
==
0
;
bool
is_sm90
=
dprops
->
major
==
9
&&
dprops
->
minor
==
0
;
bool
is_sm120
=
dprops
->
major
==
12
&&
dprops
->
minor
==
0
;
const
bool
has_blockmask
=
row_blockmask_
.
has_value
();
const
bool
has_blockmask
=
row_blockmask_
.
has_value
();
const
bool
has_streaming_info
=
streaming_info_
.
has_value
();
const
bool
has_streaming_info
=
streaming_info_
.
has_value
();
at
::
Tensor
row_blockmask
,
streaming_info
;
at
::
Tensor
row_blockmask
,
streaming_info
;
...
@@ -1513,14 +1516,14 @@ mha_fwd_block(const at::Tensor &q,
...
@@ -1513,14 +1516,14 @@ mha_fwd_block(const at::Tensor &q,
if
(
has_streaming_info
){
if
(
has_streaming_info
){
streaming_info
=
streaming_info_
.
value
();
streaming_info
=
streaming_info_
.
value
();
}
}
TORCH_CHECK
(
is_sm90
||
is_sm8x
,
"FlashAttention only supports Ampere GPUs or newer."
);
TORCH_CHECK
(
is_sm120
||
is_sm90
||
is_sm8x
,
"FlashAttention only supports Ampere GPUs or newer."
);
auto
q_dtype
=
q
.
dtype
();
auto
q_dtype
=
q
.
dtype
();
TORCH_CHECK
(
q_dtype
==
torch
::
kFloat16
||
q_dtype
==
torch
::
kBFloat16
,
TORCH_CHECK
(
q_dtype
==
torch
::
kFloat16
||
q_dtype
==
torch
::
kBFloat16
,
"FlashAttention only support fp16 and bf16 data type"
);
"FlashAttention only support fp16 and bf16 data type"
);
if
(
q_dtype
==
torch
::
kBFloat16
)
{
if
(
q_dtype
==
torch
::
kBFloat16
)
{
TORCH_CHECK
(
is_sm90
||
is_sm8x
,
"bfloat16 is only supported on Ampere GPUs or newer"
);
TORCH_CHECK
(
is_sm120
||
is_sm90
||
is_sm8x
,
"bfloat16 is only supported on Ampere GPUs or newer"
);
}
}
TORCH_CHECK
(
k
.
dtype
()
==
q_dtype
,
"query and key must have the same dtype"
);
TORCH_CHECK
(
k
.
dtype
()
==
q_dtype
,
"query and key must have the same dtype"
);
TORCH_CHECK
(
v
.
dtype
()
==
q_dtype
,
"query and value must have the same dtype"
);
TORCH_CHECK
(
v
.
dtype
()
==
q_dtype
,
"query and value must have the same dtype"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment