Commit 68971b5c authored by zhanghj2's avatar zhanghj2
Browse files

对接口进行架构检查

parent 68055db7
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <kerutils/supplemental/torch_tensors.h> #include <kerutils/supplemental/torch_tensors.h>
#include <string>
#include <cutlass/bfloat16.h> #include <cutlass/bfloat16.h>
static constexpr float LOG_2_E = 1.44269504f; static constexpr float LOG_2_E = 1.44269504f;
...@@ -22,6 +22,7 @@ struct Arch { ...@@ -22,6 +22,7 @@ struct Arch {
int major; int major;
int minor; int minor;
int num_sms; int num_sms;
std::string archName;
cudaDeviceProp* device_prop; cudaDeviceProp* device_prop;
Arch() { Arch() {
...@@ -29,6 +30,7 @@ struct Arch { ...@@ -29,6 +30,7 @@ struct Arch {
major = device_prop->major; major = device_prop->major;
minor = device_prop->minor; minor = device_prop->minor;
num_sms = device_prop->multiProcessorCount; num_sms = device_prop->multiProcessorCount;
archName = device_prop->gcnArchName;
} }
bool is_sm90a() const { bool is_sm90a() const {
...@@ -39,6 +41,22 @@ struct Arch { ...@@ -39,6 +41,22 @@ struct Arch {
bool is_sm100f() const { bool is_sm100f() const {
return major == 10; return major == 10;
} }
bool is_gfx938() const {
return archName.substr(0, archName.find(':')) == "gfx938";
}
bool is_gfx936() const {
return archName.substr(0, archName.find(':')) == "gfx936";
}
bool is_gfx928() const {
return archName.substr(0, archName.find(':')) == "gfx928";
}
bool is_gfx93x() const {
return is_gfx936() || is_gfx938();
}
}; };
// Convert int64_t stride to int32_t, with overflow check. // Convert int64_t stride to int32_t, with overflow check.
......
...@@ -24,8 +24,8 @@ dense_attn_decode_interface( ...@@ -24,8 +24,8 @@ dense_attn_decode_interface(
) { ) {
// Check arch // Check arch
Arch arch = Arch(); Arch arch = Arch();
if (!arch.is_sm90a()) { if (!arch.is_gfx93x()) {
TORCH_CHECK(false, "Dense decode MLA is only supported on SM90a architecture"); TORCH_CHECK(false, "Dense decode MLA is only supported on gfx936 or gfx938 architecture");
} }
// Check data types // Check data types
......
...@@ -26,8 +26,8 @@ dense_attn_decode_kvfp8_interface( ...@@ -26,8 +26,8 @@ dense_attn_decode_kvfp8_interface(
) { ) {
// Check arch // Check arch
Arch arch = Arch(); Arch arch = Arch();
if (!arch.is_sm90a()) { if (!arch.is_gfx93x()) {
TORCH_CHECK(false, "Dense decode MLA is only supported on SM90a architecture"); TORCH_CHECK(false, "Dense decode MLA is only supported on gfx936 or gfx938 architecture");
} }
// Check data types // Check data types
......
...@@ -26,8 +26,8 @@ dense_attn_decode_qkvfp8_interface( ...@@ -26,8 +26,8 @@ dense_attn_decode_qkvfp8_interface(
) { ) {
// Check arch // Check arch
Arch arch = Arch(); Arch arch = Arch();
if (!arch.is_sm90a()) { if (!arch.is_gfx938()) {
TORCH_CHECK(false, "Dense decode MLA is only supported on SM90a architecture"); TORCH_CHECK(false, "Dense decode MLA is only supported on gfx938 architecture");
} }
// Check data types // Check data types
......
...@@ -101,7 +101,9 @@ sparse_attn_decode_interface( ...@@ -101,7 +101,9 @@ sparse_attn_decode_interface(
// Check the architecture // Check the architecture
Arch arch = Arch(); Arch arch = Arch();
if (!arch.is_gfx93x()) {
TORCH_CHECK(false, "Dense decode MLA is only supported on gfx936 or gfx938 architecture");
}
KU_CHECK_NDIM(q, 4); KU_CHECK_NDIM(q, 4);
KU_CHECK_NDIM(kv, 4); KU_CHECK_NDIM(kv, 4);
KU_CHECK_NDIM(indices, 3); KU_CHECK_NDIM(indices, 3);
......
...@@ -59,10 +59,9 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface( ...@@ -59,10 +59,9 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface(
using bf16 = cutlass::bfloat16_t; using bf16 = cutlass::bfloat16_t;
Arch arch = Arch(); Arch arch = Arch();
bool is_sm90a = arch.is_sm90a(); if (!arch.is_gfx93x()) {
bool is_sm100f = arch.is_sm100f(); TORCH_CHECK(false, "Dense decode MLA is only supported on gfx936 or gfx938 architecture");
TORCH_CHECK(is_sm90a || is_sm100f, "Sparse Attention Forward Kernel is only supported on SM90a and SM100f architectures."); }
KU_CHECK_NDIM(q, 3); KU_CHECK_NDIM(q, 3);
KU_CHECK_NDIM(kv, 3); KU_CHECK_NDIM(kv, 3);
KU_CHECK_NDIM(indices, 3); KU_CHECK_NDIM(indices, 3);
...@@ -161,7 +160,7 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface( ...@@ -161,7 +160,7 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface(
required_features.push_back(FwdFeatures::TOPK_LENGTH); required_features.push_back(FwdFeatures::TOPK_LENGTH);
} }
if (is_sm90a) { if (arch.is_gfx93x()) {
Fwd_Sm90_Impl fwd_impl; Fwd_Sm90_Impl fwd_impl;
fwd_impl.run(params, required_features); fwd_impl.run(params, required_features);
} else { } else {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment