ops.h 6.46 KB
Newer Older
1
2
#pragma once

3
4
#include <torch/extension.h>

5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
void paged_attention_v1(torch::Tensor& out, torch::Tensor& query,
                        torch::Tensor& key_cache, torch::Tensor& value_cache,
                        int num_kv_heads, float scale,
                        torch::Tensor& block_tables, torch::Tensor& seq_lens,
                        int block_size, int max_seq_len,
                        const c10::optional<torch::Tensor>& alibi_slopes,
                        const std::string& kv_cache_dtype, float kv_scale);

void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums,
                        torch::Tensor& max_logits, torch::Tensor& tmp_out,
                        torch::Tensor& query, torch::Tensor& key_cache,
                        torch::Tensor& value_cache, int num_kv_heads,
                        float scale, torch::Tensor& block_tables,
                        torch::Tensor& seq_lens, int block_size,
                        int max_seq_len,
                        const c10::optional<torch::Tensor>& alibi_slopes,
                        const std::string& kv_cache_dtype, float kv_scale);

void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
              float epsilon);

void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
                        torch::Tensor& weight, float epsilon);

void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
                      torch::Tensor& key, int head_size,
                      torch::Tensor& cos_sin_cache, bool is_neox);

void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
                              torch::Tensor& key, int head_size,
                              torch::Tensor& cos_sin_cache, bool is_neox,
                              int rot_dim,
                              torch::Tensor& cos_sin_cache_offsets);

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_new(torch::Tensor& out, torch::Tensor& input);

void gelu_fast(torch::Tensor& out, torch::Tensor& input);
48

49
#ifndef USE_ROCM
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
                        const torch::Tensor& codebooks,
                        const torch::Tensor& scales,
                        const torch::Tensor& codebook_partition_sizes,
                        const std::optional<torch::Tensor>& bias);

torch::Tensor aqlm_dequant(const torch::Tensor& codes,
                           const torch::Tensor& codebooks,
                           const torch::Tensor& codebook_partition_sizes);

torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
                       torch::Tensor _scaling_factors, torch::Tensor _zeros,
                       int split_k_iters);

torch::Tensor awq_dequantize(torch::Tensor _kernel,
                             torch::Tensor _scaling_factors,
                             torch::Tensor _zeros, int split_k_iters, int thx,
                             int thy);

torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                          torch::Tensor& b_scales, torch::Tensor& workspace,
                          int64_t size_m, int64_t size_n, int64_t size_k);

torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                                  torch::Tensor& b_meta,
                                  torch::Tensor& b_scales,
                                  torch::Tensor& workspace, int64_t num_bits,
                                  int64_t size_m, int64_t size_n,
                                  int64_t size_k);

torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                               torch::Tensor& b_scales, torch::Tensor& g_idx,
                               torch::Tensor& perm, torch::Tensor& workspace,
                               int64_t num_bits, int64_t size_m, int64_t size_n,
                               int64_t size_k, bool is_k_full);

torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
                                 int64_t size_k, int64_t size_n,
                                 int64_t num_bits);

int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
                         torch::Tensor const& b, torch::Tensor const& a_scales,
                         torch::Tensor const& b_scales);
93

94
#endif
95

96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
                     torch::Tensor lookup_table);

torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
                        torch::Tensor b_gptq_qzeros,
                        torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
                        bool use_exllama, int bit);

void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit);

void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
                             torch::Tensor& scale);

void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
                              torch::Tensor& scale);

void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
                          int block_size, torch::Tensor sorted_token_ids,
                          torch::Tensor experts_ids,
                          torch::Tensor num_tokens_post_pad);
116
117
118

#ifndef USE_ROCM
using fptr_t = uint64_t;
119
120
121
122
123
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
                      const std::vector<std::string>& handles,
                      const std::vector<int64_t>& offsets, int rank,
                      bool full_nvlink);
bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size,
124
                      bool full_nvlink);
125
126
127
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
                      torch::Tensor& out);
128
129
void dispose(fptr_t _fa);
int meta_size();
130
131
132
133
134
135
136
void register_buffer(fptr_t _fa, torch::Tensor& t,
                     const std::vector<std::string>& handles,
                     const std::vector<int64_t>& offsets);
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
    fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
                            const std::vector<std::vector<int64_t>>& offsets);
137
#endif