mhc.h 3.32 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
// SPDX-License-Identifier: MIT


#pragma once

#include <torch/extension.h>

namespace aiter {
void mhc_pre_gemm_sqrsum(torch::Tensor& out,    // (split_k, m, hc_mult3) / (m, hc_mult3)
                         torch::Tensor& sqrsum, // (split_k, m) / (m)
                         torch::Tensor& x,      // (m, hc_hidden_size)
                         torch::Tensor& fn,     // (hc_mult3, hc_hidden_size)
                         int tile_k = 128,
                         bool use_tf32 = false);
void mhc_pre_gemm_sqrsum_stage1_m128(torch::Tensor& out,    // (split_k, m, hc_mult3) / (m, hc_mult3)
                                     torch::Tensor& sqrsum, // (split_k, m) / (m)
                                     torch::Tensor& x,      // (m, hc_hidden_size)
                                     torch::Tensor& fn,     // (hc_mult3, hc_hidden_size)
                                     bool use_tf32 = false
);
void mhc_pre_reduce_splitk(torch::Tensor& out_red,    // (1, m, hc_mult3)
                           torch::Tensor& sqrsum_red, // (1, m)
                           torch::Tensor& out,        // (split_k, m, hc_mult3)
                           torch::Tensor& sqrsum      // (split_k, m)
);
void mhc_pre_big_fuse(torch::Tensor& post_mix,        // (m, hc_mult)
                      torch::Tensor& comb_mix,        // (m, hc_mult * hc_mult)
                      torch::Tensor& layer_input,     // (m, hidden_size)
                      torch::Tensor& gemm_out_mul,    // (split_k, m, hc_mult3)
                      torch::Tensor& gemm_out_sqrsum, // (split_k, m)
                      torch::Tensor& hc_scale,        // (3)
                      torch::Tensor& hc_base,         // (hc_mult3)
                      torch::Tensor& residual,        // (m, hc_mult, hidden_size)
                      float rms_eps            = 1e-6,
                      float hc_pre_eps         = 1e-6,
                      float hc_sinkhorn_eps    = 1e-6,
                      float hc_post_mult_value = 1.0,
                      int sinkhorn_repeat      = 20);
void mhc_pre_big_fuse_tlstyle(torch::Tensor& post_mix,        // (m, hc_mult)
                              torch::Tensor& comb_mix,        // (m, hc_mult * hc_mult)
                              torch::Tensor& layer_input,     // (m, hidden_size)
                              torch::Tensor& gemm_out_mul,    // (split_k, m, hc_mult3)
                              torch::Tensor& gemm_out_sqrsum, // (split_k, m)
                              torch::Tensor& hc_scale,        // (3)
                              torch::Tensor& hc_base,         // (hc_mult3)
                              torch::Tensor& residual,        // (m, hc_mult, hidden_size)
                              float rms_eps            = 1e-6,
                              float hc_pre_eps         = 1e-6,
                              float hc_sinkhorn_eps    = 1e-6,
                              float hc_post_mult_value = 1.0,
                              int sinkhorn_repeat      = 20);
void mhc_post(torch::Tensor& out,            // (m, hc_mult, hidden_size)
              torch::Tensor& x,              // (m, hidden_size)
              torch::Tensor& residual,       // (m, hc_mult, hidden_size)
              torch::Tensor& post_layer_mix, // (m, hc_mult)
              torch::Tensor& comb_res_mix    // (m, hc_mult, hc_mult)
);
} // namespace aiter