common.cu 10.4 KB
Newer Older
1
#include <ATen/cuda/CUDAContext.h>
2
#include <torch/all.h>
3
4
5
6
7
8
9
#include <c10/cuda/CUDAGuard.h>

#include <cmath>

#include "cuda_compat.h"
#include "dispatch_utils.h"

10
11
#include "../../reduction_utils.cuh"

12
13
14
namespace vllm {

__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
15
16
17
18
19
  float old;
  old = (value >= 0)
            ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
            : __uint_as_float(
                  atomicMin((unsigned int*)addr, __float_as_uint(value)));
20

21
  return old;
22
23
}

24
25
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()

26
template <bool is_scale_inverted>
27
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
28
29
30
31
32
33
34
35
    float const val, float const scale) {
  float x = 0.0f;
  if constexpr (is_scale_inverted) {
    x = val * scale;
  } else {
    x = val / scale;
  }

36
37
38
39
  float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
  return static_cast<c10::Float8_e4m3fn>(r);
}

40
41
42
43
44
45
// Compute the absolute maximum m of the input tensor and store
// m / float8_e4m3::max() in *scale. Each thread block performs a
// reduction tree and the memory in scale is atomically updated.
// So to get the right answer, *scale needs to be initialized to
// a value <= 0.0 and we need to wait for all thread blocks to
// finish before consuming *scale.
46
47
48
49
template <typename scalar_t>
__global__ void segmented_max_reduction(float* __restrict__ scale,
                                        const scalar_t* __restrict__ input,
                                        int64_t num_elems) {
50
  __shared__ float cache[1024];
51
  int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

  // First store maximum for all values processes by
  // the current thread in cache[threadIdx.x]
  scalar_t tmp = 0.0;
  while (i < num_elems) {
    float x = static_cast<float>(input[i]);
    tmp = max(tmp, fabs(x));
    i += blockDim.x * gridDim.x;
  }
  cache[threadIdx.x] = tmp;

  __syncthreads();

  // Now perform parallel reduction within the thread block
  int ib = blockDim.x / 2;
  while (ib != 0) {
    if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
69
      cache[threadIdx.x] = cache[threadIdx.x + ib];
70
71
72
73
74
75
76
    }
    __syncthreads();
    ib /= 2;
  }
  // Finally, since cache[0] contains the maximum for this thread block,
  // atomically write the max to the target location
  if (threadIdx.x == 0) {
77
78
    atomicMaxFloat(scale,
                   cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
79
80
81
  }
}

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
template <typename scalar_t>
struct __align__(8) vec4_t {
  scalar_t x;
  scalar_t y;
  scalar_t z;
  scalar_t w;
};

typedef struct __align__(4) {
  c10::Float8_e4m3fn x;
  c10::Float8_e4m3fn y;
  c10::Float8_e4m3fn z;
  c10::Float8_e4m3fn w;
}
float8x4_t;

98
template <typename scalar_t>
99
100
101
102
103
104
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
                                int64_t const num_elems, int const tid,
                                int const step) {
  // Vectorized input/output to better utilize memory bandwidth.
  vec4_t<scalar_t> const* vectorized_in =
      reinterpret_cast<vec4_t<scalar_t> const*>(input);
105

106
  int64_t const num_vec_elems = num_elems >> 2;
107
108
109
  float absmax_val = 0.0f;

#pragma unroll 4
110
  for (int64_t i = tid; i < num_vec_elems; i += step) {
111
112
113
114
115
116
    vec4_t<scalar_t> in_vec = vectorized_in[i];
    absmax_val = max(absmax_val, fabs(in_vec.x));
    absmax_val = max(absmax_val, fabs(in_vec.y));
    absmax_val = max(absmax_val, fabs(in_vec.z));
    absmax_val = max(absmax_val, fabs(in_vec.w));
  }
117

118
  // Handle the remaining elements if num_elems is not divisible by 4
119
  for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
120
121
122
123
124
125
    absmax_val = max(absmax_val, fabs(input[i]));
  }

  return absmax_val;
}

126
template <typename scalar_t, bool is_scale_inverted>
127
128
__device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
                                          scalar_t const* __restrict__ input,
129
                                          float const scale,
130
131
                                          int64_t const num_elems,
                                          int const tid, int const step) {
132
  // Vectorized input/output to better utilize memory bandwidth.
133
134
  vec4_t<scalar_t> const* vectorized_in =
      reinterpret_cast<vec4_t<scalar_t> const*>(input);
135
136
  float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);

137
  int64_t const num_vec_elems = num_elems >> 2;
138
139

#pragma unroll 4
140
  for (int64_t i = tid; i < num_vec_elems; i += step) {
141
142
143
    vec4_t<scalar_t> in_vec = vectorized_in[i];
    float8x4_t out_vec;

144
145
146
147
148
149
150
151
    out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
        static_cast<float>(in_vec.x), scale);
    out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
        static_cast<float>(in_vec.y), scale);
    out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
        static_cast<float>(in_vec.z), scale);
    out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
        static_cast<float>(in_vec.w), scale);
152
153
154
155
    vectorized_out[i] = out_vec;
  }

  // Handle the remaining elements if num_elems is not divisible by 4
156
  for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
157
158
    out[i] = scaled_fp8_conversion<is_scale_inverted>(
        static_cast<float>(input[i]), scale);
159
160
161
  }
}

162
163
164
165
166
167
168
169
170
171
template <typename scalar_t>
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
                                        const scalar_t* __restrict__ input,
                                        const float* __restrict__ scale,
                                        int64_t num_elems) {
  int tid = blockDim.x * blockIdx.x + threadIdx.x;

  // Invert the scale so that we can use multiplications to avoid expensive
  // division.
  const float inverted_scale = 1.0f / (*scale);
172
173
  scaled_fp8_conversion_vec<scalar_t, true>(
      out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
174
175
176
177
178
}

template <typename scalar_t>
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
    c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale,
179
180
181
182
    scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
    const int hidden_size) {
  float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);

183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
  int const tid = threadIdx.x;
  int const token_idx = blockIdx.x;

  scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size];
  c10::Float8_e4m3fn* __restrict__ token_output = &out[token_idx * hidden_size];

  // For vectorization, token_input and token_output pointers need to be
  // aligned at 8-byte and 4-byte addresses respectively.
  bool const can_vectorize = hidden_size % 4 == 0;

  float absmax_val = 0.0f;
  if (can_vectorize) {
    absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x);
  } else {
    for (int i = tid; i < hidden_size; i += blockDim.x) {
      float const x = static_cast<float>(token_input[i]);
      absmax_val = max(absmax_val, fabs(x));
    }
  }

  float const block_absmax_val_maybe = blockReduceMax(absmax_val);
204
  __shared__ float token_scale;
205
  if (tid == 0) {
206
207
208
209
210
211
212
213
    if (scale_ub) {
      token_scale = min(block_absmax_val_maybe, *scale_ub);
    } else {
      token_scale = block_absmax_val_maybe;
    }
    // token scale computation
    token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
    scale[token_idx] = token_scale;
214
215
216
  }
  __syncthreads();

217
  // Note that we don't use inverted scales so we can match FBGemm impl.
218
  if (can_vectorize) {
219
220
    scaled_fp8_conversion_vec<scalar_t, false>(
        token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
221
222
  } else {
    for (int i = tid; i < hidden_size; i += blockDim.x) {
223
224
      token_output[i] = scaled_fp8_conversion<false>(
          static_cast<float>(token_input[i]), token_scale);
225
226
227
228
    }
  }
}

229
}  // namespace vllm
230

231
232
233
void static_scaled_fp8_quant(torch::Tensor& out,          // [..., d]
                             torch::Tensor const& input,  // [..., d]
                             torch::Tensor const& scale)  // [1]
234
235
236
237
238
239
240
241
{
  int64_t num_tokens = input.numel() / input.size(-1);
  int64_t num_elems = input.numel();
  dim3 grid(num_tokens);
  dim3 block(1024);
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  VLLM_DISPATCH_FLOATING_TYPES(
242
243
244
245
      input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
        vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
            out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
            scale.data_ptr<float>(), num_elems);
246
247
248
      });
}

249
250
251
void dynamic_scaled_fp8_quant(torch::Tensor& out,          // [..., d]
                              torch::Tensor const& input,  // [..., d]
                              torch::Tensor& scale)        // [1]
252
253
254
255
256
257
258
259
{
  int64_t num_tokens = input.numel() / input.size(-1);
  int64_t num_elems = input.numel();
  dim3 grid(num_tokens);
  dim3 block(1024);
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  VLLM_DISPATCH_FLOATING_TYPES(
260
261
262
263
264
265
      input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
        vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
            scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
        vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
            out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
            scale.data_ptr<float>(), num_elems);
266
267
      });
}
268

269
270
271
272
void dynamic_per_token_scaled_fp8_quant(
    torch::Tensor& out,          // [..., d]
    torch::Tensor const& input,  // [..., d]
    torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
  TORCH_CHECK(input.is_contiguous());
  TORCH_CHECK(out.is_contiguous());

  int const hidden_size = input.size(-1);
  int const num_tokens = input.numel() / hidden_size;
  dim3 const grid(num_tokens);
  dim3 const block(std::min(hidden_size, 1024));

  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  VLLM_DISPATCH_FLOATING_TYPES(
      input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
        vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
            <<<grid, block, 0, stream>>>(
                out.data_ptr<c10::Float8_e4m3fn>(), scales.data_ptr<float>(),
288
289
290
                input.data_ptr<scalar_t>(),
                scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
                hidden_size);
291
292
      });
}