Commit f298a271 authored by zhanghj2's avatar zhanghj2
Browse files

sm90改为gfx93

parent a8393a04
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
using namespace cute; using namespace cute;
namespace sm90::decode::sparse_fp8 { namespace gfx93::decode::sparse_fp8 {
static constexpr int HEAD_DIM_K = 576; static constexpr int HEAD_DIM_K = 576;
static constexpr int HEAD_DIM_V = 512; static constexpr int HEAD_DIM_V = 512;
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "defines.h" #include "defines.h"
namespace sm90::decode::sparse_fp8 { namespace gfx93::decode::sparse_fp8 {
// struct fp8x8 { // struct fp8x8 {
// // __nv_fp8x4_e4m3 lo; // // __nv_fp8x4_e4m3 lo;
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
using namespace cute; using namespace cute;
namespace sm90::decode::sparse_fp8 { namespace gfx93::decode::sparse_fp8 {
// // In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~1) to the actual row_idx // // In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~1) to the actual row_idx
// // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a // // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
using namespace cute; using namespace cute;
namespace sm90::decode::sparse_fp8 { namespace gfx93::decode::sparse_fp8 {
template<ModelType MODEL_TYPE, int NUM_HEADS> template<ModelType MODEL_TYPE, int NUM_HEADS>
class KernelTemplate { class KernelTemplate {
......
#include "../splitkv_mla.cuh" #include "../splitkv_mla.cuh"
namespace sm90::decode::sparse_fp8 { namespace gfx93::decode::sparse_fp8 {
template void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::MODEL1, 128>(const SparseAttnDecodeParams &params); template void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::MODEL1, 128>(const SparseAttnDecodeParams &params);
......
#include "../splitkv_mla.cuh" #include "../splitkv_mla.cuh"
namespace sm90::decode::sparse_fp8 { namespace gfx93::decode::sparse_fp8 {
template void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::MODEL1, 16>(const SparseAttnDecodeParams &params); template void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::MODEL1, 16>(const SparseAttnDecodeParams &params);
......
#include "../splitkv_mla.cuh" #include "../splitkv_mla.cuh"
namespace sm90::decode::sparse_fp8 { namespace gfx93::decode::sparse_fp8 {
template void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::MODEL1, 64>(const SparseAttnDecodeParams &params); template void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::MODEL1, 64>(const SparseAttnDecodeParams &params);
......
#include "../splitkv_mla.cuh" #include "../splitkv_mla.cuh"
namespace sm90::decode::sparse_fp8 { namespace gfx93::decode::sparse_fp8 {
template void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::V32, 128>(const SparseAttnDecodeParams &params); template void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::V32, 128>(const SparseAttnDecodeParams &params);
......
#include "../splitkv_mla.cuh" #include "../splitkv_mla.cuh"
namespace sm90::decode::sparse_fp8 { namespace gfx93::decode::sparse_fp8 {
template void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::V32, 16>(const SparseAttnDecodeParams &params); template void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::V32, 16>(const SparseAttnDecodeParams &params);
......
#include "../splitkv_mla.cuh" #include "../splitkv_mla.cuh"
namespace sm90::decode::sparse_fp8 { namespace gfx93::decode::sparse_fp8 {
template void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::V32, 64>(const SparseAttnDecodeParams &params); template void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::V32, 64>(const SparseAttnDecodeParams &params);
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include "softmax.h" #include "softmax.h"
using namespace cute; using namespace cute;
namespace sm90::decode::sparse_fp8 { namespace gfx93::decode::sparse_fp8 {
#define CUDART_L2E_F 1.442695041F #define CUDART_L2E_F 1.442695041F
static constexpr float MAX_INIT_VAL = -1e30; // Prevent (-inf) - (-inf) = nan static constexpr float MAX_INIT_VAL = -1e30; // Prevent (-inf) - (-inf) = nan
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "params.h" #include "params.h"
namespace sm90::decode::sparse_fp8 { namespace gfx93::decode::sparse_fp8 {
template<ModelType MODEL_TYPE, int NUM_HEADS> template<ModelType MODEL_TYPE, int NUM_HEADS>
void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams &params); void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams &params);
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include <cute/tensor.hpp> #include <cute/tensor.hpp>
// #include <cutlass/arch/barrier.h> // #include <cutlass/arch/barrier.h>
namespace sm90 { namespace gfx93 {
// __forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst) { // __forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst) {
// uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst); // uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst);
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "defines.h" #include "defines.h"
#include "params.h" #include "params.h"
namespace sm90::fwd { namespace gfx93::fwd {
using namespace cute; using namespace cute;
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include "phase1.h" #include "phase1.h"
namespace sm90 { namespace gfx93 {
void run_fwd_kernel(const SparseAttnFwdParams& params) { void run_fwd_kernel(const SparseAttnFwdParams& params) {
const bool have_topk_length = params.topk_length != nullptr; const bool have_topk_length = params.topk_length != nullptr;
...@@ -12,19 +12,19 @@ void run_fwd_kernel(const SparseAttnFwdParams& params) { ...@@ -12,19 +12,19 @@ void run_fwd_kernel(const SparseAttnFwdParams& params) {
// Dispatch based on d_qk dimension and presence of topk_length // Dispatch based on d_qk dimension and presence of topk_length
if (params.d_qk == 512) { if (params.d_qk == 512) {
if (have_topk_length) { if (have_topk_length) {
sm90::fwd::run_fwd_phase1_kernel<512, true>(params); gfx93::fwd::run_fwd_phase1_kernel<512, true>(params);
} else { } else {
sm90::fwd::run_fwd_phase1_kernel<512, false>(params); gfx93::fwd::run_fwd_phase1_kernel<512, false>(params);
} }
} else if (params.d_qk == 576) { } else if (params.d_qk == 576) {
if (have_topk_length) { if (have_topk_length) {
sm90::fwd::run_fwd_phase1_kernel<576, true>(params); gfx93::fwd::run_fwd_phase1_kernel<576, true>(params);
} else { } else {
sm90::fwd::run_fwd_phase1_kernel<576, false>(params); gfx93::fwd::run_fwd_phase1_kernel<576, false>(params);
} }
} else { } else {
throw std::runtime_error("Unsupported d_qk value in sparse attention fwd kernel"); throw std::runtime_error("Unsupported d_qk value in sparse attention fwd kernel");
} }
} }
} // namespace sm90 } // namespace gfx93
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "params.h" #include "params.h"
namespace sm90 { namespace gfx93 {
void run_fwd_kernel(const SparseAttnFwdParams& params); void run_fwd_kernel(const SparseAttnFwdParams& params);
......
#include "../phase1.h" #include "../phase1.h"
#include "../phase1.cuh" #include "../phase1.cuh"
namespace sm90::fwd { namespace gfx93::fwd {
// NOTE (intlsy): We instantiate run_fwd_phase1_kernel in two .cu files as functions with HAVE_TOPK_LENGTH // NOTE (intlsy): We instantiate run_fwd_phase1_kernel in two .cu files as functions with HAVE_TOPK_LENGTH
// = true / false respectively, to compile them in parallel. // = true / false respectively, to compile them in parallel.
......
#include "../phase1.h" #include "../phase1.h"
#include "../phase1.cuh" #include "../phase1.cuh"
namespace sm90::fwd { namespace gfx93::fwd {
// NOTE (intlsy): We instantiate run_fwd_phase1_kernel in two .cu files as functions with HAVE_TOPK_LENGTH // NOTE (intlsy): We instantiate run_fwd_phase1_kernel in two .cu files as functions with HAVE_TOPK_LENGTH
// = true / false respectively, to compile them in parallel. // = true / false respectively, to compile them in parallel.
......
#include "../phase1.h" #include "../phase1.h"
#include "../phase1.cuh" #include "../phase1.cuh"
namespace sm90::fwd { namespace gfx93::fwd {
template void run_fwd_phase1_kernel<576, false>(const SparseAttnFwdParams& params); template void run_fwd_phase1_kernel<576, false>(const SparseAttnFwdParams& params);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment