rocm_ops.cpp 2.57 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
// SPDX-License-Identifier: MIT
 
#include "activation.h"
#include "attention.h"
#include "attention_ragged.h"
#include "attention_ck.h"
#include "attention_asm.h"
#include "attention_asm_mla.h"
#include "cache.h"
#include "custom_all_reduce.h"
#include "communication_asm.h"
// #include "gemm_a8w8_blockscale.h"
#include "custom.h"
#include "moe_op.h"
#include "moe_sorting.h"
#include "moe_sum.h"
#include "moe_utils.h"
#include "norm.h"
#include "pos_encoding.h"
#include "rmsnorm.h"
#include "smoothquant.h"
#include "aiter_operator.h"
#include "asm_gemm_a8w8.h"
#include <torch/extension.h>
// #include "gemm_a8w8.h"
// #include "batched_gemm_a8w8.h"
#include "quant.h"
#include "moe_ck.h"
#include "moe_asm.h"
#include "awq_gemm_asm.h"
#include "awq_dq_asm.h"
#include "rope.h"
#include "rocsolgemm.cuh"
#include "hipbsolgemm.cuh"
#include "aiter_enum.h"
#include "aiter_unary.h"

#include "torch/mha_batch_prefill.h"
#include "torch/mha_varlen_fwd.h"
#include "torch/mha_varlen_bwd.h"
#include "torch/mha_bwd.h"
#include "torch/mha_fwd.h"
#include "torch/mha_v3_fwd.h"
#include "torch/mha_v3_bwd.h"
#include "torch/mha_v3_varlen_bwd.h"

#include "rocm_ops.hpp"

#ifdef PREBUILD_KERNELS
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
      // remove *TUNE* , MHA*
      // GEMM_A8W8_TUNE_PYBIND;
      AITER_ENUM_PYBIND;
      //ck module
      RMSNORM_PYBIND;
      // MHA_VARLEN_FWD_PYBIND;
      // MHA_VARLEN_BWD_PYBIND;
      // MHA_FWD_PYBIND;
      // MHA_BWD_PYBIND;
      // MHA_BATCH_PREFILL_PYBIND;
      // MHA_FWD_ASM_PYBIND
      // MHA_BWD_ASM_PYBIND;
      // MHA_VARLEN_BWD_ASM_PYBIND;
      // GEMM_A8W8_PYBIND;
      // CUSTOM_PYBIND;
      // SMOOTHQUANT_PYBIND;
      // BATCHED_GEMM_A8W8_PYBIND;
      MOE_CK_PYBIND;
      MOE_C_PYBIND;
      // BATCHED_GEMM_A8W8_TUNE_PYBIND;
      // GEMM_A8W8_ASM_PYBIND;
      ACTIVATION_PYBIND;
      // ATTENTION_ASM_MLA_PYBIND;
      // ATTENTION_CK_PYBIND;
      MOE_SORTING_PYBIND;
      MOE_SUM_PYBIND;
      NORM_PYBIND;
      POS_ENCODING_PYBIND;
      // ATTENTION_PYBIND;
      // MOE_CK_2STAGES_PYBIND;
      MOE_UTILS_PYBIND;
      MOE_ASM_2STAGES_PYBIND;
      AWQ_GEMM_ASM_PYBIND;
      AWQ_DQ_ASM_PYBIND;
      QUANT_PYBIND;
      // ATTENTION_ASM_PYBIND;
      // ATTENTION_RAGGED_PYBIND;
      // MOE_OP_PYBIND;
      ROPE_GENERAL_FWD_PYBIND;
      ROPE_GENERAL_BWD_PYBIND;
      ROPE_POS_FWD_PYBIND;
      // GEMM_A8W8_BLOCKSCALE_TUNE_PYBIND;
      // GEMM_A8W8_BLOCKSCALE_PYBIND;
      AITER_OPERATOR_PYBIND;
      AITER_UNARY_PYBIND;
      CUSTOM_ALL_REDUCE_PYBIND;
      // CACHE_PYBIND;
      HIPBSOLGEMM_PYBIND;
      ROCSOLGEMM_PYBIND;
}
#endif