Commit f298a271 authored by zhanghj2's avatar zhanghj2
Browse files

sm90改为gfx93

parent a8393a04
......@@ -6,7 +6,7 @@
#include "common.h"
#include "params.h"
#include "sm90/decode/dense/splitkv_mla.h"
#include "gfx93/decode/dense/splitkv_mla.h"
#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "smxx/decode/combine/combine.h"
......@@ -173,12 +173,12 @@ dense_attn_decode_interface(
params.stream = at::cuda::getCurrentCUDAStream().stream();
if (q_dtype == torch::kBFloat16) {
sm90::run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(params);
gfx93::run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(params);
} else if (q_dtype == torch::kHalf) {
#ifdef FLASH_MLA_DISABLE_FP16
TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA.");
#else
sm90::run_flash_splitkv_mla_kernel<cutlass::half_t>(params);
gfx93::run_flash_splitkv_mla_kernel<cutlass::half_t>(params);
#endif
} else {
TORCH_CHECK(false, "Unsupported dtype for dense MLA on SM90");
......
......@@ -6,7 +6,7 @@
#include "common.h"
#include "params.h"
#include "sm90/decode/dense_kvfp8/splitkv_mla.h"
#include "gfx93/decode/dense_kvfp8/splitkv_mla.h"
#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "smxx/decode/combine/combine.h"
......@@ -188,7 +188,7 @@ dense_attn_decode_kvfp8_interface(
params.stream = at::cuda::getCurrentCUDAStream().stream();
if (q_dtype == torch::kBFloat16) {
sm90::run_flash_splitkv_mla_kvfp8_kernel<cutlass::bfloat16_t>(params);
gfx93::run_flash_splitkv_mla_kvfp8_kernel<cutlass::bfloat16_t>(params);
} else {
TORCH_CHECK(false, "Unsupported dtype for dense MLA on SM90");
}
......
......@@ -6,7 +6,7 @@
#include "common.h"
#include "params.h"
#include "sm90/decode/dense_qkvfp8/splitkv_mla.h"
#include "gfx93/decode/dense_qkvfp8/splitkv_mla.h"
#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "smxx/decode/combine/combine.h"
......@@ -188,7 +188,7 @@ dense_attn_decode_qkvfp8_interface(
params.stream = at::cuda::getCurrentCUDAStream().stream();
if (q_dtype == torch::kFloat8_e4m3fn) {
sm90::run_flash_splitkv_mla_qkvfp8_kernel<cutlass::float_e4m3_t>(params);
gfx93::run_flash_splitkv_mla_qkvfp8_kernel<cutlass::float_e4m3_t>(params);
} else {
TORCH_CHECK(false, "Unsupported dtype for dense MLA on SM90");
}
......
......@@ -4,7 +4,7 @@
#include "params.h"
#include "sm90/decode/sparse_fp8/splitkv_mla.h"
#include "gfx93/decode/sparse_fp8/splitkv_mla.h"
#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "smxx/decode/combine/combine.h"
......@@ -76,7 +76,7 @@ protected:
void run_(const SparseAttnDecodeParams &params, const std::vector<FeatureT> &required_features) override {
DISPATCH_MODEL_TYPE(params.model_type, MODEL_TYPE, [&]() {
DISPATCH_NUM_HEADS(params.h_q, NUM_HEADS, [&]() {
sm90::decode::sparse_fp8::run_flash_splitkv_mla_fp8_sparse_kernel<MODEL_TYPE, NUM_HEADS>(params);
gfx93::decode::sparse_fp8::run_flash_splitkv_mla_fp8_sparse_kernel<MODEL_TYPE, NUM_HEADS>(params);
});
});
}
......
......@@ -4,7 +4,7 @@
#include "params.h"
#include "sm90/prefill/sparse/phase1.h"
#include "gfx93/prefill/sparse/phase1.h"
enum class FwdFeatures : int {
......@@ -41,7 +41,7 @@ protected:
void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {
DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() {
DISPATCH_BOOLEAN_FLAG(params.topk_length != nullptr, HAVE_TOPK_LENGTH, [&]() {
sm90::fwd::run_fwd_phase1_kernel<HEAD_DIM_QK, HAVE_TOPK_LENGTH>(params);
gfx93::fwd::run_fwd_phase1_kernel<HEAD_DIM_QK, HAVE_TOPK_LENGTH>(params);
});
});
}
......
#include "../splitkv_mla.cuh"
#include "../splitkv_mla.h"
namespace sm90 {
namespace gfx93 {
template void run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(DenseAttnDecodeParams &params);
......
#include "../splitkv_mla.cuh"
#include "../splitkv_mla.h"
namespace sm90 {
namespace gfx93 {
#ifndef FLASH_MLA_DISABLE_FP16
template void run_flash_splitkv_mla_kernel<cutlass::half_t>(DenseAttnDecodeParams &params);
......
......@@ -8,7 +8,7 @@
#include "softmax.h"
using namespace cute;
namespace sm90 {
namespace gfx93 {
template<typename T>
__device__ void
......
......@@ -2,7 +2,7 @@
#include "params.h"
namespace sm90 {
namespace gfx93 {
template<typename InputT>
void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams &params);
......
#include "../splitkv_mla.cuh"
#include "../splitkv_mla.h"
namespace sm90 {
namespace gfx93 {
template void run_flash_splitkv_mla_kvfp8_kernel<cutlass::bfloat16_t>(DenseAttnDecodeParams_fp8 &params);
......
......@@ -8,7 +8,7 @@
#include "softmax.h"
using namespace cute;
namespace sm90 {
namespace gfx93 {
template<typename T>
__device__ void
......
......@@ -2,7 +2,7 @@
#include "params.h"
namespace sm90 {
namespace gfx93 {
template<typename InputT>
void run_flash_splitkv_mla_kvfp8_kernel(DenseAttnDecodeParams_fp8 &params);
......
#include "../splitkv_mla.cuh"
#include "../splitkv_mla.h"
namespace sm90 {
namespace gfx93 {
template void run_flash_splitkv_mla_qkvfp8_kernel<cutlass::float_e4m3_t>(DenseAttnDecodeParams_fp8 &params);
......
......@@ -8,7 +8,7 @@
#include "softmax.h"
using namespace cute;
namespace sm90 {
namespace gfx93 {
template<typename T>
__device__ void
......
......@@ -2,7 +2,7 @@
#include "params.h"
namespace sm90 {
namespace gfx93 {
template<typename InputT>
void run_flash_splitkv_mla_qkvfp8_kernel(DenseAttnDecodeParams_fp8 &params);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment