es_fp8_blockwise.cu 5.11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#include <torch/all.h>

#include <tuple>

#include "es_fp8_blockwise_launcher.cuh"

/**
 * @brief Performs blockwise grouped matrix multiplication on FP8 quantized inputs,
 *        with per-block scaling.
 *
 * This function dispatches to hardware-specific implementations (e.g., SM100 FP8)
 * to compute:
 *     C_i = scale_a[i] * A_i * scale_b[i] * B_i
 * for each expert group `i`, using input `problem_sizes` and `expert_offsets`
 * to describe the individual matrix dimensions and their offsets.
 *
 * Input tensors A and B must be quantized to 8-bit formats and dequantized before multiplication.
 * The output tensor is written with bfloat16 or half precision.
 *
 * @param output         Output tensor (must be of type bfloat16 or half).
 * @param a              Input tensor A (must be kFloat8_e4m3fn).
 * @param b              Input tensor B (must be kFloat8_e4m3fn).
 * @param scales_a       Scaling factors for tensor A, float32 per expert group.
 * @param scales_b       Scaling factors for tensor B, float32 per expert group.
 * @param stride_a       Stride information for tensor A (int32).
 * @param stride_b       Stride information for tensor B (int32).
 * @param stride_c       Stride information for output tensor C (int32).
 * @param problem_sizes  2D int32 tensor of shape (num_experts, 3), specifying (M, N, K)
 *                       for each grouped matrix multiplication problem.
 * @param expert_offsets 1D int32 tensor of size (num_experts), used to index into
 *                       the grouped input tensors for dispatch.
 */
void es_fp8_blockwise_scaled_grouped_mm(
    torch::Tensor& output,
    const torch::Tensor& a,
    const torch::Tensor& b,
    const torch::Tensor& scales_a,
    const torch::Tensor& scales_b,
    const torch::Tensor& stride_a,
    const torch::Tensor& stride_b,
    const torch::Tensor& stride_d,
    const torch::Tensor& problem_sizes,
    const torch::Tensor& expert_offsets) {
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
  TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
  TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)");
  TORCH_CHECK(
      problem_sizes.size(0) == expert_offsets.size(0), "Number of experts in problem_sizes must match expert_offsets");
  TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, "problem_sizes must be int32");
  TORCH_CHECK(a.scalar_type() == torch::kFloat8_e4m3fn, "a must be kFloat8_e4m3fn");
  TORCH_CHECK(b.scalar_type() == torch::kFloat8_e4m3fn, "b must be kFloat8_e4m3fn");
  TORCH_CHECK(
      output.scalar_type() == torch::kBFloat16 || output.scalar_type() == torch::kHalf,
      "output must be bfloat16 or half");

  int num_experts = (int)problem_sizes.size(0);
  torch::TensorOptions options_int64 = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
  torch::TensorOptions options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(a.device());
  torch::Tensor out_ptrs = torch::empty(num_experts, options_int64);
  torch::Tensor a_ptrs = torch::empty(num_experts, options_int64);
  torch::Tensor b_ptrs = torch::empty(num_experts, options_int64);
  torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int64);
  torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int64);

  torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int32);
  torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int32);

  torch::Tensor lm_problem_sizes = torch::empty({num_experts, 3}, options_int32);
  torch::Tensor mm_problem_sizes = torch::empty({num_experts, 3}, options_int32);
  torch::Tensor hm_problem_sizes = torch::empty({num_experts, 3}, options_int32);
  expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_pre_compute(
      out_ptrs,
      a_ptrs,
      b_ptrs,
      a_scales_ptrs,
      b_scales_ptrs,
      layout_sfa,
      layout_sfb,
      lm_problem_sizes,
      mm_problem_sizes,
      hm_problem_sizes,
      output,
      a,
      b,
      scales_a,
      scales_b,
      problem_sizes,
      expert_offsets);
  if (output.dtype() == torch::kBFloat16) {
    expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype<cutlass::bfloat16_t>(
        out_ptrs,
        a_ptrs,
        b_ptrs,
        a_scales_ptrs,
        b_scales_ptrs,
        stride_a,
        stride_b,
        stride_d,
        layout_sfa,
        layout_sfb,
        lm_problem_sizes,
        mm_problem_sizes,
        hm_problem_sizes);
  } else if (output.dtype() == torch::kFloat16) {
    expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype<cutlass::half_t>(
        out_ptrs,
        a_ptrs,
        b_ptrs,
        a_scales_ptrs,
        b_scales_ptrs,
        stride_a,
        stride_b,
        stride_d,
        layout_sfa,
        layout_sfb,
        lm_problem_sizes,
        mm_problem_sizes,
        hm_problem_sizes);
  } else {
    TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
  }
#else
  TORCH_CHECK_NOT_IMPLEMENTED(
      can_implement, "No implemented fp8_blockwise_scaled_grouped_mm for current compute capability: ", sm_version);
#endif
}