ext.hh 1.77 KB
Newer Older
1
2
3
4
5
6
7
8
#pragma once

#include <torch/library.h>

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// No support for async
#else

9
10
11
torch::Tensor awq_marlin_repack(torch::Tensor &b_q_weight, int64_t size_k,
                                int64_t size_n, int64_t num_bits);

12
torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
13
14
15
16
17
                               torch::Tensor &b_scales, torch::Tensor &b_zeros,
                               torch::Tensor &g_idx, torch::Tensor &perm,
                               torch::Tensor &workspace, int64_t num_bits,
                               int64_t size_m, int64_t size_n, int64_t size_k,
                               bool is_k_full, bool has_zp);
18

19
20
21
22
23
24
25
torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
                                  torch::Tensor &b_meta,
                                  torch::Tensor &b_scales,
                                  torch::Tensor &workspace, int64_t num_bits,
                                  int64_t size_m, int64_t size_n,
                                  int64_t size_k);

26
27
28
29
30
31
32
33
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
                                 int64_t size_k, int64_t size_n,
                                 int64_t num_bits);

torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
                          torch::Tensor &b_scales, torch::Tensor &workspace,
                          int64_t size_m, int64_t size_n, int64_t size_k);

34
35
torch::Tensor fp8_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
                              torch::Tensor &b_scales, torch::Tensor &workspace,
36
37
38
                              int64_t num_bits, int64_t size_m, int64_t size_n,
                              int64_t size_k);

39
#endif