torch_bindings.cpp 586 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#include "registration.h"
#include "punica_ops.h"

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
  m.def(
      "dispatch_bgmv(Tensor! y, Tensor x, Tensor w, Tensor indicies, int "
      "layer_idx, float scale) -> ()");
  m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv);

  m.def(
      "dispatch_bgmv_low_level(Tensor! y, Tensor x, Tensor w,"
      "Tensor indicies, int layer_idx,"
      "float scale, int h_in, int h_out,"
      "int y_offset) -> ()");
  m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level);
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)