dsv3_router_gemm_float_out.cu 8.32 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
/*
 * Adapted from
 * https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu
 * https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp
 *
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include "cuda_bf16.h"
#include "cuda_runtime.h"
#include "utils.h"

// Custom FMA implementation using PTX assembly instructions
__device__ __forceinline__ void fma(float2& d, float2 const& a, float2 const& b, float2 const& c) {
  asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n"
               : "=l"(reinterpret_cast<uint64_t&>(d))
               : "l"(reinterpret_cast<uint64_t const&>(a)),
                 "l"(reinterpret_cast<uint64_t const&>(b)),
                 "l"(reinterpret_cast<uint64_t const&>(c)));
}

// Convert 8 bfloat16 values from a uint4 to float array - optimized conversion
template <int VPT>
__device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec, float* dst) {
  __nv_bfloat16* bf16_ptr = reinterpret_cast<__nv_bfloat16*>(const_cast<uint4*>(&vec));

#pragma unroll
  for (int i = 0; i < VPT; i++) {
    dst[i] = __bfloat162float(bf16_ptr[i]);
  }
}

template <typename T, int kBlockSize, int VPT, int kNumTokens, int kNumExperts, int kHiddenDim>
49
__global__ __launch_bounds__(128, 1) void router_gemm_kernel_float_output(float* out, T const* mat_a, T const* mat_b) {
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
  // Each block handles one expert column
  int const n_idx = blockIdx.x;
  int const tid = threadIdx.x;
  constexpr int kWarpSize = 32;
  constexpr int kNumWarps = kBlockSize / kWarpSize;
  // Constants for this kernel
  constexpr int k_elems_per_k_iteration = VPT * kBlockSize;
  constexpr int k_iterations = kHiddenDim / k_elems_per_k_iteration;  // Total K iterations

  // Initialize accumulators for all M rows
  float acc[kNumTokens] = {};

  // Shared memory for warp-level reduction
  __shared__ float sm_reduction[kNumTokens][kNumWarps];  // kNumWarps

  // B matrix is in column-major order, so we can directly load a column for the n_idx expert
  T const* b_col = mat_b + n_idx * kHiddenDim;

  // Pre-compute k_base values for each iteration to help compiler optimize
  // int k_bases[k_iterations];
  int k_bases[k_iterations];
#pragma unroll
  for (int ki = 0; ki < k_iterations; ki++) {
    k_bases[ki] = ki * k_elems_per_k_iteration + tid * VPT;
  }

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
  asm volatile("griddepcontrol.wait;");
#endif

  // Process the GEMM in chunks
  for (int ki = 0; ki < k_iterations; ki++) {
    int const k_base = k_bases[ki];

    // Load B matrix values using vector load (8 bf16 values)
    uint4 b_vec = *reinterpret_cast<uint4 const*>(b_col + k_base);

    // Convert B values to float
    float b_float[VPT];
    bf16_uint4_to_float8<VPT>(b_vec, b_float);

// Process each token
#pragma unroll
    for (int m_idx = 0; m_idx < kNumTokens; m_idx++) {
      // Load both rows of A matrix using vector loads
      uint4 a_vec = *reinterpret_cast<uint4 const*>(mat_a + (m_idx * kHiddenDim) + k_base);

      // Convert A values to float
      float a_float[VPT];
      bf16_uint4_to_float8<VPT>(a_vec, a_float);

// Process elements in this chunk
#pragma unroll
      for (int k = 0; k < VPT; k++) {
        float a = a_float[k];
        float b = b_float[k];
        acc[m_idx] += a * b;
      }
    }
  }

  // Perform warp-level reduction
  int const warpSize = 32;
  int const warpId = tid / warpSize;
  int const laneId = tid % warpSize;

  // Register for warp-level reduction results
  float warp_result[kNumTokens];

#pragma unroll
  for (int m_idx = 0; m_idx < kNumTokens; m_idx++) {
    warp_result[m_idx] = acc[m_idx];
  }

// Perform warp-level reduction using optimized butterfly pattern
#pragma unroll
  for (int m = 0; m < kNumTokens; m++) {
    float sum = warp_result[m];

    // Butterfly reduction pattern
    sum += __shfl_xor_sync(0xffffffff, sum, 16);
    sum += __shfl_xor_sync(0xffffffff, sum, 8);
    sum += __shfl_xor_sync(0xffffffff, sum, 4);
    sum += __shfl_xor_sync(0xffffffff, sum, 2);
    sum += __shfl_xor_sync(0xffffffff, sum, 1);

    // Only the first thread in each warp stores to shared memory
    if (laneId == 0) {
      sm_reduction[m][warpId] = sum;
    }
  }

  __syncthreads();

  // Final reduction across warps (only first thread)
  if (tid == 0) {
#pragma unroll
    for (int m = 0; m < kNumTokens; m++) {
      float final_sum = 0.0f;

// Sum across the kNumWarps
#pragma unroll
      for (int w = 0; w < kNumWarps; w++) {
        final_sum += sm_reduction[m][w];
      }

      // Write final result
      out[m * kNumExperts + n_idx] = final_sum;
    }
  }
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
  asm volatile("griddepcontrol.launch_dependents;");
#endif
}

template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim>
166
void invokeRouterGemmFloatOutput(float* output, T const* mat_a, T const* mat_b, cudaStream_t stream) {
167
168
169
170
171
172
173
174
175
176
177
178
179
  constexpr int VPT = 16 / sizeof(T);
  constexpr int kBlockSize = 128;
  cudaLaunchConfig_t config;
  config.gridDim = kNumExperts;
  config.blockDim = kBlockSize;
  config.dynamicSmemBytes = 0;
  config.stream = stream;
  cudaLaunchAttribute attrs[1];
  attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
  attrs[0].val.programmaticStreamSerializationAllowed = getEnvEnablePDL();
  config.numAttrs = 1;
  config.attrs = attrs;
  cudaLaunchKernelEx(
180
181
182
183
184
      &config,
      router_gemm_kernel_float_output<T, kBlockSize, VPT, kNumTokens, kNumExperts, kHiddenDim>,
      output,
      mat_a,
      mat_b);
185
186
}

187
188
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 1, 256, 7168>(
    float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
189

190
191
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 2, 256, 7168>(
    float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
192

193
194
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 3, 256, 7168>(
    float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
195

196
197
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 4, 256, 7168>(
    float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
198

199
200
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 5, 256, 7168>(
    float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
201

202
203
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 6, 256, 7168>(
    float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
204

205
206
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 7, 256, 7168>(
    float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
207

208
209
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 8, 256, 7168>(
    float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
210

211
212
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 9, 256, 7168>(
    float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
213

214
215
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 10, 256, 7168>(
    float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
216

217
218
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 11, 256, 7168>(
    float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
219

220
221
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 12, 256, 7168>(
    float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
222

223
224
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 13, 256, 7168>(
    float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
225

226
227
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 14, 256, 7168>(
    float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
228

229
230
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 15, 256, 7168>(
    float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
231

232
233
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 16, 256, 7168>(
    float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);