Commit f298a271 authored by zhanghj2's avatar zhanghj2
Browse files

sm90改为gfx93

parent a8393a04
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "common.h" #include "common.h"
#include "params.h" #include "params.h"
#include "sm90/decode/dense/splitkv_mla.h" #include "gfx93/decode/dense/splitkv_mla.h"
#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h" #include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "smxx/decode/combine/combine.h" #include "smxx/decode/combine/combine.h"
...@@ -173,12 +173,12 @@ dense_attn_decode_interface( ...@@ -173,12 +173,12 @@ dense_attn_decode_interface(
params.stream = at::cuda::getCurrentCUDAStream().stream(); params.stream = at::cuda::getCurrentCUDAStream().stream();
if (q_dtype == torch::kBFloat16) { if (q_dtype == torch::kBFloat16) {
sm90::run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(params); gfx93::run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(params);
} else if (q_dtype == torch::kHalf) { } else if (q_dtype == torch::kHalf) {
#ifdef FLASH_MLA_DISABLE_FP16 #ifdef FLASH_MLA_DISABLE_FP16
TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA."); TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA.");
#else #else
sm90::run_flash_splitkv_mla_kernel<cutlass::half_t>(params); gfx93::run_flash_splitkv_mla_kernel<cutlass::half_t>(params);
#endif #endif
} else { } else {
TORCH_CHECK(false, "Unsupported dtype for dense MLA on SM90"); TORCH_CHECK(false, "Unsupported dtype for dense MLA on SM90");
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "common.h" #include "common.h"
#include "params.h" #include "params.h"
#include "sm90/decode/dense_kvfp8/splitkv_mla.h" #include "gfx93/decode/dense_kvfp8/splitkv_mla.h"
#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h" #include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "smxx/decode/combine/combine.h" #include "smxx/decode/combine/combine.h"
...@@ -188,7 +188,7 @@ dense_attn_decode_kvfp8_interface( ...@@ -188,7 +188,7 @@ dense_attn_decode_kvfp8_interface(
params.stream = at::cuda::getCurrentCUDAStream().stream(); params.stream = at::cuda::getCurrentCUDAStream().stream();
if (q_dtype == torch::kBFloat16) { if (q_dtype == torch::kBFloat16) {
sm90::run_flash_splitkv_mla_kvfp8_kernel<cutlass::bfloat16_t>(params); gfx93::run_flash_splitkv_mla_kvfp8_kernel<cutlass::bfloat16_t>(params);
} else { } else {
TORCH_CHECK(false, "Unsupported dtype for dense MLA on SM90"); TORCH_CHECK(false, "Unsupported dtype for dense MLA on SM90");
} }
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "common.h" #include "common.h"
#include "params.h" #include "params.h"
#include "sm90/decode/dense_qkvfp8/splitkv_mla.h" #include "gfx93/decode/dense_qkvfp8/splitkv_mla.h"
#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h" #include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "smxx/decode/combine/combine.h" #include "smxx/decode/combine/combine.h"
...@@ -188,7 +188,7 @@ dense_attn_decode_qkvfp8_interface( ...@@ -188,7 +188,7 @@ dense_attn_decode_qkvfp8_interface(
params.stream = at::cuda::getCurrentCUDAStream().stream(); params.stream = at::cuda::getCurrentCUDAStream().stream();
if (q_dtype == torch::kFloat8_e4m3fn) { if (q_dtype == torch::kFloat8_e4m3fn) {
sm90::run_flash_splitkv_mla_qkvfp8_kernel<cutlass::float_e4m3_t>(params); gfx93::run_flash_splitkv_mla_qkvfp8_kernel<cutlass::float_e4m3_t>(params);
} else { } else {
TORCH_CHECK(false, "Unsupported dtype for dense MLA on SM90"); TORCH_CHECK(false, "Unsupported dtype for dense MLA on SM90");
} }
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include "params.h" #include "params.h"
#include "sm90/decode/sparse_fp8/splitkv_mla.h" #include "gfx93/decode/sparse_fp8/splitkv_mla.h"
#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h" #include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "smxx/decode/combine/combine.h" #include "smxx/decode/combine/combine.h"
...@@ -76,7 +76,7 @@ protected: ...@@ -76,7 +76,7 @@ protected:
void run_(const SparseAttnDecodeParams &params, const std::vector<FeatureT> &required_features) override { void run_(const SparseAttnDecodeParams &params, const std::vector<FeatureT> &required_features) override {
DISPATCH_MODEL_TYPE(params.model_type, MODEL_TYPE, [&]() { DISPATCH_MODEL_TYPE(params.model_type, MODEL_TYPE, [&]() {
DISPATCH_NUM_HEADS(params.h_q, NUM_HEADS, [&]() { DISPATCH_NUM_HEADS(params.h_q, NUM_HEADS, [&]() {
sm90::decode::sparse_fp8::run_flash_splitkv_mla_fp8_sparse_kernel<MODEL_TYPE, NUM_HEADS>(params); gfx93::decode::sparse_fp8::run_flash_splitkv_mla_fp8_sparse_kernel<MODEL_TYPE, NUM_HEADS>(params);
}); });
}); });
} }
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include "params.h" #include "params.h"
#include "sm90/prefill/sparse/phase1.h" #include "gfx93/prefill/sparse/phase1.h"
enum class FwdFeatures : int { enum class FwdFeatures : int {
...@@ -41,7 +41,7 @@ protected: ...@@ -41,7 +41,7 @@ protected:
void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override { void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {
DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() { DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() {
DISPATCH_BOOLEAN_FLAG(params.topk_length != nullptr, HAVE_TOPK_LENGTH, [&]() { DISPATCH_BOOLEAN_FLAG(params.topk_length != nullptr, HAVE_TOPK_LENGTH, [&]() {
sm90::fwd::run_fwd_phase1_kernel<HEAD_DIM_QK, HAVE_TOPK_LENGTH>(params); gfx93::fwd::run_fwd_phase1_kernel<HEAD_DIM_QK, HAVE_TOPK_LENGTH>(params);
}); });
}); });
} }
......
#include "../splitkv_mla.cuh" #include "../splitkv_mla.cuh"
#include "../splitkv_mla.h" #include "../splitkv_mla.h"
namespace sm90 { namespace gfx93 {
template void run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(DenseAttnDecodeParams &params); template void run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(DenseAttnDecodeParams &params);
......
#include "../splitkv_mla.cuh" #include "../splitkv_mla.cuh"
#include "../splitkv_mla.h" #include "../splitkv_mla.h"
namespace sm90 { namespace gfx93 {
#ifndef FLASH_MLA_DISABLE_FP16 #ifndef FLASH_MLA_DISABLE_FP16
template void run_flash_splitkv_mla_kernel<cutlass::half_t>(DenseAttnDecodeParams &params); template void run_flash_splitkv_mla_kernel<cutlass::half_t>(DenseAttnDecodeParams &params);
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "softmax.h" #include "softmax.h"
using namespace cute; using namespace cute;
namespace sm90 { namespace gfx93 {
template<typename T> template<typename T>
__device__ void __device__ void
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "params.h" #include "params.h"
namespace sm90 { namespace gfx93 {
template<typename InputT> template<typename InputT>
void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams &params); void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams &params);
......
#include "../splitkv_mla.cuh" #include "../splitkv_mla.cuh"
#include "../splitkv_mla.h" #include "../splitkv_mla.h"
namespace sm90 { namespace gfx93 {
template void run_flash_splitkv_mla_kvfp8_kernel<cutlass::bfloat16_t>(DenseAttnDecodeParams_fp8 &params); template void run_flash_splitkv_mla_kvfp8_kernel<cutlass::bfloat16_t>(DenseAttnDecodeParams_fp8 &params);
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "softmax.h" #include "softmax.h"
using namespace cute; using namespace cute;
namespace sm90 { namespace gfx93 {
template<typename T> template<typename T>
__device__ void __device__ void
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "params.h" #include "params.h"
namespace sm90 { namespace gfx93 {
template<typename InputT> template<typename InputT>
void run_flash_splitkv_mla_kvfp8_kernel(DenseAttnDecodeParams_fp8 &params); void run_flash_splitkv_mla_kvfp8_kernel(DenseAttnDecodeParams_fp8 &params);
......
#include "../splitkv_mla.cuh" #include "../splitkv_mla.cuh"
#include "../splitkv_mla.h" #include "../splitkv_mla.h"
namespace sm90 { namespace gfx93 {
template void run_flash_splitkv_mla_qkvfp8_kernel<cutlass::float_e4m3_t>(DenseAttnDecodeParams_fp8 &params); template void run_flash_splitkv_mla_qkvfp8_kernel<cutlass::float_e4m3_t>(DenseAttnDecodeParams_fp8 &params);
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "softmax.h" #include "softmax.h"
using namespace cute; using namespace cute;
namespace sm90 { namespace gfx93 {
template<typename T> template<typename T>
__device__ void __device__ void
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "params.h" #include "params.h"
namespace sm90 { namespace gfx93 {
template<typename InputT> template<typename InputT>
void run_flash_splitkv_mla_qkvfp8_kernel(DenseAttnDecodeParams_fp8 &params); void run_flash_splitkv_mla_qkvfp8_kernel(DenseAttnDecodeParams_fp8 &params);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment