es_fp8_blockwise.cu 6.21 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include <torch/all.h>

#include <tuple>

#include "es_fp8_blockwise_launcher.cuh"

/**
 * @brief Performs blockwise grouped matrix multiplication on FP8 quantized inputs,
 *        with per-block scaling.
 *
 * This function dispatches to hardware-specific implementations (e.g., SM100 FP8)
 * to compute:
 *     C_i = scale_a[i] * A_i * scale_b[i] * B_i
 * for each expert group `i`, using input `problem_sizes` and `expert_offsets`
 * to describe the individual matrix dimensions and their offsets.
 *
 * Input tensors A and B must be quantized to 8-bit formats and dequantized before multiplication.
 * The output tensor is written with bfloat16 or half precision.
 *
 * @param output         Output tensor (must be of type bfloat16 or half).
 * @param a              Input tensor A (must be kFloat8_e4m3fn).
 * @param b              Input tensor B (must be kFloat8_e4m3fn).
 * @param scales_a       Scaling factors for tensor A, float32 per expert group.
 * @param scales_b       Scaling factors for tensor B, float32 per expert group.
 * @param stride_a       Stride information for tensor A (int32).
 * @param stride_b       Stride information for tensor B (int32).
 * @param stride_c       Stride information for output tensor C (int32).
 * @param problem_sizes  2D int32 tensor of shape (num_experts, 3), specifying (M, N, K)
 *                       for each grouped matrix multiplication problem.
 * @param expert_offsets 1D int32 tensor of size (num_experts), used to index into
 *                       the grouped input tensors for dispatch.
 */
void es_fp8_blockwise_scaled_grouped_mm(
    torch::Tensor& output,
    const torch::Tensor& a,
    const torch::Tensor& b,
    const torch::Tensor& scales_a,
    const torch::Tensor& scales_b,
    const torch::Tensor& stride_a,
    const torch::Tensor& stride_b,
    const torch::Tensor& stride_d,
    const torch::Tensor& problem_sizes,
    const torch::Tensor& expert_offsets) {
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
  TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
  TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)");
  TORCH_CHECK(
      problem_sizes.size(0) == expert_offsets.size(0), "Number of experts in problem_sizes must match expert_offsets");
  TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, "problem_sizes must be int32");
  TORCH_CHECK(a.scalar_type() == torch::kFloat8_e4m3fn, "a must be kFloat8_e4m3fn");
  TORCH_CHECK(b.scalar_type() == torch::kFloat8_e4m3fn, "b must be kFloat8_e4m3fn");
  TORCH_CHECK(
      output.scalar_type() == torch::kBFloat16 || output.scalar_type() == torch::kHalf,
      "output must be bfloat16 or half");

  int num_experts = (int)problem_sizes.size(0);
  torch::TensorOptions options_int64 = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
  torch::TensorOptions options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(a.device());
  torch::Tensor out_ptrs = torch::empty(num_experts, options_int64);
  torch::Tensor a_ptrs = torch::empty(num_experts, options_int64);
  torch::Tensor b_ptrs = torch::empty(num_experts, options_int64);
  torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int64);
  torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int64);

  torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int32);
  torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int32);

  torch::Tensor lm_problem_sizes = torch::empty({num_experts, 3}, options_int32);
  torch::Tensor mm_problem_sizes = torch::empty({num_experts, 3}, options_int32);
  torch::Tensor hm_problem_sizes = torch::empty({num_experts, 3}, options_int32);
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

  const std::string H20_device_type_str("NVIDIA H20");
  bool is_h20_device = std::string(at::cuda::getCurrentDeviceProperties()->name) == H20_device_type_str;
  at::cuda::CUDAGuard device_guard{(char)a.get_device()};
  cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());

  if (output.dtype() == torch::kBFloat16) {
    expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_pre_compute<cutlass::bfloat16_t>(
        out_ptrs,
        a_ptrs,
        b_ptrs,
        a_scales_ptrs,
        b_scales_ptrs,
        layout_sfa,
        layout_sfb,
        lm_problem_sizes,
        mm_problem_sizes,
        hm_problem_sizes,
        output,
        a,
        b,
        scales_a,
        scales_b,
        problem_sizes,
        expert_offsets,
        is_h20_device,
        stream);
  } else if (output.dtype() == torch::kFloat16) {
    expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_pre_compute<cutlass::half_t>(
        out_ptrs,
        a_ptrs,
        b_ptrs,
        a_scales_ptrs,
        b_scales_ptrs,
        layout_sfa,
        layout_sfb,
        lm_problem_sizes,
        mm_problem_sizes,
        hm_problem_sizes,
        output,
        a,
        b,
        scales_a,
        scales_b,
        problem_sizes,
        expert_offsets,
        is_h20_device,
        stream);
  } else {
    TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
  }

123
124
125
126
127
128
129
130
131
132
133
134
135
136
  if (output.dtype() == torch::kBFloat16) {
    expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype<cutlass::bfloat16_t>(
        out_ptrs,
        a_ptrs,
        b_ptrs,
        a_scales_ptrs,
        b_scales_ptrs,
        stride_a,
        stride_b,
        stride_d,
        layout_sfa,
        layout_sfb,
        lm_problem_sizes,
        mm_problem_sizes,
137
138
139
        hm_problem_sizes,
        is_h20_device,
        stream);
140
141
142
143
144
145
146
147
148
149
150
151
152
153
  } else if (output.dtype() == torch::kFloat16) {
    expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype<cutlass::half_t>(
        out_ptrs,
        a_ptrs,
        b_ptrs,
        a_scales_ptrs,
        b_scales_ptrs,
        stride_a,
        stride_b,
        stride_d,
        layout_sfa,
        layout_sfb,
        lm_problem_sizes,
        mm_problem_sizes,
154
155
156
        hm_problem_sizes,
        is_h20_device,
        stream);
157
158
159
160
161
162
163
164
  } else {
    TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
  }
#else
  TORCH_CHECK_NOT_IMPLEMENTED(
      can_implement, "No implemented fp8_blockwise_scaled_grouped_mm for current compute capability: ", sm_version);
#endif
}