"tests/vscode:/vscode.git/clone" did not exist on "5458eb835d66323a11d4a252ad551d001ce00ac8"
common.cuh 6.5 KB
Newer Older
1
2
#pragma once

3
4
#include "quantization/vectorization.cuh"

5
#include <cmath>
6
#include <c10/core/ScalarType.h>
7
8
9

#ifndef USE_ROCM
  #include <c10/util/Float8_e4m3fn.h>
10
  #define MAYBE_HOST_DEVICE C10_HOST_DEVICE
11
#else
12
13
  #include <ATen/hip/HIPContext.h>
  #include <c10/util/Float8_e4m3fn.h>
14
  #include <c10/util/Float8_e4m3fnuz.h>
15
  #include "amd/quant_utils.cuh"
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
  // ROCm doesn't seem to need C10_HOST_DEVICE for static constexpr
  #define MAYBE_HOST_DEVICE
#endif

// Determines the preferred FP8 type for the current platform.
// Note that for CUDA this just returns true,
// but on ROCm it will check device props.
static bool is_fp8_ocp() {
#ifndef USE_ROCM
  return true;
#else
  auto dprops = at::cuda::getCurrentDeviceProperties();
  std::string device_arch = dprops->gcnArchName;
  size_t substring = device_arch.find("gfx94");
  return substring == std::string::npos;
31
#endif
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
}

template <typename T>
struct fp8_e4m3_adjusted_max;

template <>
struct fp8_e4m3_adjusted_max<c10::Float8_e4m3fn> {
  static constexpr c10::Float8_e4m3fn val() {
    return std::numeric_limits<c10::Float8_e4m3fn>::max();
  }
};

// Using the default max value from pytorch (240.0 0x7F) will cause accuracy
// issues when running dynamic quantization. Here use 224.0 0x7E for rocm.
template <>
struct fp8_e4m3_adjusted_max<c10::Float8_e4m3fnuz> {
  static constexpr c10::Float8_e4m3fnuz val() {
    return c10::Float8_e4m3fnuz(0x7E, c10::Float8_e4m3fnuz::from_bits());
  }
};

template <typename T>
MAYBE_HOST_DEVICE static constexpr T fp8_e4m3_adjusted_max_v =
    fp8_e4m3_adjusted_max<T>::val();
56
57
58
59
60
61
62
63
64
65
66
67
68

namespace vllm {

__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
  float old;
  old = (value >= 0)
            ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
            : __uint_as_float(
                  atomicMin((unsigned int*)addr, __float_as_uint(value)));

  return old;
}

69
70
template <bool is_scale_inverted, typename fp8_type>
__device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
71
72
73
74
75
76
77
78
                                                          float const scale) {
  float x = 0.0f;
  if constexpr (is_scale_inverted) {
    x = val * scale;
  } else {
    x = val / scale;
  }

79
80
  float r = fmax(-fp8_e4m3_adjusted_max_v<fp8_type>,
                 fmin(x, fp8_e4m3_adjusted_max_v<fp8_type>));
81
#ifndef USE_ROCM
82
  return static_cast<fp8_type>(r);
83
84
#else
  // Use hardware cvt instruction for fp8 on rocm
85
  return fp8::cvt_c10<fp8_type>(r);
86
87
88
89
90
91
92
93
94
#endif
}

// Compute the absolute maximum m of the input tensor and store
// m / float8_e4m3::max() in *scale. Each thread block performs a
// reduction tree and the memory in scale is atomically updated.
// So to get the right answer, *scale needs to be initialized to
// a value <= 0.0 and we need to wait for all thread blocks to
// finish before consuming *scale.
95
template <typename scalar_t, typename fp8_type>
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
__global__ void segmented_max_reduction(float* __restrict__ scale,
                                        const scalar_t* __restrict__ input,
                                        int64_t num_elems) {
  __shared__ float cache[1024];
  int64_t i = blockDim.x * blockIdx.x + threadIdx.x;

  // First store maximum for all values processes by
  // the current thread in cache[threadIdx.x]
  scalar_t tmp = 0.0;
  while (i < num_elems) {
    float x = static_cast<float>(input[i]);
    tmp = max(tmp, fabs(x));
    i += blockDim.x * gridDim.x;
  }
  cache[threadIdx.x] = tmp;

  __syncthreads();

  // Now perform parallel reduction within the thread block
  int ib = blockDim.x / 2;
  while (ib != 0) {
    if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
      cache[threadIdx.x] = cache[threadIdx.x + ib];
    }
    __syncthreads();
    ib /= 2;
  }
  // Finally, since cache[0] contains the maximum for this thread block,
  // atomically write the max to the target location
  if (threadIdx.x == 0) {
126
    atomicMaxFloat(scale, cache[0] / fp8_e4m3_adjusted_max_v<fp8_type>);
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
  }
}

template <typename scalar_t>
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
                                int64_t const num_elems, int const tid,
                                int const step) {
  // Vectorized input/output to better utilize memory bandwidth.
  vec4_t<scalar_t> const* vectorized_in =
      reinterpret_cast<vec4_t<scalar_t> const*>(input);

  int64_t const num_vec_elems = num_elems >> 2;
  float absmax_val = 0.0f;

#pragma unroll 4
  for (int64_t i = tid; i < num_vec_elems; i += step) {
    vec4_t<scalar_t> in_vec = vectorized_in[i];
    absmax_val = max(absmax_val, fabs(in_vec.x));
    absmax_val = max(absmax_val, fabs(in_vec.y));
    absmax_val = max(absmax_val, fabs(in_vec.z));
    absmax_val = max(absmax_val, fabs(in_vec.w));
  }

  // Handle the remaining elements if num_elems is not divisible by 4
  for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
    absmax_val = max(absmax_val, fabs(input[i]));
  }

  return absmax_val;
}

158
159
template <typename scalar_t, bool is_scale_inverted, typename fp8_type>
__device__ void scaled_fp8_conversion_vec(fp8_type* __restrict__ out,
160
161
162
163
                                          scalar_t const* __restrict__ input,
                                          float const scale,
                                          int64_t const num_elems,
                                          int const tid, int const step) {
164
  using float8x4_t = q8x4_t<fp8_type>;
165
  // Vectorized input/output to better utilize memory bandwidth.
166
167
  auto const* vectorized_in = reinterpret_cast<vec4_t<scalar_t> const*>(input);
  auto* vectorized_out = reinterpret_cast<float8x4_t*>(out);
168
169
170
171
172
173
174
175

  int64_t const num_vec_elems = num_elems >> 2;

#pragma unroll 4
  for (int64_t i = tid; i < num_vec_elems; i += step) {
    vec4_t<scalar_t> in_vec = vectorized_in[i];
    float8x4_t out_vec;

176
    out_vec.x = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
177
        static_cast<float>(in_vec.x), scale);
178
    out_vec.y = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
179
        static_cast<float>(in_vec.y), scale);
180
    out_vec.z = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
181
        static_cast<float>(in_vec.z), scale);
182
    out_vec.w = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
183
184
185
186
187
188
        static_cast<float>(in_vec.w), scale);
    vectorized_out[i] = out_vec;
  }

  // Handle the remaining elements if num_elems is not divisible by 4
  for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
189
    out[i] = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
190
191
192
193
        static_cast<float>(input[i]), scale);
  }
}

194
}  // namespace vllm