Commit 59487e20 authored by zhanghj2's avatar zhanghj2
Browse files

smxx修改为gfx9

parent f298a271
......@@ -7,8 +7,8 @@
#include "params.h"
#include "gfx93/decode/dense/splitkv_mla.h"
#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "smxx/decode/combine/combine.h"
#include "gfx9/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "gfx9/decode/combine/combine.h"
static std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>>
dense_attn_decode_interface(
......@@ -110,7 +110,7 @@ dense_attn_decode_interface(
num_sm_parts,
at::cuda::getCurrentCUDAStream().stream()
};
smxx::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params);
gfx9::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params);
} else {
KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32);
KU_CHECK_DTYPE(num_splits, torch::kInt32);
......@@ -207,10 +207,10 @@ dense_attn_decode_interface(
};
if (q_dtype == torch::kBFloat16) {
smxx::decode::run_flash_mla_combine_kernel<cutlass::bfloat16_t>(combine_params);
gfx9::decode::run_flash_mla_combine_kernel<cutlass::bfloat16_t>(combine_params);
} else if (q_dtype == torch::kHalf) {
#ifndef FLASH_MLA_DISABLE_FP16
smxx::decode::run_flash_mla_combine_kernel<cutlass::half_t>(combine_params);
gfx9::decode::run_flash_mla_combine_kernel<cutlass::half_t>(combine_params);
#endif
} else {
TORCH_CHECK(false, "Unsupported tensor dtype for query");
......
......@@ -7,8 +7,8 @@
#include "params.h"
#include "gfx93/decode/dense_kvfp8/splitkv_mla.h"
#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "smxx/decode/combine/combine.h"
#include "gfx9/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "gfx9/decode/combine/combine.h"
static std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>>
dense_attn_decode_kvfp8_interface(
......@@ -123,7 +123,7 @@ dense_attn_decode_kvfp8_interface(
num_sm_parts,
at::cuda::getCurrentCUDAStream().stream()
};
smxx::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params);
gfx9::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params);
} else {
KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32);
KU_CHECK_DTYPE(num_splits, torch::kInt32);
......@@ -215,7 +215,7 @@ dense_attn_decode_kvfp8_interface(
at::cuda::getCurrentCUDAStream().stream()
};
smxx::decode::run_flash_mla_combine_kernel<cutlass::bfloat16_t>(combine_params);
gfx9::decode::run_flash_mla_combine_kernel<cutlass::bfloat16_t>(combine_params);
out = out.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk, head_size_v}).transpose(1, 2)
......
......@@ -7,8 +7,8 @@
#include "params.h"
#include "gfx93/decode/dense_qkvfp8/splitkv_mla.h"
#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "smxx/decode/combine/combine.h"
#include "gfx9/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "gfx9/decode/combine/combine.h"
static std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>>
dense_attn_decode_qkvfp8_interface(
......@@ -123,7 +123,7 @@ dense_attn_decode_qkvfp8_interface(
num_sm_parts,
at::cuda::getCurrentCUDAStream().stream()
};
smxx::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params);
gfx9::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params);
} else {
KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32);
KU_CHECK_DTYPE(num_splits, torch::kInt32);
......@@ -215,7 +215,7 @@ dense_attn_decode_qkvfp8_interface(
at::cuda::getCurrentCUDAStream().stream()
};
smxx::decode::run_flash_mla_combine_kernel<cutlass::bfloat16_t>(combine_params);
gfx9::decode::run_flash_mla_combine_kernel<cutlass::bfloat16_t>(combine_params);
out = out.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk, head_size_v}).transpose(1, 2)
......
......@@ -5,8 +5,8 @@
#include "params.h"
#include "gfx93/decode/sparse_fp8/splitkv_mla.h"
#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "smxx/decode/combine/combine.h"
#include "gfx9/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "gfx9/decode/combine/combine.h"
// Feature set of sparse decoding kernels
enum class DecodeFeatures : int {
......@@ -328,7 +328,7 @@ sparse_attn_decode_interface(
impl_meta.num_sm_parts,
at::cuda::getCurrentCUDAStream().stream()
};
smxx::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params);
gfx9::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params);
}
// Stick the metadata pointers to `params`
KU_CHECK_DEVICE(tile_scheduler_metadata);
......@@ -379,7 +379,7 @@ sparse_attn_decode_interface(
ku::get_optional_tensor_ptr<float>(attn_sink),
at::cuda::getCurrentCUDAStream().stream()
};
smxx::decode::run_flash_mla_combine_kernel<bf16>(combine_params);
gfx9::decode::run_flash_mla_combine_kernel<bf16>(combine_params);
delete impl;
......
......@@ -14,7 +14,7 @@
using namespace cute;
namespace smxx::decode {
namespace gfx9::decode {
template<typename ElementT, int HEAD_DIM_V, int BLOCK_SIZE_M, int MAX_SPLITS, int NUM_THREADS>
__global__ void __launch_bounds__(NUM_THREADS, 1)
......
......@@ -2,7 +2,7 @@
#include "params.h"
namespace smxx::decode {
namespace gfx9::decode {
template<typename ElementT>
void run_flash_mla_combine_kernel(CombineParams &params);
......
......@@ -6,7 +6,7 @@
#include "utils.h"
namespace smxx::decode {
namespace gfx9::decode {
__global__ void __launch_bounds__(64, 1)
get_mla_metadata_kernel(const GetDecodeSchedMetaParams params) {
......
......@@ -2,7 +2,7 @@
#include "params.h"
namespace smxx::decode {
namespace gfx9::decode {
void run_get_decoding_sched_meta_kernel(GetDecodeSchedMetaParams &params);
......
......@@ -52,8 +52,8 @@ ext_modules.append(
"csrc/api/api.cpp",
# # Misc kernels for decoding
"csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu",
"csrc/smxx/decode/combine/combine.cu",
"csrc/gfx9/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu",
"csrc/gfx9/decode/combine/combine.cu",
# # gfx93 dense decode
"csrc/gfx93/decode/dense/instantiations/fp16.cu",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment