Commit f298a271 authored by zhanghj2's avatar zhanghj2
Browse files

sm90改为gfx93

parent a8393a04
......@@ -7,7 +7,7 @@
using namespace cute;
namespace sm90::decode::sparse_fp8 {
namespace gfx93::decode::sparse_fp8 {
static constexpr int HEAD_DIM_K = 576;
static constexpr int HEAD_DIM_V = 512;
......
......@@ -5,7 +5,7 @@
#include "defines.h"
namespace sm90::decode::sparse_fp8 {
namespace gfx93::decode::sparse_fp8 {
// struct fp8x8 {
// // __nv_fp8x4_e4m3 lo;
......
......@@ -7,7 +7,7 @@
using namespace cute;
namespace sm90::decode::sparse_fp8 {
namespace gfx93::decode::sparse_fp8 {
// // In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~1) to the actual row_idx
// // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
......
......@@ -10,7 +10,7 @@
using namespace cute;
namespace sm90::decode::sparse_fp8 {
namespace gfx93::decode::sparse_fp8 {
template<ModelType MODEL_TYPE, int NUM_HEADS>
class KernelTemplate {
......
#include "../splitkv_mla.cuh"
namespace sm90::decode::sparse_fp8 {
namespace gfx93::decode::sparse_fp8 {
template void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::MODEL1, 128>(const SparseAttnDecodeParams &params);
......
#include "../splitkv_mla.cuh"
namespace sm90::decode::sparse_fp8 {
namespace gfx93::decode::sparse_fp8 {
template void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::MODEL1, 16>(const SparseAttnDecodeParams &params);
......
#include "../splitkv_mla.cuh"
namespace sm90::decode::sparse_fp8 {
namespace gfx93::decode::sparse_fp8 {
template void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::MODEL1, 64>(const SparseAttnDecodeParams &params);
......
#include "../splitkv_mla.cuh"
namespace sm90::decode::sparse_fp8 {
namespace gfx93::decode::sparse_fp8 {
template void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::V32, 128>(const SparseAttnDecodeParams &params);
......
#include "../splitkv_mla.cuh"
namespace sm90::decode::sparse_fp8 {
namespace gfx93::decode::sparse_fp8 {
template void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::V32, 16>(const SparseAttnDecodeParams &params);
......
#include "../splitkv_mla.cuh"
namespace sm90::decode::sparse_fp8 {
namespace gfx93::decode::sparse_fp8 {
template void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::V32, 64>(const SparseAttnDecodeParams &params);
......
......@@ -18,7 +18,7 @@
#include "softmax.h"
using namespace cute;
namespace sm90::decode::sparse_fp8 {
namespace gfx93::decode::sparse_fp8 {
#define CUDART_L2E_F 1.442695041F
static constexpr float MAX_INIT_VAL = -1e30; // Prevent (-inf) - (-inf) = nan
......
......@@ -2,7 +2,7 @@
#include "params.h"
namespace sm90::decode::sparse_fp8 {
namespace gfx93::decode::sparse_fp8 {
template<ModelType MODEL_TYPE, int NUM_HEADS>
void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams &params);
......
......@@ -3,7 +3,7 @@
#include <cute/tensor.hpp>
// #include <cutlass/arch/barrier.h>
namespace sm90 {
namespace gfx93 {
// __forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst) {
// uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst);
......
......@@ -7,7 +7,7 @@
#include "defines.h"
#include "params.h"
namespace sm90::fwd {
namespace gfx93::fwd {
using namespace cute;
......
......@@ -4,7 +4,7 @@
#include "phase1.h"
namespace sm90 {
namespace gfx93 {
void run_fwd_kernel(const SparseAttnFwdParams& params) {
const bool have_topk_length = params.topk_length != nullptr;
......@@ -12,19 +12,19 @@ void run_fwd_kernel(const SparseAttnFwdParams& params) {
// Dispatch based on d_qk dimension and presence of topk_length
if (params.d_qk == 512) {
if (have_topk_length) {
sm90::fwd::run_fwd_phase1_kernel<512, true>(params);
gfx93::fwd::run_fwd_phase1_kernel<512, true>(params);
} else {
sm90::fwd::run_fwd_phase1_kernel<512, false>(params);
gfx93::fwd::run_fwd_phase1_kernel<512, false>(params);
}
} else if (params.d_qk == 576) {
if (have_topk_length) {
sm90::fwd::run_fwd_phase1_kernel<576, true>(params);
gfx93::fwd::run_fwd_phase1_kernel<576, true>(params);
} else {
sm90::fwd::run_fwd_phase1_kernel<576, false>(params);
gfx93::fwd::run_fwd_phase1_kernel<576, false>(params);
}
} else {
throw std::runtime_error("Unsupported d_qk value in sparse attention fwd kernel");
}
}
} // namespace sm90
} // namespace gfx93
......@@ -2,7 +2,7 @@
#include "params.h"
namespace sm90 {
namespace gfx93 {
void run_fwd_kernel(const SparseAttnFwdParams& params);
......
#include "../phase1.h"
#include "../phase1.cuh"
namespace sm90::fwd {
namespace gfx93::fwd {
// NOTE (intlsy): We instantiate run_fwd_phase1_kernel in two .cu files as functions with HAVE_TOPK_LENGTH
// = true / false respectively, to compile them in parallel.
......
#include "../phase1.h"
#include "../phase1.cuh"
namespace sm90::fwd {
namespace gfx93::fwd {
// NOTE (intlsy): We instantiate run_fwd_phase1_kernel in two .cu files as functions with HAVE_TOPK_LENGTH
// = true / false respectively, to compile them in parallel.
......
#include "../phase1.h"
#include "../phase1.cuh"
namespace sm90::fwd {
namespace gfx93::fwd {
template void run_fwd_phase1_kernel<576, false>(const SparseAttnFwdParams& params);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment