Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhangdong1
Block-Sparse-Attention
Commits
0d23f715
Commit
0d23f715
authored
Feb 18, 2025
by
sxtyzhangzk
Browse files
[major] support sm_120
parent
4d6d778f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
6 deletions
+9
-6
csrc/block_sparse_attn/flash_api.cpp
csrc/block_sparse_attn/flash_api.cpp
+9
-6
No files found.
csrc/block_sparse_attn/flash_api.cpp
View file @
0d23f715
...
@@ -301,7 +301,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -301,7 +301,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
TORCH_CHECK
(
is_sm90
||
is_sm8x
,
"FlashAttention only supports Ampere GPUs or newer."
);
bool is_sm120 = dprops->major == 12 && dprops->minor == 0;
TORCH_CHECK(is_sm120 || is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
// We will support Turing in the near future
// We will support Turing in the near future
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
...
@@ -309,7 +310,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -309,7 +310,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
"FlashAttention only support fp16 and bf16 data type");
if (q_dtype == torch::kBFloat16) {
if (q_dtype == torch::kBFloat16) {
TORCH_CHECK
(
is_sm90
||
is_sm8x
,
"bfloat16 is only supported on Ampere GPUs or newer"
);
TORCH_CHECK(
is_sm120 ||
is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
}
}
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
...
@@ -514,7 +515,8 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
...
@@ -514,7 +515,8 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
TORCH_CHECK
(
is_sm90
||
is_sm8x
,
"FlashAttention only supports Ampere GPUs or newer."
);
bool is_sm120 = dprops->major == 12 && dprops->minor == 0;
TORCH_CHECK(is_sm120 || is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
// We will support Turing in the near future
// We will support Turing in the near future
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
...
@@ -522,7 +524,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
...
@@ -522,7 +524,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
"FlashAttention only support fp16 and bf16 data type");
if (q_dtype == torch::kBFloat16) {
if (q_dtype == torch::kBFloat16) {
TORCH_CHECK
(
is_sm90
||
is_sm8x
,
"bfloat16 is only supported on Ampere GPUs or newer"
);
TORCH_CHECK(
is_sm120 ||
is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
}
}
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
...
@@ -1504,6 +1506,7 @@ mha_fwd_block(const at::Tensor &q,
...
@@ -1504,6 +1506,7 @@ mha_fwd_block(const at::Tensor &q,
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
bool is_sm120 = dprops->major == 12 && dprops->minor == 0;
const bool has_blockmask = row_blockmask_.has_value();
const bool has_blockmask = row_blockmask_.has_value();
const bool has_streaming_info = streaming_info_.has_value();
const bool has_streaming_info = streaming_info_.has_value();
at::Tensor row_blockmask, streaming_info;
at::Tensor row_blockmask, streaming_info;
...
@@ -1513,14 +1516,14 @@ mha_fwd_block(const at::Tensor &q,
...
@@ -1513,14 +1516,14 @@ mha_fwd_block(const at::Tensor &q,
if (has_streaming_info){
if (has_streaming_info){
streaming_info = streaming_info_.value();
streaming_info = streaming_info_.value();
}
}
TORCH_CHECK
(
is_sm90
||
is_sm8x
,
"FlashAttention only supports Ampere GPUs or newer."
);
TORCH_CHECK(
is_sm120 ||
is_sm90 || is_sm8x
, "FlashAttention only supports Ampere GPUs or newer.");
auto q_dtype = q.dtype();
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
"FlashAttention only support fp16 and bf16 data type");
if (q_dtype == torch::kBFloat16) {
if (q_dtype == torch::kBFloat16) {
TORCH_CHECK
(
is_sm90
||
is_sm8x
,
"bfloat16 is only supported on Ampere GPUs or newer"
);
TORCH_CHECK(
is_sm120 ||
is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
}
}
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment