per_token_group_quant_8bit.cu 7.32 KB
Newer Older
1
2
3
4
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h>

#include <cmath>
5
#include <flashinfer/vec_dtypes.cuh>
6
7
8

#include "utils.h"

9
10
11
12
13
14
15
16
__device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
  unsigned mask = 0xffff;

  val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
  val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
  val = fmaxf(val, __shfl_xor_sync(mask, val, 2));
  val = fmaxf(val, __shfl_xor_sync(mask, val, 1));
  return val;
17
18
}

19
template <typename T, typename DST_DTYPE, bool IS_COLUMN_MAJOR = false>
20
__global__ void per_token_group_quant_8bit_kernel(
21
22
23
24
25
    const T* __restrict__ input,
    void* __restrict__ output_q,
    float* __restrict__ output_s,
    const int group_size,
    const int num_groups,
26
    const int groups_per_block,
27
    const float eps,
28
    const float min_8bit,
29
30
31
    const float max_8bit,
    const int scale_num_rows = 0,
    const int scale_stride = 0) {
32
33
34
  const int threads_per_group = 16;
  const int local_group_id = threadIdx.x / threads_per_group;
  const int lane_id = threadIdx.x % threads_per_group;
35

36
  const int block_group_id = blockIdx.x * groups_per_block;
37
38
  const int global_group_id = block_group_id + local_group_id;
  const int block_group_offset = global_group_id * group_size;
39
40
41

  float local_absmax = eps;

42
  const T* group_input = input + block_group_offset;
43
  DST_DTYPE* group_output = static_cast<DST_DTYPE*>(output_q) + block_group_offset;
44
45
46
47
48
49
50
51
52
  float* scale_output;

  if constexpr (IS_COLUMN_MAJOR) {
    const int row_idx = global_group_id / scale_num_rows;
    const int col_idx = global_group_id % scale_num_rows;
    scale_output = output_s + (col_idx * scale_stride + row_idx);
  } else {
    scale_output = output_s + global_group_id;
  }
53
54
55
56
57
58
59
60
61

  constexpr uint32_t vec_size = 16 / sizeof(T);
  using vec_t = flashinfer::vec_t<T, vec_size>;

  const int32_t num_vec_elems = group_size / vec_size;

  for (int32_t i = lane_id; i < num_vec_elems; i += 16) {
    vec_t input_vec;
    input_vec.cast_load(group_input + i * vec_size);
62

63
64
65
#pragma unroll
    for (uint32_t j = 0; j < vec_size; ++j) {
      float val = static_cast<float>(input_vec[j]);
66
67
68
      float abs_val = fabsf(val);
      local_absmax = fmaxf(local_absmax, abs_val);
    }
69
  }
70

71
  local_absmax = GroupReduceMax(local_absmax, lane_id);
72

73
  const float y_s = local_absmax / max_8bit;
74

75
76
77
  if (lane_id == 0) {
    *scale_output = y_s;
  }
78

79
80
81
  for (int32_t i = lane_id; i < num_vec_elems; i += 16) {
    vec_t input_vec;
    input_vec.cast_load(group_input + i * vec_size);
82

83
84
85
#pragma unroll
    for (uint32_t j = 0; j < vec_size; ++j) {
      float val = static_cast<float>(input_vec[j]);
86
87
      float q_val = fminf(fmaxf(val / y_s, min_8bit), max_8bit);
      group_output[i * vec_size + j] = DST_DTYPE(q_val);
88
89
90
91
    }
  }
}

92
void sgl_per_token_group_quant_8bit(
93
94
95
96
97
    torch::Tensor input,
    torch::Tensor output_q,
    torch::Tensor output_s,
    int64_t group_size,
    double eps,
98
99
    double min_8bit,
    double max_8bit) {
100
101
102
103
104
105
  CHECK_INPUT(input);
  CHECK_INPUT(output_q);

  const int num_groups = input.numel() / group_size;

  CHECK_EQ(input.numel() % group_size, 0);
106
  CHECK_EQ(output_s.dim(), 2);
107
108
109

  cudaStream_t stream = at::cuda::getCurrentCUDAStream();

110
111
112
113
114
115
116
117
118
119
120
121
122
123
  constexpr int THREADS_PER_GROUP = 16;

  int groups_per_block = 1;

  if (num_groups % 16 == 0) {
    groups_per_block = 16;
  } else if (num_groups % 8 == 0) {
    groups_per_block = 8;
  } else if (num_groups % 4 == 0) {
    groups_per_block = 4;
  } else if (num_groups % 2 == 0) {
    groups_per_block = 2;
  }

124
125
126
127
  auto dst_type = output_q.scalar_type();
  const int num_blocks = num_groups / groups_per_block;
  const int num_threads = groups_per_block * THREADS_PER_GROUP;

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
  const bool is_column_major = output_s.stride(0) < output_s.stride(1);
  const int scale_num_rows = output_s.size(1);
  const int scale_stride = output_s.stride(1);

#define LAUNCH_KERNEL(T, DST_DTYPE)                                                       \
  do {                                                                                    \
    dim3 grid(num_blocks);                                                                \
    dim3 block(num_threads);                                                              \
    if (is_column_major) {                                                                \
      per_token_group_quant_8bit_kernel<T, DST_DTYPE, true><<<grid, block, 0, stream>>>(  \
          static_cast<T*>(input.data_ptr()),                                              \
          output_q.data_ptr(),                                                            \
          static_cast<float*>(output_s.data_ptr()),                                       \
          group_size,                                                                     \
          num_groups,                                                                     \
          groups_per_block,                                                               \
          (float)eps,                                                                     \
          (float)min_8bit,                                                                \
          (float)max_8bit,                                                                \
          scale_num_rows,                                                                 \
          scale_stride);                                                                  \
    } else {                                                                              \
      per_token_group_quant_8bit_kernel<T, DST_DTYPE, false><<<grid, block, 0, stream>>>( \
          static_cast<T*>(input.data_ptr()),                                              \
          output_q.data_ptr(),                                                            \
          static_cast<float*>(output_s.data_ptr()),                                       \
          group_size,                                                                     \
          num_groups,                                                                     \
          groups_per_block,                                                               \
          (float)eps,                                                                     \
          (float)min_8bit,                                                                \
          (float)max_8bit);                                                               \
    }                                                                                     \
161
162
  } while (0)

163
  DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
164
165
166
167
168
169
    if (dst_type == at::ScalarType::Char) {
      LAUNCH_KERNEL(scalar_t, int8_t);
      return true;
    } else if (dst_type == at::ScalarType::Float8_e4m3fn) {
      LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn);
      return true;
170
    }
171
    return false;
172
  });
173
174

#undef LAUNCH_KERNEL
175
}
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197

void sgl_per_token_group_quant_int8(
    torch::Tensor input,
    torch::Tensor output_q,
    torch::Tensor output_s,
    int64_t group_size,
    double eps,
    double int8_min,
    double int8_max) {
  sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, int8_min, int8_max);
}

void sgl_per_token_group_quant_fp8(
    torch::Tensor input,
    torch::Tensor output_q,
    torch::Tensor output_s,
    int64_t group_size,
    double eps,
    double fp8_min,
    double fp8_max) {
  sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, fp8_min, fp8_max);
}