common_extension.cc 2.05 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
8
9
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/all.h>
#include <torch/library.h>

#include "lightx2v_kernel_ops.h"

TORCH_LIBRARY_FRAGMENT(lightx2v_kernel, m) {

  m.def(
10
      "cutlass_scaled_nvfp4_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor "
helloyongyang's avatar
helloyongyang committed
11
      "alpha, Tensor? bias) -> ()");
12
  m.impl("cutlass_scaled_nvfp4_mm_sm120", torch::kCUDA, &cutlass_scaled_nvfp4_mm_sm120);
helloyongyang's avatar
helloyongyang committed
13
14

  m.def(
15
      "scaled_nvfp4_quant_sm120(Tensor! output, Tensor! input,"
helloyongyang's avatar
helloyongyang committed
16
      "                 Tensor! output_scale, Tensor! input_scale) -> ()");
17
  m.impl("scaled_nvfp4_quant_sm120", torch::kCUDA, &scaled_nvfp4_quant_sm120);
helloyongyang's avatar
helloyongyang committed
18

19
  m.def(
20
21
22
23
24
25
    "scaled_mxfp4_quant_sm120(Tensor! output, Tensor! input,"
    "                 Tensor! output_scale) -> ()");
  m.impl("scaled_mxfp4_quant_sm120", torch::kCUDA, &scaled_mxfp4_quant_sm120);

  m.def(
      "scaled_mxfp8_quant_sm120(Tensor! output, Tensor! input,"
26
      "                 Tensor! output_scale) -> ()");
27
28
  m.impl("scaled_mxfp8_quant_sm120", torch::kCUDA, &scaled_mxfp8_quant_sm120);

29
  m.def(
30
      "scaled_mxfp6_quant_sm120(Tensor! output, Tensor! input,"
31
      "                 Tensor! output_scale) -> ()");
32
33
34
35
36
37
  m.impl("scaled_mxfp6_quant_sm120", torch::kCUDA, &scaled_mxfp6_quant_sm120);

  m.def(
    "cutlass_scaled_mxfp4_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor "
    "alpha, Tensor? bias) -> ()");
  m.impl("cutlass_scaled_mxfp4_mm_sm120", torch::kCUDA, &cutlass_scaled_mxfp4_mm_sm120);
38

helloyongyang's avatar
helloyongyang committed
39
40
41
42
43
  m.def(
      "cutlass_scaled_mxfp6_mxfp8_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor "
      "alpha, Tensor? bias) -> ()");
  m.impl("cutlass_scaled_mxfp6_mxfp8_mm_sm120", torch::kCUDA, &cutlass_scaled_mxfp6_mxfp8_mm_sm120);

44
45
46
47
48
  m.def(
      "cutlass_scaled_mxfp8_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor "
      "alpha, Tensor? bias) -> ()");
  m.impl("cutlass_scaled_mxfp8_mm_sm120", torch::kCUDA, &cutlass_scaled_mxfp8_mm_sm120);

helloyongyang's avatar
helloyongyang committed
49
50
51
}

REGISTER_EXTENSION(common_ops)