scaled_mm_c3x.cu 3.25 KB
Newer Older
1
#include <cudaTypedefs.h>
2
#include "c3x/scaled_mm_kernels.hpp"
3

4
#include "core/math.hpp"
5
6

/*
7
   This file defines quantized GEMM operations using the CUTLASS 3.x API, for
8
9
10
   NVIDIA GPUs with sm90a (Hopper) or later.
*/

11
12
13
14
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
                            torch::Tensor const& b_scales,
15
                            std::optional<torch::Tensor> const& bias) {
16
17
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

  using GroupShape = std::array<int64_t, 2>;

  int M = a.size(0), N = b.size(1), K = a.size(1);

  GroupShape a_scale_group_shape = [&, &s = a_scales]() -> GroupShape {
    if (s.numel() == 1) return {M, K};  // tensor-wise
    if (s.dim() == 2)
      return {ceil_div(a.size(0), s.size(0)), ceil_div(a.size(1), s.size(1))};
    TORCH_CHECK(false, "Unsupported scale shape for scale_a");
  }();

  GroupShape b_scale_group_shape = [&, &s = b_scales]() -> GroupShape {
    if (s.numel() == 1) return {K, N};  // tensor-wise
    if (s.dim() == 2)
      return {ceil_div(b.size(0), s.size(0)), ceil_div(b.size(1), s.size(1))};
    TORCH_CHECK(false, "Unsupported scale shape for scale_b");
  }();

  if ((a_scale_group_shape == GroupShape{M, K} ||
       a_scale_group_shape == GroupShape{1, K}) &&
      (b_scale_group_shape == GroupShape{K, N} ||
       b_scale_group_shape == GroupShape{K, 1})) {
    // "standard per-tensor/per-token/per-channel" scaling
    TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
    if (a.dtype() == torch::kFloat8_e4m3fn) {
      vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias);
    } else {
      TORCH_CHECK(a.dtype() == torch::kInt8);
      vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias);
    }
  } else if (a_scale_group_shape == GroupShape{1, 128} &&
             b_scale_group_shape == GroupShape{128, 128}) {
    // 1x128 per-token group scales for activations
    // 128x128 blockwise scales for weights
    TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn &&
                    b.dtype() == torch::kFloat8_e4m3fn,
                "Currently only FP8 is supported for A group shape 1x128 and "
                "B group shape 128x128");
    TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");

    vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales);
60
  } else {
61
    TORCH_CHECK(false, "Unsupported scale group shapes for CUTLASS 3.x GEMM");
62
63
64
  }
}

65
66
67
68
69
void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
                                torch::Tensor const& b,
                                torch::Tensor const& a_scales,
                                torch::Tensor const& b_scales,
                                torch::Tensor const& azp_adj,
70
71
                                std::optional<torch::Tensor> const& azp,
                                std::optional<torch::Tensor> const& bias) {
72
73
74
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);

75
76
  vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
                                        azp, bias);
77
}