Commit 0d23f715 authored by sxtyzhangzk's avatar sxtyzhangzk
Browse files

[major] support sm_120

parent 4d6d778f
...@@ -301,7 +301,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -301,7 +301,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5; // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
bool is_sm90 = dprops->major == 9 && dprops->minor == 0; bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); bool is_sm120 = dprops->major == 12 && dprops->minor == 0;
TORCH_CHECK(is_sm120 || is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
// We will support Turing in the near future // We will support Turing in the near future
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
...@@ -309,7 +310,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -309,7 +310,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type"); "FlashAttention only support fp16 and bf16 data type");
if (q_dtype == torch::kBFloat16) { if (q_dtype == torch::kBFloat16) {
TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); TORCH_CHECK(is_sm120 || is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
} }
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
...@@ -514,7 +515,8 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -514,7 +515,8 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5; // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
bool is_sm90 = dprops->major == 9 && dprops->minor == 0; bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); bool is_sm120 = dprops->major == 12 && dprops->minor == 0;
TORCH_CHECK(is_sm120 || is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
// We will support Turing in the near future // We will support Turing in the near future
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
...@@ -522,7 +524,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -522,7 +524,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type"); "FlashAttention only support fp16 and bf16 data type");
if (q_dtype == torch::kBFloat16) { if (q_dtype == torch::kBFloat16) {
TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); TORCH_CHECK(is_sm120 || is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
} }
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
...@@ -1504,6 +1506,7 @@ mha_fwd_block(const at::Tensor &q, ...@@ -1504,6 +1506,7 @@ mha_fwd_block(const at::Tensor &q,
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5; // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
bool is_sm90 = dprops->major == 9 && dprops->minor == 0; bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
bool is_sm120 = dprops->major == 12 && dprops->minor == 0;
const bool has_blockmask = row_blockmask_.has_value(); const bool has_blockmask = row_blockmask_.has_value();
const bool has_streaming_info = streaming_info_.has_value(); const bool has_streaming_info = streaming_info_.has_value();
at::Tensor row_blockmask, streaming_info; at::Tensor row_blockmask, streaming_info;
...@@ -1513,14 +1516,14 @@ mha_fwd_block(const at::Tensor &q, ...@@ -1513,14 +1516,14 @@ mha_fwd_block(const at::Tensor &q,
if (has_streaming_info){ if (has_streaming_info){
streaming_info = streaming_info_.value(); streaming_info = streaming_info_.value();
} }
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); TORCH_CHECK(is_sm120 || is_sm90 || is_sm8x , "FlashAttention only supports Ampere GPUs or newer.");
auto q_dtype = q.dtype(); auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type"); "FlashAttention only support fp16 and bf16 data type");
if (q_dtype == torch::kBFloat16) { if (q_dtype == torch::kBFloat16) {
TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); TORCH_CHECK(is_sm120 || is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
} }
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment