Commit f298a271 authored by zhanghj2's avatar zhanghj2
Browse files

sm90改为gfx93

parent a8393a04
#include "../phase1.h" #include "../phase1.h"
#include "../phase1.cuh" #include "../phase1.cuh"
namespace sm90::fwd { namespace gfx93::fwd {
template void run_fwd_phase1_kernel<576, true>(const SparseAttnFwdParams& params); template void run_fwd_phase1_kernel<576, true>(const SparseAttnFwdParams& params);
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "softmax.h" #include "softmax.h"
#include "../../helpers.h" #include "../../helpers.h"
namespace sm90::fwd { namespace gfx93::fwd {
#define CUDART_L2E_F 1.442695041F #define CUDART_L2E_F 1.442695041F
using namespace cute; using namespace cute;
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "../../../params.h" #include "../../../params.h"
namespace sm90::fwd { namespace gfx93::fwd {
template<int D_QK, bool HAVE_TOPK_LENGTH> template<int D_QK, bool HAVE_TOPK_LENGTH>
void run_fwd_phase1_kernel(const SparseAttnFwdParams& params); void run_fwd_phase1_kernel(const SparseAttnFwdParams& params);
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
#include "common.h" #include "common.h"
// #include "sm80/intrinsics.cuh" // #include "sm80/intrinsics.cuh"
// #include "sm80/helpers.cuh" // #include "sm80/helpers.cuh"
// #include "sm90/intrinsics.cuh" // #include "gfx93/intrinsics.cuh"
// #include "sm90/helpers.cuh" // #include "gfx93/helpers.cuh"
// #include "sm100/intrinsics.cuh" // #include "sm100/intrinsics.cuh"
// #include "sm100/helpers.cuh" // #include "sm100/helpers.cuh"
// #include "sm100/gemm.cuh" // #include "sm100/gemm.cuh"
......
...@@ -55,30 +55,30 @@ ext_modules.append( ...@@ -55,30 +55,30 @@ ext_modules.append(
"csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu", "csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu",
"csrc/smxx/decode/combine/combine.cu", "csrc/smxx/decode/combine/combine.cu",
# # sm90 dense decode # # gfx93 dense decode
"csrc/sm90/decode/dense/instantiations/fp16.cu", "csrc/gfx93/decode/dense/instantiations/fp16.cu",
"csrc/sm90/decode/dense/instantiations/bf16.cu", "csrc/gfx93/decode/dense/instantiations/bf16.cu",
## sm90 dense qkvfp8 decode ## gfx93 dense qkvfp8 decode
"csrc/sm90/decode/dense_qkvfp8/instantiations/fp8e4m3.cu", "csrc/gfx93/decode/dense_qkvfp8/instantiations/fp8e4m3.cu",
## sm90 dense kvfp8 decode ## gfx93 dense kvfp8 decode
"csrc/sm90/decode/dense_kvfp8/instantiations/kvfp8.cu", "csrc/gfx93/decode/dense_kvfp8/instantiations/kvfp8.cu",
# # sm90 sparse decode # # gfx93 sparse decode
"csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h16.cu", "csrc/gfx93/decode/sparse_fp8/instantiations/model1_persistent_h16.cu",
"csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu", "csrc/gfx93/decode/sparse_fp8/instantiations/model1_persistent_h64.cu",
"csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu", "csrc/gfx93/decode/sparse_fp8/instantiations/model1_persistent_h128.cu",
"csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h16.cu", "csrc/gfx93/decode/sparse_fp8/instantiations/v32_persistent_h16.cu",
"csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu", "csrc/gfx93/decode/sparse_fp8/instantiations/v32_persistent_h64.cu",
"csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu", "csrc/gfx93/decode/sparse_fp8/instantiations/v32_persistent_h128.cu",
# # sm90 sparse prefill # # gfx93 sparse prefill
"csrc/sm90/prefill/sparse/fwd.cu", "csrc/gfx93/prefill/sparse/fwd.cu",
"csrc/sm90/prefill/sparse/instantiations/phase1_k512.cu", "csrc/gfx93/prefill/sparse/instantiations/phase1_k512.cu",
"csrc/sm90/prefill/sparse/instantiations/phase1_k512_topklen.cu", "csrc/gfx93/prefill/sparse/instantiations/phase1_k512_topklen.cu",
"csrc/sm90/prefill/sparse/instantiations/phase1_k576.cu", "csrc/gfx93/prefill/sparse/instantiations/phase1_k576.cu",
"csrc/sm90/prefill/sparse/instantiations/phase1_k576_topklen.cu", "csrc/gfx93/prefill/sparse/instantiations/phase1_k576_topklen.cu",
], ],
extra_compile_args={ extra_compile_args={
...@@ -99,7 +99,7 @@ ext_modules.append( ...@@ -99,7 +99,7 @@ ext_modules.append(
include_dirs=[ include_dirs=[
Path(this_dir) / "csrc", Path(this_dir) / "csrc",
Path(this_dir) / "csrc" / "kerutils" / "include", # TODO Remove me Path(this_dir) / "csrc" / "kerutils" / "include", # TODO Remove me
Path(this_dir) / "csrc" / "sm90", Path(this_dir) / "csrc" / "gfx93",
Path(this_dir) / "csrc" / "cutlass" / "include", Path(this_dir) / "csrc" / "cutlass" / "include",
], ],
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment