ops.h 7.11 KB
Newer Older
1
2
#pragma once

3
#include <optional>
4
#include <torch/library.h>
5

6
7
void paged_attention_v1(
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
8
9
10
11
12
13
14
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
    const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
    const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);
15

16
17
18
void paged_attention_v2(
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
19
20
21
22
23
24
25
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
    const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
    const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);
26
27

void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
28
              double epsilon);
29
30

void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
31
                        torch::Tensor& weight, double epsilon);
32
33

void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
34
                      torch::Tensor& key, int64_t head_size,
35
36
37
                      torch::Tensor& cos_sin_cache, bool is_neox);

void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
38
                              torch::Tensor& key, int64_t head_size,
39
                              torch::Tensor& cos_sin_cache, bool is_neox,
40
                              int64_t rot_dim,
41
42
43
44
45
46
47
48
49
50
51
                              torch::Tensor& cos_sin_cache_offsets);

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_new(torch::Tensor& out, torch::Tensor& input);

void gelu_fast(torch::Tensor& out, torch::Tensor& input);
52

53
54
void gelu_quick(torch::Tensor& out, torch::Tensor& input);

55
#ifndef USE_ROCM
56
57
58
59
60
61
62
63
64
65
66
67
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
                        const torch::Tensor& codebooks,
                        const torch::Tensor& scales,
                        const torch::Tensor& codebook_partition_sizes,
                        const std::optional<torch::Tensor>& bias);

torch::Tensor aqlm_dequant(const torch::Tensor& codes,
                           const torch::Tensor& codebooks,
                           const torch::Tensor& codebook_partition_sizes);

torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
                       torch::Tensor _scaling_factors, torch::Tensor _zeros,
68
                       int64_t split_k_iters);
69
70
71

torch::Tensor awq_dequantize(torch::Tensor _kernel,
                             torch::Tensor _scaling_factors,
72
73
                             torch::Tensor _zeros, int64_t split_k_iters,
                             int64_t thx, int64_t thy);
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                          torch::Tensor& b_scales, torch::Tensor& workspace,
                          int64_t size_m, int64_t size_n, int64_t size_k);

torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                                  torch::Tensor& b_meta,
                                  torch::Tensor& b_scales,
                                  torch::Tensor& workspace, int64_t num_bits,
                                  int64_t size_m, int64_t size_n,
                                  int64_t size_k);

torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                               torch::Tensor& b_scales, torch::Tensor& g_idx,
                               torch::Tensor& perm, torch::Tensor& workspace,
                               int64_t num_bits, int64_t size_m, int64_t size_n,
                               int64_t size_k, bool is_k_full);

torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
                                 int64_t size_k, int64_t size_n,
                                 int64_t num_bits);

96
97
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);

98
99
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
                       torch::Tensor const& b, torch::Tensor const& a_scales,
100
101
                       torch::Tensor const& b_scales,
                       c10::optional<torch::Tensor> const& bias);
102

103
#endif
104

105
106
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
                              torch::Tensor const& scale);
107

108
109
110
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
                               torch::Tensor& scales);

111
112
113
114
115
116
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
                     torch::Tensor lookup_table);

torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
                        torch::Tensor b_gptq_qzeros,
                        torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
117
                        bool use_exllama, int64_t bit);
118

119
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
120
121
122
123
124
125
126

void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
                             torch::Tensor& scale);

void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
                              torch::Tensor& scale);

127
128
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
                          int64_t block_size, torch::Tensor sorted_token_ids,
129
130
                          torch::Tensor experts_ids,
                          torch::Tensor num_tokens_post_pad);
131
132

#ifndef USE_ROCM
133
using fptr_t = int64_t;
134
135
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
                      const std::vector<std::string>& handles,
136
                      const std::vector<int64_t>& offsets, int64_t rank,
137
                      bool full_nvlink);
138
bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
139
                      bool full_nvlink);
140
141
142
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
                      torch::Tensor& out);
143
void dispose(fptr_t _fa);
144
int64_t meta_size();
145
146
147
void register_buffer(fptr_t _fa, torch::Tensor& t,
                     const std::vector<std::string>& handles,
                     const std::vector<int64_t>& offsets);
148
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
149
150
151
    fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
                            const std::vector<std::vector<int64_t>>& offsets);
152
#endif