per_token_group_quant_8bit.cu 5.32 KB
Newer Older
1
2
3
4
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h>

#include <cmath>
5
#include <flashinfer/vec_dtypes.cuh>
6
7
8

#include "utils.h"

9
10
11
12
13
14
15
16
__device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
  unsigned mask = 0xffff;

  val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
  val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
  val = fmaxf(val, __shfl_xor_sync(mask, val, 2));
  val = fmaxf(val, __shfl_xor_sync(mask, val, 1));
  return val;
17
18
}

19
20
template <typename T, typename DST_DTYPE>
__global__ void per_token_group_quant_8bit_kernel(
21
22
23
24
25
    const T* __restrict__ input,
    void* __restrict__ output_q,
    float* __restrict__ output_s,
    const int group_size,
    const int num_groups,
26
    const int groups_per_block,
27
    const float eps,
28
29
    const float min_8bit,
    const float max_8bit) {
30
31
32
  const int threads_per_group = 16;
  const int local_group_id = threadIdx.x / threads_per_group;
  const int lane_id = threadIdx.x % threads_per_group;
33

34
  const int block_group_id = blockIdx.x * groups_per_block;
35
  const int block_group_offset = (block_group_id + local_group_id) * group_size;
36
37
38

  float local_absmax = eps;

39
  const T* group_input = input + block_group_offset;
40
  DST_DTYPE* group_output = static_cast<DST_DTYPE*>(output_q) + block_group_offset;
41
42
43
44
45
46
47
48
49
50
  float* scale_output = output_s + (block_group_id + local_group_id);

  constexpr uint32_t vec_size = 16 / sizeof(T);
  using vec_t = flashinfer::vec_t<T, vec_size>;

  const int32_t num_vec_elems = group_size / vec_size;

  for (int32_t i = lane_id; i < num_vec_elems; i += 16) {
    vec_t input_vec;
    input_vec.cast_load(group_input + i * vec_size);
51

52
53
54
#pragma unroll
    for (uint32_t j = 0; j < vec_size; ++j) {
      float val = static_cast<float>(input_vec[j]);
55
56
57
      float abs_val = fabsf(val);
      local_absmax = fmaxf(local_absmax, abs_val);
    }
58
  }
59

60
  local_absmax = GroupReduceMax(local_absmax, lane_id);
61

62
  const float y_s = local_absmax / max_8bit;
63

64
65
66
  if (lane_id == 0) {
    *scale_output = y_s;
  }
67

68
69
70
  for (int32_t i = lane_id; i < num_vec_elems; i += 16) {
    vec_t input_vec;
    input_vec.cast_load(group_input + i * vec_size);
71

72
73
74
#pragma unroll
    for (uint32_t j = 0; j < vec_size; ++j) {
      float val = static_cast<float>(input_vec[j]);
75
76
      float q_val = fminf(fmaxf(val / y_s, min_8bit), max_8bit);
      group_output[i * vec_size + j] = DST_DTYPE(q_val);
77
78
79
80
    }
  }
}

81
void sgl_per_token_group_quant_8bit(
82
83
84
85
86
    torch::Tensor input,
    torch::Tensor output_q,
    torch::Tensor output_s,
    int64_t group_size,
    double eps,
87
88
    double min_8bit,
    double max_8bit) {
89
90
91
92
93
94
95
96
97
98
  CHECK_INPUT(input);
  CHECK_INPUT(output_q);
  CHECK_INPUT(output_s);

  const int num_groups = input.numel() / group_size;

  CHECK_EQ(input.numel() % group_size, 0);

  cudaStream_t stream = at::cuda::getCurrentCUDAStream();

99
100
101
102
103
104
105
106
107
108
109
110
111
112
  constexpr int THREADS_PER_GROUP = 16;

  int groups_per_block = 1;

  if (num_groups % 16 == 0) {
    groups_per_block = 16;
  } else if (num_groups % 8 == 0) {
    groups_per_block = 8;
  } else if (num_groups % 4 == 0) {
    groups_per_block = 4;
  } else if (num_groups % 2 == 0) {
    groups_per_block = 2;
  }

113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
  auto dst_type = output_q.scalar_type();
  const int num_blocks = num_groups / groups_per_block;
  const int num_threads = groups_per_block * THREADS_PER_GROUP;

#define LAUNCH_KERNEL(T, DST_DTYPE)                                              \
  do {                                                                           \
    dim3 grid(num_blocks);                                                       \
    dim3 block(num_threads);                                                     \
    per_token_group_quant_8bit_kernel<T, DST_DTYPE><<<grid, block, 0, stream>>>( \
        static_cast<T*>(input.data_ptr()),                                       \
        output_q.data_ptr(),                                                     \
        static_cast<float*>(output_s.data_ptr()),                                \
        group_size,                                                              \
        num_groups,                                                              \
        groups_per_block,                                                        \
        (float)eps,                                                              \
        (float)min_8bit,                                                         \
        (float)max_8bit);                                                        \
131
132
  } while (0)

133
  DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
134
135
136
137
138
139
    if (dst_type == at::ScalarType::Char) {
      LAUNCH_KERNEL(scalar_t, int8_t);
      return true;
    } else if (dst_type == at::ScalarType::Float8_e4m3fn) {
      LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn);
      return true;
140
    }
141
    return false;
142
  });
143
144

#undef LAUNCH_KERNEL
145
}
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167

void sgl_per_token_group_quant_int8(
    torch::Tensor input,
    torch::Tensor output_q,
    torch::Tensor output_s,
    int64_t group_size,
    double eps,
    double int8_min,
    double int8_max) {
  sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, int8_min, int8_max);
}

void sgl_per_token_group_quant_fp8(
    torch::Tensor input,
    torch::Tensor output_q,
    torch::Tensor output_s,
    int64_t group_size,
    double eps,
    double fp8_min,
    double fp8_max) {
  sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, fp8_min, fp8_max);
}