common_extension.cc 1.38 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/all.h>
#include <torch/library.h>

#include "lightx2v_kernel_ops.h"

TORCH_LIBRARY_FRAGMENT(lightx2v_kernel, m) {

  m.def(
      "cutlass_scaled_fp4_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor "
      "alpha, Tensor? bias) -> ()");
  m.impl("cutlass_scaled_fp4_mm_sm120", torch::kCUDA, &cutlass_scaled_fp4_mm_sm120);

  m.def(
      "scaled_fp4_quant_sm120(Tensor! output, Tensor! input,"
      "                 Tensor! output_scale, Tensor! input_scale) -> ()");
  m.impl("scaled_fp4_quant_sm120", torch::kCUDA, &scaled_fp4_quant_sm120);

19
20
21
22
23
  m.def(
      "scaled_fp8_quant_sm120(Tensor! output, Tensor! input,"
      "                 Tensor! output_scale) -> ()");
  m.impl("scaled_fp8_quant_sm120", torch::kCUDA, &scaled_fp8_quant_sm120);

helloyongyang's avatar
helloyongyang committed
24
25
26
27
28
  m.def(
      "cutlass_scaled_mxfp6_mxfp8_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor "
      "alpha, Tensor? bias) -> ()");
  m.impl("cutlass_scaled_mxfp6_mxfp8_mm_sm120", torch::kCUDA, &cutlass_scaled_mxfp6_mxfp8_mm_sm120);

29
30
31
32
33
  m.def(
      "cutlass_scaled_mxfp8_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor "
      "alpha, Tensor? bias) -> ()");
  m.impl("cutlass_scaled_mxfp8_mm_sm120", torch::kCUDA, &cutlass_scaled_mxfp8_mm_sm120);

helloyongyang's avatar
helloyongyang committed
34
35
36
}

REGISTER_EXTENSION(common_ops)