per_token_group_quant_fp8.cu 4.76 KB
Newer Older
1
2
3
4
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h>

#include <cmath>
5
#include <flashinfer/vec_dtypes.cuh>
6
7
8
9
10

#include "utils.h"

using FP8_TYPE = c10::Float8_e4m3fn;

11
12
13
14
15
16
17
18
__device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
  unsigned mask = 0xffff;

  val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
  val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
  val = fmaxf(val, __shfl_xor_sync(mask, val, 2));
  val = fmaxf(val, __shfl_xor_sync(mask, val, 1));
  return val;
19
20
}

21
template <typename T, int GROUPS_PER_BLOCK = 16>
22
23
24
25
26
27
28
29
30
__global__ void per_token_group_quant_fp8_kernel(
    const T* __restrict__ input,
    void* __restrict__ output_q,
    float* __restrict__ output_s,
    const int group_size,
    const int num_groups,
    const float eps,
    const float fp8_min,
    const float fp8_max) {
31
32
33
  const int threads_per_group = 16;
  const int local_group_id = threadIdx.x / threads_per_group;
  const int lane_id = threadIdx.x % threads_per_group;
34

35
36
  const int block_group_id = blockIdx.x * GROUPS_PER_BLOCK;
  const int block_group_offset = (block_group_id + local_group_id) * group_size;
37
38
39

  float local_absmax = eps;

40
41
42
43
44
45
46
47
48
49
50
51
  const T* group_input = input + block_group_offset;
  FP8_TYPE* group_output = static_cast<FP8_TYPE*>(output_q) + block_group_offset;
  float* scale_output = output_s + (block_group_id + local_group_id);

  constexpr uint32_t vec_size = 16 / sizeof(T);
  using vec_t = flashinfer::vec_t<T, vec_size>;

  const int32_t num_vec_elems = group_size / vec_size;

  for (int32_t i = lane_id; i < num_vec_elems; i += 16) {
    vec_t input_vec;
    input_vec.cast_load(group_input + i * vec_size);
52

53
54
55
#pragma unroll
    for (uint32_t j = 0; j < vec_size; ++j) {
      float val = static_cast<float>(input_vec[j]);
56
57
58
      float abs_val = fabsf(val);
      local_absmax = fmaxf(local_absmax, abs_val);
    }
59
  }
60

61
  local_absmax = GroupReduceMax(local_absmax, lane_id);
62

63
  const float y_s = local_absmax / fp8_max;
64

65
66
67
  if (lane_id == 0) {
    *scale_output = y_s;
  }
68

69
70
71
  for (int32_t i = lane_id; i < num_vec_elems; i += 16) {
    vec_t input_vec;
    input_vec.cast_load(group_input + i * vec_size);
72

73
74
75
#pragma unroll
    for (uint32_t j = 0; j < vec_size; ++j) {
      float val = static_cast<float>(input_vec[j]);
76
      float q_val = fminf(fmaxf(val / y_s, fp8_min), fp8_max);
77
      group_output[i * vec_size + j] = FP8_TYPE(q_val);
78
79
80
81
    }
  }
}

82
83
84
85
86
87
88
89
void sgl_per_token_group_quant_fp8(
    torch::Tensor input,
    torch::Tensor output_q,
    torch::Tensor output_s,
    int64_t group_size,
    double eps,
    double fp8_min,
    double fp8_max) {
90
91
92
93
94
95
96
97
98
99
  CHECK_INPUT(input);
  CHECK_INPUT(output_q);
  CHECK_INPUT(output_s);

  const int num_groups = input.numel() / group_size;

  CHECK_EQ(input.numel() % group_size, 0);

  cudaStream_t stream = at::cuda::getCurrentCUDAStream();

100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
  constexpr int THREADS_PER_GROUP = 16;

  int groups_per_block = 1;

  if (num_groups % 16 == 0) {
    groups_per_block = 16;
  } else if (num_groups % 8 == 0) {
    groups_per_block = 8;
  } else if (num_groups % 4 == 0) {
    groups_per_block = 4;
  } else if (num_groups % 2 == 0) {
    groups_per_block = 2;
  }

#define LAUNCH_KERNEL(T, GPB)                                                          \
  do {                                                                                 \
    constexpr int GROUPS_PER_BLOCK = GPB;                                              \
    dim3 grid((num_groups + GROUPS_PER_BLOCK - 1) / GROUPS_PER_BLOCK);                 \
    dim3 block(GROUPS_PER_BLOCK* THREADS_PER_GROUP);                                   \
    per_token_group_quant_fp8_kernel<T, GROUPS_PER_BLOCK><<<grid, block, 0, stream>>>( \
        static_cast<T*>(input.data_ptr()),                                             \
        output_q.data_ptr(),                                                           \
        static_cast<float*>(output_s.data_ptr()),                                      \
        group_size,                                                                    \
        num_groups,                                                                    \
        (float)eps,                                                                    \
        (float)fp8_min,                                                                \
        (float)fp8_max);                                                               \
  } while (0)

130
  DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
131
132
133
134
135
136
137
138
139
140
141
    if (groups_per_block == 16) {
      LAUNCH_KERNEL(scalar_t, 16);
    } else if (groups_per_block == 8) {
      LAUNCH_KERNEL(scalar_t, 8);
    } else if (groups_per_block == 4) {
      LAUNCH_KERNEL(scalar_t, 4);
    } else if (groups_per_block == 2) {
      LAUNCH_KERNEL(scalar_t, 2);
    } else {
      LAUNCH_KERNEL(scalar_t, 1);
    }
142
143
    return true;
  });
144
145

#undef LAUNCH_KERNEL
146
}