Commit 59487e20 authored by zhanghj2's avatar zhanghj2
Browse files

smxx修改为gfx9

parent f298a271
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
#include "params.h" #include "params.h"
#include "gfx93/decode/dense/splitkv_mla.h" #include "gfx93/decode/dense/splitkv_mla.h"
#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h" #include "gfx9/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "smxx/decode/combine/combine.h" #include "gfx9/decode/combine/combine.h"
static std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>> static std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>>
dense_attn_decode_interface( dense_attn_decode_interface(
...@@ -110,7 +110,7 @@ dense_attn_decode_interface( ...@@ -110,7 +110,7 @@ dense_attn_decode_interface(
num_sm_parts, num_sm_parts,
at::cuda::getCurrentCUDAStream().stream() at::cuda::getCurrentCUDAStream().stream()
}; };
smxx::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params); gfx9::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params);
} else { } else {
KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32); KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32);
KU_CHECK_DTYPE(num_splits, torch::kInt32); KU_CHECK_DTYPE(num_splits, torch::kInt32);
...@@ -207,10 +207,10 @@ dense_attn_decode_interface( ...@@ -207,10 +207,10 @@ dense_attn_decode_interface(
}; };
if (q_dtype == torch::kBFloat16) { if (q_dtype == torch::kBFloat16) {
smxx::decode::run_flash_mla_combine_kernel<cutlass::bfloat16_t>(combine_params); gfx9::decode::run_flash_mla_combine_kernel<cutlass::bfloat16_t>(combine_params);
} else if (q_dtype == torch::kHalf) { } else if (q_dtype == torch::kHalf) {
#ifndef FLASH_MLA_DISABLE_FP16 #ifndef FLASH_MLA_DISABLE_FP16
smxx::decode::run_flash_mla_combine_kernel<cutlass::half_t>(combine_params); gfx9::decode::run_flash_mla_combine_kernel<cutlass::half_t>(combine_params);
#endif #endif
} else { } else {
TORCH_CHECK(false, "Unsupported tensor dtype for query"); TORCH_CHECK(false, "Unsupported tensor dtype for query");
......
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
#include "params.h" #include "params.h"
#include "gfx93/decode/dense_kvfp8/splitkv_mla.h" #include "gfx93/decode/dense_kvfp8/splitkv_mla.h"
#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h" #include "gfx9/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "smxx/decode/combine/combine.h" #include "gfx9/decode/combine/combine.h"
static std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>> static std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>>
dense_attn_decode_kvfp8_interface( dense_attn_decode_kvfp8_interface(
...@@ -123,7 +123,7 @@ dense_attn_decode_kvfp8_interface( ...@@ -123,7 +123,7 @@ dense_attn_decode_kvfp8_interface(
num_sm_parts, num_sm_parts,
at::cuda::getCurrentCUDAStream().stream() at::cuda::getCurrentCUDAStream().stream()
}; };
smxx::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params); gfx9::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params);
} else { } else {
KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32); KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32);
KU_CHECK_DTYPE(num_splits, torch::kInt32); KU_CHECK_DTYPE(num_splits, torch::kInt32);
...@@ -215,7 +215,7 @@ dense_attn_decode_kvfp8_interface( ...@@ -215,7 +215,7 @@ dense_attn_decode_kvfp8_interface(
at::cuda::getCurrentCUDAStream().stream() at::cuda::getCurrentCUDAStream().stream()
}; };
smxx::decode::run_flash_mla_combine_kernel<cutlass::bfloat16_t>(combine_params); gfx9::decode::run_flash_mla_combine_kernel<cutlass::bfloat16_t>(combine_params);
out = out.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk, head_size_v}).transpose(1, 2) out = out.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk, head_size_v}).transpose(1, 2)
......
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
#include "params.h" #include "params.h"
#include "gfx93/decode/dense_qkvfp8/splitkv_mla.h" #include "gfx93/decode/dense_qkvfp8/splitkv_mla.h"
#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h" #include "gfx9/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "smxx/decode/combine/combine.h" #include "gfx9/decode/combine/combine.h"
static std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>> static std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>>
dense_attn_decode_qkvfp8_interface( dense_attn_decode_qkvfp8_interface(
...@@ -123,7 +123,7 @@ dense_attn_decode_qkvfp8_interface( ...@@ -123,7 +123,7 @@ dense_attn_decode_qkvfp8_interface(
num_sm_parts, num_sm_parts,
at::cuda::getCurrentCUDAStream().stream() at::cuda::getCurrentCUDAStream().stream()
}; };
smxx::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params); gfx9::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params);
} else { } else {
KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32); KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32);
KU_CHECK_DTYPE(num_splits, torch::kInt32); KU_CHECK_DTYPE(num_splits, torch::kInt32);
...@@ -215,7 +215,7 @@ dense_attn_decode_qkvfp8_interface( ...@@ -215,7 +215,7 @@ dense_attn_decode_qkvfp8_interface(
at::cuda::getCurrentCUDAStream().stream() at::cuda::getCurrentCUDAStream().stream()
}; };
smxx::decode::run_flash_mla_combine_kernel<cutlass::bfloat16_t>(combine_params); gfx9::decode::run_flash_mla_combine_kernel<cutlass::bfloat16_t>(combine_params);
out = out.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk, head_size_v}).transpose(1, 2) out = out.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk, head_size_v}).transpose(1, 2)
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
#include "params.h" #include "params.h"
#include "gfx93/decode/sparse_fp8/splitkv_mla.h" #include "gfx93/decode/sparse_fp8/splitkv_mla.h"
#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h" #include "gfx9/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "smxx/decode/combine/combine.h" #include "gfx9/decode/combine/combine.h"
// Feature set of sparse decoding kernels // Feature set of sparse decoding kernels
enum class DecodeFeatures : int { enum class DecodeFeatures : int {
...@@ -328,7 +328,7 @@ sparse_attn_decode_interface( ...@@ -328,7 +328,7 @@ sparse_attn_decode_interface(
impl_meta.num_sm_parts, impl_meta.num_sm_parts,
at::cuda::getCurrentCUDAStream().stream() at::cuda::getCurrentCUDAStream().stream()
}; };
smxx::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params); gfx9::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params);
} }
// Stick the metadata pointers to `params` // Stick the metadata pointers to `params`
KU_CHECK_DEVICE(tile_scheduler_metadata); KU_CHECK_DEVICE(tile_scheduler_metadata);
...@@ -379,7 +379,7 @@ sparse_attn_decode_interface( ...@@ -379,7 +379,7 @@ sparse_attn_decode_interface(
ku::get_optional_tensor_ptr<float>(attn_sink), ku::get_optional_tensor_ptr<float>(attn_sink),
at::cuda::getCurrentCUDAStream().stream() at::cuda::getCurrentCUDAStream().stream()
}; };
smxx::decode::run_flash_mla_combine_kernel<bf16>(combine_params); gfx9::decode::run_flash_mla_combine_kernel<bf16>(combine_params);
delete impl; delete impl;
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
using namespace cute; using namespace cute;
namespace smxx::decode { namespace gfx9::decode {
template<typename ElementT, int HEAD_DIM_V, int BLOCK_SIZE_M, int MAX_SPLITS, int NUM_THREADS> template<typename ElementT, int HEAD_DIM_V, int BLOCK_SIZE_M, int MAX_SPLITS, int NUM_THREADS>
__global__ void __launch_bounds__(NUM_THREADS, 1) __global__ void __launch_bounds__(NUM_THREADS, 1)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "params.h" #include "params.h"
namespace smxx::decode { namespace gfx9::decode {
template<typename ElementT> template<typename ElementT>
void run_flash_mla_combine_kernel(CombineParams &params); void run_flash_mla_combine_kernel(CombineParams &params);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "utils.h" #include "utils.h"
namespace smxx::decode { namespace gfx9::decode {
__global__ void __launch_bounds__(64, 1) __global__ void __launch_bounds__(64, 1)
get_mla_metadata_kernel(const GetDecodeSchedMetaParams params) { get_mla_metadata_kernel(const GetDecodeSchedMetaParams params) {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "params.h" #include "params.h"
namespace smxx::decode { namespace gfx9::decode {
void run_get_decoding_sched_meta_kernel(GetDecodeSchedMetaParams &params); void run_get_decoding_sched_meta_kernel(GetDecodeSchedMetaParams &params);
......
...@@ -52,8 +52,8 @@ ext_modules.append( ...@@ -52,8 +52,8 @@ ext_modules.append(
"csrc/api/api.cpp", "csrc/api/api.cpp",
# # Misc kernels for decoding # # Misc kernels for decoding
"csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu", "csrc/gfx9/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu",
"csrc/smxx/decode/combine/combine.cu", "csrc/gfx9/decode/combine/combine.cu",
# # gfx93 dense decode # # gfx93 dense decode
"csrc/gfx93/decode/dense/instantiations/fp16.cu", "csrc/gfx93/decode/dense/instantiations/fp16.cu",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment