es_fp8_blockwise_launcher.cuh 12.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>

#include <iostream>
#include <string>

#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "es_fp8_blockwise_functor.cuh"

namespace expert_specialization {

using namespace cute;

void es_sm90_fp8_blockwise_scaled_group_mm_pre_compute(
    // Output
    torch::Tensor& out_ptrs,
    torch::Tensor& a_ptrs,
    torch::Tensor& b_ptrs,
    torch::Tensor& a_scales_ptrs,
    torch::Tensor& b_scales_ptrs,
    torch::Tensor& layout_sfa,
    torch::Tensor& layout_sfb,
    torch::Tensor& lm_problem_sizes,
    torch::Tensor& mm_problem_sizes,
    torch::Tensor& hm_problem_sizes,
    // Input
    torch::Tensor& out_tensors,
    torch::Tensor const& a_tensors,
    torch::Tensor const& b_tensors,
    torch::Tensor const& a_scales,
    torch::Tensor const& b_scales,
    torch::Tensor const& problem_sizes,
    torch::Tensor const& expert_offsets) {
  TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
  TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);

  const std::string H20_device_type_str("NVIDIA H20");
  bool is_h20_device = std::string(at::cuda::getCurrentDeviceProperties()->name) == H20_device_type_str;

  // Creat Scale Factor Layout Functor
  using LayoutSFA = typename PerfConfigMiddleMH20::LayoutSFA;
  using LayoutSFB = typename PerfConfigMiddleMH20::LayoutSFB;
  struct Fp8BlockwiseGroupedGemmSFLayoutFunctor<PerfConfigMiddleMH20> sf_layout(
      reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()));

  int num_experts = (int)expert_offsets.size(0);
  auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
  // Dispatch
  if (out_tensors.dtype() == torch::kBFloat16) {
    struct Fp8BlockwiseGroupedGemmOffsetFunctor<cutlass::float_e4m3_t, float, cutlass::bfloat16_t> of(
        static_cast<int*>(expert_offsets.data_ptr()),
        static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()),
        static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()),
        static_cast<cutlass::bfloat16_t*>(out_tensors.data_ptr()),
        static_cast<float*>(a_scales.data_ptr()),
        static_cast<float*>(b_scales.data_ptr()),
        static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()),
        static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()),
        static_cast<float**>(a_scales_ptrs.data_ptr()),
        static_cast<float**>(b_scales_ptrs.data_ptr()),
        static_cast<cutlass::bfloat16_t**>(out_ptrs.data_ptr()));
    if (!is_h20_device) {
      struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMHx00> lm_psf(
          static_cast<int*>(lm_problem_sizes.data_ptr()));
      struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMHx00> mm_psf(
          static_cast<int*>(mm_problem_sizes.data_ptr()));
      struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMHx00> hm_psf(
          static_cast<int*>(hm_problem_sizes.data_ptr()));
      groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
          static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
    } else {
      struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMH20> lm_psf(
          static_cast<int*>(lm_problem_sizes.data_ptr()));
      struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMH20> mm_psf(
          static_cast<int*>(mm_problem_sizes.data_ptr()));
      struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMH20> hm_psf(
          static_cast<int*>(hm_problem_sizes.data_ptr()));
      groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
          static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
    }
  } else if (out_tensors.dtype() == torch::kFloat16) {
    struct Fp8BlockwiseGroupedGemmOffsetFunctor<cutlass::float_e4m3_t, float, cutlass::half_t> of(
        static_cast<int*>(expert_offsets.data_ptr()),
        static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()),
        static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()),
        static_cast<cutlass::half_t*>(out_tensors.data_ptr()),
        static_cast<float*>(a_scales.data_ptr()),
        static_cast<float*>(b_scales.data_ptr()),
        static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()),
        static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()),
        static_cast<float**>(a_scales_ptrs.data_ptr()),
        static_cast<float**>(b_scales_ptrs.data_ptr()),
        static_cast<cutlass::half_t**>(out_ptrs.data_ptr()));
    if (!is_h20_device) {
      struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMHx00> lm_psf(
          static_cast<int*>(lm_problem_sizes.data_ptr()));
      struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMHx00> mm_psf(
          static_cast<int*>(mm_problem_sizes.data_ptr()));
      struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMHx00> hm_psf(
          static_cast<int*>(hm_problem_sizes.data_ptr()));
      groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
          static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
    } else {
      struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMH20> lm_psf(
          static_cast<int*>(lm_problem_sizes.data_ptr()));
      struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMH20> mm_psf(
          static_cast<int*>(mm_problem_sizes.data_ptr()));
      struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMH20> hm_psf(
          static_cast<int*>(hm_problem_sizes.data_ptr()));
      groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
          static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
    }
  } else {
    TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
  }
}

template <typename GemmTraits>
void launch_sm90_fp8_blockwise_scaled_group_mm(
    torch::Tensor& out_ptrs,
    const torch::Tensor& a_ptrs,
    const torch::Tensor& b_ptrs,
    const torch::Tensor& a_scales_ptrs,
    const torch::Tensor& b_scales_ptrs,
    const torch::Tensor& stride_a,
    const torch::Tensor& stride_b,
    const torch::Tensor& stride_d,
    const torch::Tensor& layout_sfa,
    const torch::Tensor& layout_sfb,
    const torch::Tensor& problem_sizes) {
  using ElementA = typename GemmTraits::ElementA;
  using StrideA = typename GemmTraits::StrideA;
  using ElementB = typename GemmTraits::ElementB;
  using StrideB = typename GemmTraits::StrideB;
  using ElementAccumulator = typename GemmTraits::ElementAccumulator;
  using LayoutSFA = typename GemmTraits::LayoutSFA;
  using LayoutSFB = typename GemmTraits::LayoutSFB;
  using ElementD = typename GemmTraits::ElementD;
  using StrideD = typename GemmTraits::StrideD;
  using UnderlyingProblemShape = typename GemmTraits::ProblemShape::UnderlyingProblemShape;
  using Gemm = typename GemmTraits::Gemm;
  using GemmKernel = typename GemmTraits::GemmKernel;

  int num_experts = (int)problem_sizes.size(0);
  Gemm gemm_op;

  typename GemmKernel::MainloopArguments mainloop_args{
      static_cast<const ElementA**>(a_ptrs.data_ptr()),
      static_cast<StrideA*>(stride_a.data_ptr()),
      static_cast<const ElementB**>(b_ptrs.data_ptr()),
      static_cast<StrideB*>(stride_b.data_ptr()),
      static_cast<const ElementAccumulator**>(a_scales_ptrs.data_ptr()),
      reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()),
      static_cast<const ElementAccumulator**>(b_scales_ptrs.data_ptr()),
      reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr())};

  cutlass::KernelHardwareInfo hw_info;
  hw_info.device_id = c10::cuda::current_device();
  hw_info.sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;

  typename GemmKernel::EpilogueArguments epilogue_args{
      {}, nullptr, nullptr, static_cast<ElementD**>(out_ptrs.data_ptr()), static_cast<StrideD*>(stride_d.data_ptr())};

  UnderlyingProblemShape* problem_sizes_as_shapes = static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
  typename GemmKernel::Arguments args{
      cutlass::gemm::GemmUniversalMode::kGrouped,
      {num_experts, problem_sizes_as_shapes, nullptr},
      mainloop_args,
      epilogue_args,
      hw_info};

  at::cuda::CUDAGuard device_guard{(char)a_ptrs.get_device()};
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a_ptrs.get_device());

  auto can_implement_status = gemm_op.can_implement(args);
  TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM");

  torch::TensorOptions options_uint8 = torch::TensorOptions().dtype(torch::kUInt8).device(out_ptrs.device());
  size_t workspace_size = gemm_op.get_workspace_size(args);
  torch::Tensor workspace = torch::empty(workspace_size, options_uint8);

  auto status = gemm_op.initialize(args, workspace.data_ptr(), stream);
  TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");

  status = gemm_op.run(stream, nullptr, true);  // Enable PDL
  TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}

template <typename OutType>
void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
    torch::Tensor& out_ptrs,
    const torch::Tensor& a_ptrs,
    const torch::Tensor& b_ptrs,
    const torch::Tensor& a_scales_ptrs,
    const torch::Tensor& b_scales_ptrs,
    const torch::Tensor& stride_a,
    const torch::Tensor& stride_b,
    const torch::Tensor& stride_d,
    const torch::Tensor& layout_sfa,
    const torch::Tensor& layout_sfb,
    const torch::Tensor& lm_problem_sizes,
    const torch::Tensor& mm_problem_sizes,
    const torch::Tensor& hm_problem_sizes) {
  using LowMGemmH20Traits =
      ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits<OutType, cutlass::layout::ColumnMajor, PerfConfigLowMH20>;
  using LowMGemmHx00Traits =
      ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits<OutType, cutlass::layout::ColumnMajor, PerfConfigLowMHx00>;
  using MiddleMGemmH20Traits =
      ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits<OutType, cutlass::layout::RowMajor, PerfConfigMiddleMH20>;
  using MiddleMGemmHx00Traits = ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits<
      OutType,
      cutlass::layout::ColumnMajor,
      PerfConfigMiddleMHx00>;
  using HighMGemmH20Traits =
      ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits<OutType, cutlass::layout::RowMajor, PerfConfigHighMH20>;
  using HighMGemmHx00Traits =
      ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits<OutType, cutlass::layout::RowMajor, PerfConfigHighMHx00>;

  const std::string H20_device_type_str("NVIDIA H20");
  bool is_h20_device = std::string(at::cuda::getCurrentDeviceProperties()->name) == H20_device_type_str;

  if (!is_h20_device) {
    launch_sm90_fp8_blockwise_scaled_group_mm<LowMGemmHx00Traits>(
        out_ptrs,
        b_ptrs,
        a_ptrs,
        b_scales_ptrs,
        a_scales_ptrs,
        stride_b,
        stride_a,
        stride_d,
        layout_sfb,
        layout_sfa,
        lm_problem_sizes);
  } else {
    launch_sm90_fp8_blockwise_scaled_group_mm<LowMGemmH20Traits>(
        out_ptrs,
        b_ptrs,
        a_ptrs,
        b_scales_ptrs,
        a_scales_ptrs,
        stride_b,
        stride_a,
        stride_d,
        layout_sfb,
        layout_sfa,
        lm_problem_sizes);
  }

  if (!is_h20_device) {
    launch_sm90_fp8_blockwise_scaled_group_mm<MiddleMGemmHx00Traits>(
        out_ptrs,
        b_ptrs,
        a_ptrs,
        b_scales_ptrs,
        a_scales_ptrs,
        stride_b,
        stride_a,
        stride_d,
        layout_sfb,
        layout_sfa,
        mm_problem_sizes);
  } else {
    launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmHx00Traits>(
        out_ptrs,
        a_ptrs,
        b_ptrs,
        a_scales_ptrs,
        b_scales_ptrs,
        stride_a,
        stride_b,
        stride_d,
        layout_sfa,
        layout_sfb,
        mm_problem_sizes);
  }

  if (!is_h20_device) {
    launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmHx00Traits>(
        out_ptrs,
        a_ptrs,
        b_ptrs,
        a_scales_ptrs,
        b_scales_ptrs,
        stride_a,
        stride_b,
        stride_d,
        layout_sfa,
        layout_sfb,
        hm_problem_sizes);
  } else {
    launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmH20Traits>(
        out_ptrs,
        a_ptrs,
        b_ptrs,
        a_scales_ptrs,
        b_scales_ptrs,
        stride_a,
        stride_b,
        stride_d,
        layout_sfa,
        layout_sfb,
        hm_problem_sizes);
  }
}

}  // namespace expert_specialization