common.cu 6.47 KB
Newer Older
1
#include "common.cuh"
2
3
#include "dispatch_utils.h"

4
5
#include <c10/cuda/CUDAGuard.h>

6
7
8
9
10
#ifndef USE_ROCM
  #include <cub/cub.cuh>
#else
  #include <hipcub/hipcub.hpp>
#endif
11

12
13
namespace vllm {

14
15
template <typename scalar_t, typename fp8_type>
__global__ void scaled_fp8_quant_kernel(fp8_type* __restrict__ out,
16
17
18
19
20
21
22
23
                                        const scalar_t* __restrict__ input,
                                        const float* __restrict__ scale,
                                        int64_t num_elems) {
  int tid = blockDim.x * blockIdx.x + threadIdx.x;

  // Invert the scale so that we can use multiplications to avoid expensive
  // division.
  const float inverted_scale = 1.0f / (*scale);
24
25
  scaled_fp8_conversion_vec<scalar_t, true>(
      out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
26
27
}

28
template <typename scalar_t, typename fp8_type>
29
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
30
    fp8_type* __restrict__ out, float* __restrict__ scale,
31
32
    scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
    const int hidden_size) {
33
34
35
  int const tid = threadIdx.x;
  int const token_idx = blockIdx.x;

36
37
38
  // Use int64 to avoid overflowing an int32 when calculating this offset
  int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
  scalar_t const* __restrict__ token_input = &input[offset];
39
  fp8_type* __restrict__ token_output = &out[offset];
40
41

  // For vectorization, token_input and token_output pointers need to be
42
43
  // aligned at 32-byte and 16-byte addresses respectively.
  bool const can_vectorize = hidden_size % 16 == 0;
44
45
46
47
48
49
50

  float absmax_val = 0.0f;
  if (can_vectorize) {
    absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x);
  } else {
    for (int i = tid; i < hidden_size; i += blockDim.x) {
      float const x = static_cast<float>(token_input[i]);
51
      absmax_val = fmaxf(absmax_val, fabsf(x));
52
53
54
    }
  }

55
  using BlockReduce = cub::BlockReduce<float, 256>;
56
57
58
  __shared__ typename BlockReduce::TempStorage reduceStorage;
  float const block_absmax_val_maybe =
      BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
59
  __shared__ float token_scale;
60
  if (tid == 0) {
61
    if (scale_ub) {
62
      token_scale = fminf(block_absmax_val_maybe, *scale_ub);
63
64
65
66
    } else {
      token_scale = block_absmax_val_maybe;
    }
    // token scale computation
67
68
    token_scale = fmaxf(token_scale / quant_type_max_v<fp8_type>,
                        min_scaling_factor<fp8_type>::val());
69
    scale[token_idx] = token_scale;
70
71
72
  }
  __syncthreads();

73
  // Note that we don't use inverted scales so we can match FBGemm impl.
74
  if (can_vectorize) {
75
76
    scaled_fp8_conversion_vec<scalar_t, false>(
        token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
77
78
  } else {
    for (int i = tid; i < hidden_size; i += blockDim.x) {
79
      token_output[i] = scaled_fp8_conversion<false, fp8_type>(
80
          static_cast<float>(token_input[i]), token_scale);
81
82
83
84
    }
  }
}

85
}  // namespace vllm
86

87
88
89
void static_scaled_fp8_quant(torch::Tensor& out,          // [..., d]
                             torch::Tensor const& input,  // [..., d]
                             torch::Tensor const& scale)  // [1]
90
{
91
92
93
94
95
  int const block_size = 256;
  int const num_tokens = input.numel() / input.size(-1);
  int const num_elems = input.numel();
  dim3 const grid(num_tokens);
  dim3 const block(block_size);
96
97
98
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  VLLM_DISPATCH_FLOATING_TYPES(
99
100
101
102
103
104
105
106
      input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
        VLLM_DISPATCH_FP8_TYPES(
            out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
              vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>
                  <<<grid, block, 0, stream>>>(
                      out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
                      scale.data_ptr<float>(), num_elems);
            });
107
108
109
      });
}

110
111
112
void dynamic_scaled_fp8_quant(torch::Tensor& out,          // [..., d]
                              torch::Tensor const& input,  // [..., d]
                              torch::Tensor& scale)        // [1]
113
{
114
115
116
117
118
  int const block_size = 256;
  int const num_tokens = input.numel() / input.size(-1);
  int const num_elems = input.numel();
  dim3 const grid(num_tokens);
  dim3 const block(block_size);
119
120
121
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  VLLM_DISPATCH_FLOATING_TYPES(
122
123
124
125
126
127
128
129
130
131
132
133
      input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
        VLLM_DISPATCH_FP8_TYPES(
            out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
              vllm::segmented_max_reduction<scalar_t, fp8_t>
                  <<<grid, block, 0, stream>>>(scale.data_ptr<float>(),
                                               input.data_ptr<scalar_t>(),
                                               num_elems);
              vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>
                  <<<grid, block, 0, stream>>>(
                      out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
                      scale.data_ptr<float>(), num_elems);
            });
134
135
      });
}
136

137
138
139
140
void dynamic_per_token_scaled_fp8_quant(
    torch::Tensor& out,          // [..., d]
    torch::Tensor const& input,  // [..., d]
    torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
141
142
143
144
145
  TORCH_CHECK(input.is_contiguous());
  TORCH_CHECK(out.is_contiguous());

  int const hidden_size = input.size(-1);
  int const num_tokens = input.numel() / hidden_size;
146
  int const block_size = 256;
147
  dim3 const grid(num_tokens);
148
  dim3 const block(std::min(hidden_size, block_size));
149
150
151
152

  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  VLLM_DISPATCH_FLOATING_TYPES(
153
154
155
156
157
158
159
160
161
162
163
164
165
      input.scalar_type(),
      "dynamic_per_token_scaled_fp8_quant_kernel_scalar_type", [&] {
        VLLM_DISPATCH_FP8_TYPES(
            out.scalar_type(),
            "dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] {
              vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
                  <<<grid, block, 0, stream>>>(
                      out.data_ptr<fp8_t>(), scales.data_ptr<float>(),
                      input.data_ptr<scalar_t>(),
                      scale_ub.has_value() ? scale_ub->data_ptr<float>()
                                           : nullptr,
                      hidden_size);
            });
166
167
      });
}