common.cu 10.8 KB
Newer Older
1
#include <ATen/cuda/CUDAContext.h>
2
#include <torch/all.h>
3
4
5
6
7
8
9
#include <c10/cuda/CUDAGuard.h>

#include <cmath>

#include "cuda_compat.h"
#include "dispatch_utils.h"

10
11
#include "../../reduction_utils.cuh"

12
13
14
15
16
17
18
19
20
21
22
23
#ifndef USE_ROCM
using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
    std::numeric_limits<FP8_TYPE>::max();
#else
  #include "amd/hip_float8.h"
using FP8_TYPE = c10::Float8_e4m3fnuz;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr auto FP8_E4M3_MAX = 224.0f;
#endif

24
25
26
namespace vllm {

__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
27
28
29
30
31
  float old;
  old = (value >= 0)
            ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
            : __uint_as_float(
                  atomicMin((unsigned int*)addr, __float_as_uint(value)));
32

33
  return old;
34
35
}

36
template <bool is_scale_inverted>
37
38
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
                                                          float const scale) {
39
40
41
42
43
44
45
  float x = 0.0f;
  if constexpr (is_scale_inverted) {
    x = val * scale;
  } else {
    x = val / scale;
  }

46
  float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
47
#ifndef USE_ROCM
48
  return static_cast<c10::Float8_e4m3fn>(r);
49
50
51
52
53
#else
  // Use hardware cvt instruction for fp8 on rocm
  return c10::Float8_e4m3fnuz(hip_fp8(r).data,
                              c10::Float8_e4m3fnuz::from_bits());
#endif
54
55
}

56
57
58
59
60
61
// Compute the absolute maximum m of the input tensor and store
// m / float8_e4m3::max() in *scale. Each thread block performs a
// reduction tree and the memory in scale is atomically updated.
// So to get the right answer, *scale needs to be initialized to
// a value <= 0.0 and we need to wait for all thread blocks to
// finish before consuming *scale.
62
63
64
65
template <typename scalar_t>
__global__ void segmented_max_reduction(float* __restrict__ scale,
                                        const scalar_t* __restrict__ input,
                                        int64_t num_elems) {
66
  __shared__ float cache[1024];
67
  int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84

  // First store maximum for all values processes by
  // the current thread in cache[threadIdx.x]
  scalar_t tmp = 0.0;
  while (i < num_elems) {
    float x = static_cast<float>(input[i]);
    tmp = max(tmp, fabs(x));
    i += blockDim.x * gridDim.x;
  }
  cache[threadIdx.x] = tmp;

  __syncthreads();

  // Now perform parallel reduction within the thread block
  int ib = blockDim.x / 2;
  while (ib != 0) {
    if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
85
      cache[threadIdx.x] = cache[threadIdx.x + ib];
86
87
88
89
90
91
92
    }
    __syncthreads();
    ib /= 2;
  }
  // Finally, since cache[0] contains the maximum for this thread block,
  // atomically write the max to the target location
  if (threadIdx.x == 0) {
93
    atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
94
95
96
  }
}

97
98
99
100
101
102
103
104
105
template <typename scalar_t>
struct __align__(8) vec4_t {
  scalar_t x;
  scalar_t y;
  scalar_t z;
  scalar_t w;
};

typedef struct __align__(4) {
106
107
108
109
  FP8_TYPE x;
  FP8_TYPE y;
  FP8_TYPE z;
  FP8_TYPE w;
110
111
112
}
float8x4_t;

113
template <typename scalar_t>
114
115
116
117
118
119
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
                                int64_t const num_elems, int const tid,
                                int const step) {
  // Vectorized input/output to better utilize memory bandwidth.
  vec4_t<scalar_t> const* vectorized_in =
      reinterpret_cast<vec4_t<scalar_t> const*>(input);
120

121
  int64_t const num_vec_elems = num_elems >> 2;
122
123
124
  float absmax_val = 0.0f;

#pragma unroll 4
125
  for (int64_t i = tid; i < num_vec_elems; i += step) {
126
127
128
129
130
131
    vec4_t<scalar_t> in_vec = vectorized_in[i];
    absmax_val = max(absmax_val, fabs(in_vec.x));
    absmax_val = max(absmax_val, fabs(in_vec.y));
    absmax_val = max(absmax_val, fabs(in_vec.z));
    absmax_val = max(absmax_val, fabs(in_vec.w));
  }
132

133
  // Handle the remaining elements if num_elems is not divisible by 4
134
  for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
135
136
137
138
139
140
    absmax_val = max(absmax_val, fabs(input[i]));
  }

  return absmax_val;
}

141
template <typename scalar_t, bool is_scale_inverted>
142
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
143
                                          scalar_t const* __restrict__ input,
144
                                          float const scale,
145
146
                                          int64_t const num_elems,
                                          int const tid, int const step) {
147
  // Vectorized input/output to better utilize memory bandwidth.
148
149
  vec4_t<scalar_t> const* vectorized_in =
      reinterpret_cast<vec4_t<scalar_t> const*>(input);
150
151
  float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);

152
  int64_t const num_vec_elems = num_elems >> 2;
153
154

#pragma unroll 4
155
  for (int64_t i = tid; i < num_vec_elems; i += step) {
156
157
158
    vec4_t<scalar_t> in_vec = vectorized_in[i];
    float8x4_t out_vec;

159
160
161
162
163
164
165
166
    out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
        static_cast<float>(in_vec.x), scale);
    out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
        static_cast<float>(in_vec.y), scale);
    out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
        static_cast<float>(in_vec.z), scale);
    out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
        static_cast<float>(in_vec.w), scale);
167
168
169
170
    vectorized_out[i] = out_vec;
  }

  // Handle the remaining elements if num_elems is not divisible by 4
171
  for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
172
173
    out[i] = scaled_fp8_conversion<is_scale_inverted>(
        static_cast<float>(input[i]), scale);
174
175
176
  }
}

177
template <typename scalar_t>
178
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
179
180
181
182
183
184
185
186
                                        const scalar_t* __restrict__ input,
                                        const float* __restrict__ scale,
                                        int64_t num_elems) {
  int tid = blockDim.x * blockIdx.x + threadIdx.x;

  // Invert the scale so that we can use multiplications to avoid expensive
  // division.
  const float inverted_scale = 1.0f / (*scale);
187
188
  scaled_fp8_conversion_vec<scalar_t, true>(
      out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
189
190
191
192
}

template <typename scalar_t>
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
193
    FP8_TYPE* __restrict__ out, float* __restrict__ scale,
194
195
196
197
    scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
    const int hidden_size) {
  float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);

198
199
200
201
  int const tid = threadIdx.x;
  int const token_idx = blockIdx.x;

  scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size];
202
  FP8_TYPE* __restrict__ token_output = &out[token_idx * hidden_size];
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218

  // For vectorization, token_input and token_output pointers need to be
  // aligned at 8-byte and 4-byte addresses respectively.
  bool const can_vectorize = hidden_size % 4 == 0;

  float absmax_val = 0.0f;
  if (can_vectorize) {
    absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x);
  } else {
    for (int i = tid; i < hidden_size; i += blockDim.x) {
      float const x = static_cast<float>(token_input[i]);
      absmax_val = max(absmax_val, fabs(x));
    }
  }

  float const block_absmax_val_maybe = blockReduceMax(absmax_val);
219
  __shared__ float token_scale;
220
  if (tid == 0) {
221
222
223
224
225
226
227
228
    if (scale_ub) {
      token_scale = min(block_absmax_val_maybe, *scale_ub);
    } else {
      token_scale = block_absmax_val_maybe;
    }
    // token scale computation
    token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
    scale[token_idx] = token_scale;
229
230
231
  }
  __syncthreads();

232
  // Note that we don't use inverted scales so we can match FBGemm impl.
233
  if (can_vectorize) {
234
235
    scaled_fp8_conversion_vec<scalar_t, false>(
        token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
236
237
  } else {
    for (int i = tid; i < hidden_size; i += blockDim.x) {
238
239
      token_output[i] = scaled_fp8_conversion<false>(
          static_cast<float>(token_input[i]), token_scale);
240
241
242
243
    }
  }
}

244
}  // namespace vllm
245

246
247
248
void static_scaled_fp8_quant(torch::Tensor& out,          // [..., d]
                             torch::Tensor const& input,  // [..., d]
                             torch::Tensor const& scale)  // [1]
249
250
251
252
253
254
255
256
{
  int64_t num_tokens = input.numel() / input.size(-1);
  int64_t num_elems = input.numel();
  dim3 grid(num_tokens);
  dim3 block(1024);
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  VLLM_DISPATCH_FLOATING_TYPES(
257
258
      input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
        vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
259
            out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
260
            scale.data_ptr<float>(), num_elems);
261
262
263
      });
}

264
265
266
void dynamic_scaled_fp8_quant(torch::Tensor& out,          // [..., d]
                              torch::Tensor const& input,  // [..., d]
                              torch::Tensor& scale)        // [1]
267
268
269
270
271
272
273
274
{
  int64_t num_tokens = input.numel() / input.size(-1);
  int64_t num_elems = input.numel();
  dim3 grid(num_tokens);
  dim3 block(1024);
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  VLLM_DISPATCH_FLOATING_TYPES(
275
276
277
278
      input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
        vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
            scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
        vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
279
            out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
280
            scale.data_ptr<float>(), num_elems);
281
282
      });
}
283

284
285
286
287
void dynamic_per_token_scaled_fp8_quant(
    torch::Tensor& out,          // [..., d]
    torch::Tensor const& input,  // [..., d]
    torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
288
289
290
291
292
293
294
295
296
297
298
299
300
301
  TORCH_CHECK(input.is_contiguous());
  TORCH_CHECK(out.is_contiguous());

  int const hidden_size = input.size(-1);
  int const num_tokens = input.numel() / hidden_size;
  dim3 const grid(num_tokens);
  dim3 const block(std::min(hidden_size, 1024));

  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  VLLM_DISPATCH_FLOATING_TYPES(
      input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
        vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
            <<<grid, block, 0, stream>>>(
302
                out.data_ptr<FP8_TYPE>(), scales.data_ptr<float>(),
303
304
305
                input.data_ptr<scalar_t>(),
                scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
                hidden_size);
306
307
      });
}