common.cuh 5.4 KB
Newer Older
1
2
#pragma once

3
#include "quantization/vectorization.cuh"
4
#include "quantization/utils.cuh"
5

6
7
#include <cmath>

8
#ifdef USE_ROCM
9
  #include "amd/quant_utils.cuh"
10
11
12
13
14
15
16
17
18
19
20
21
22
#endif

// Determines the preferred FP8 type for the current platform.
// Note that for CUDA this just returns true,
// but on ROCm it will check device props.
static bool is_fp8_ocp() {
#ifndef USE_ROCM
  return true;
#else
  auto dprops = at::cuda::getCurrentDeviceProperties();
  std::string device_arch = dprops->gcnArchName;
  size_t substring = device_arch.find("gfx94");
  return substring == std::string::npos;
23
#endif
24
25
}

26
27
28
29
30
31
32
33
34
35
36
37
namespace vllm {

__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
  float old;
  old = (value >= 0)
            ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
            : __uint_as_float(
                  atomicMin((unsigned int*)addr, __float_as_uint(value)));

  return old;
}

38
39
template <bool is_scale_inverted, typename fp8_type>
__device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
40
41
42
43
44
45
46
47
                                                          float const scale) {
  float x = 0.0f;
  if constexpr (is_scale_inverted) {
    x = val * scale;
  } else {
    x = val / scale;
  }

48
  float r =
49
      fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>));
50
#ifndef USE_ROCM
51
  return static_cast<fp8_type>(r);
52
53
#else
  // Use hardware cvt instruction for fp8 on rocm
54
  return fp8::cvt_c10<fp8_type>(r);
55
56
57
58
59
60
61
62
63
#endif
}

// Compute the absolute maximum m of the input tensor and store
// m / float8_e4m3::max() in *scale. Each thread block performs a
// reduction tree and the memory in scale is atomically updated.
// So to get the right answer, *scale needs to be initialized to
// a value <= 0.0 and we need to wait for all thread blocks to
// finish before consuming *scale.
64
template <typename scalar_t, typename fp8_type>
65
66
67
__global__ void segmented_max_reduction(float* __restrict__ scale,
                                        const scalar_t* __restrict__ input,
                                        int64_t num_elems) {
68
  __shared__ float cache[256];
69
70
71
72
73
74
75
  int64_t i = blockDim.x * blockIdx.x + threadIdx.x;

  // First store maximum for all values processes by
  // the current thread in cache[threadIdx.x]
  scalar_t tmp = 0.0;
  while (i < num_elems) {
    float x = static_cast<float>(input[i]);
76
    tmp = fmaxf(tmp, fabsf(x));
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    i += blockDim.x * gridDim.x;
  }
  cache[threadIdx.x] = tmp;

  __syncthreads();

  // Now perform parallel reduction within the thread block
  int ib = blockDim.x / 2;
  while (ib != 0) {
    if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
      cache[threadIdx.x] = cache[threadIdx.x + ib];
    }
    __syncthreads();
    ib /= 2;
  }
  // Finally, since cache[0] contains the maximum for this thread block,
  // atomically write the max to the target location
  if (threadIdx.x == 0) {
95
    atomicMaxFloat(scale, cache[0] / quant_type_max_v<fp8_type>);
96
97
98
99
100
101
102
  }
}

template <typename scalar_t>
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
                                int64_t const num_elems, int const tid,
                                int const step) {
103
104
  constexpr size_t VEC_SIZE = 16;
  using scalarxN_t = vec_n_t<scalar_t, VEC_SIZE>;
105
  // Vectorized input/output to better utilize memory bandwidth.
106
  auto const* vectorized_in = reinterpret_cast<scalarxN_t const*>(input);
107

108
109
  // num_elems / VEC_SIZE (which is 16)
  int64_t const num_vec_elems = num_elems >> 4;
110
111
  float absmax_val = 0.0f;

112
#pragma unroll
113
  for (int64_t i = tid; i < num_vec_elems; i += step) {
114
115
116
117
118
    scalarxN_t in_vec = vectorized_in[i];
#pragma unroll
    for (int j = 0; j < VEC_SIZE; ++j) {
      absmax_val = fmaxf(absmax_val, fabsf(in_vec.val[j]));
    }
119
120
  }

121
122
123
  // Handle the remaining elements if num_elems is not divisible by VEC_SIZE
  for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) {
    absmax_val = fmaxf(absmax_val, fabsf(input[i]));
124
125
126
127
128
  }

  return absmax_val;
}

129
130
template <typename scalar_t, bool is_scale_inverted, typename fp8_type>
__device__ void scaled_fp8_conversion_vec(fp8_type* __restrict__ out,
131
132
133
134
                                          scalar_t const* __restrict__ input,
                                          float const scale,
                                          int64_t const num_elems,
                                          int const tid, int const step) {
135
136
137
  constexpr size_t VEC_SIZE = 16;
  using scalarxN_t = vec_n_t<scalar_t, VEC_SIZE>;
  using float8xN_t = q8_n_t<fp8_type, VEC_SIZE>;
138
  // Vectorized input/output to better utilize memory bandwidth.
139
140
  auto const* vectorized_in = reinterpret_cast<scalarxN_t const*>(input);
  auto* vectorized_out = reinterpret_cast<float8xN_t*>(out);
141

142
143
  // num_elems / VEC_SIZE (which is 16)
  int64_t const num_vec_elems = num_elems >> 4;
144

145
#pragma unroll
146
  for (int64_t i = tid; i < num_vec_elems; i += step) {
147
148
149
150
151
152
153
154
    scalarxN_t in_vec = vectorized_in[i];
    float8xN_t out_vec;

#pragma unroll
    for (int j = 0; j < VEC_SIZE; ++j) {
      out_vec.val[j] = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
          static_cast<float>(in_vec.val[j]), scale);
    }
155
156
157
    vectorized_out[i] = out_vec;
  }

158
159
  // Handle the remaining elements if num_elems is not divisible by VEC_SIZE
  for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) {
160
    out[i] = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
161
162
163
164
        static_cast<float>(input[i]), scale);
  }
}

165
}  // namespace vllm