per_token_group_quant_fp8.cu 2.99 KB
Newer Older
1
2
3
4
5
6
7
8
9
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h>

#include <cmath>

#include "utils.h"

using FP8_TYPE = c10::Float8_e4m3fn;

10
11
12
13
14
15
__device__ __forceinline__ float GroupReduceMax(volatile float* smem, const int tid) {
  smem[tid] = fmaxf(smem[tid], smem[tid + 8]);
  if (tid < 4) smem[tid] = fmaxf(smem[tid], smem[tid + 4]);
  if (tid < 2) smem[tid] = fmaxf(smem[tid], smem[tid + 2]);
  if (tid < 1) smem[tid] = fmaxf(smem[tid], smem[tid + 1]);
  return smem[0];
16
17
18
}

template <typename T>
19
20
21
22
23
24
25
26
27
__global__ void per_token_group_quant_fp8_kernel(
    const T* __restrict__ input,
    void* __restrict__ output_q,
    float* __restrict__ output_s,
    const int group_size,
    const int num_groups,
    const float eps,
    const float fp8_min,
    const float fp8_max) {
28
29
  const int groups_per_block = 16;
  const int block_group_id = blockIdx.x * groups_per_block;
30
31
32
  const int tid = threadIdx.x;
  const int local_group_id = tid / 16;
  const int local_tid = tid % 16;
33

34
  __shared__ float s_absmax[16][17];
35
36
37

  float local_absmax = eps;

38
39
40
41
  if (block_group_id + local_group_id < num_groups) {
    const T* group_input = input + (block_group_id + local_group_id) * group_size;
    FP8_TYPE* group_output = static_cast<FP8_TYPE*>(output_q) + (block_group_id + local_group_id) * group_size;
    float* scale_output = output_s + block_group_id + local_group_id;
42

43
44
    for (int i = local_tid; i < group_size; i += 16) {
      float val = static_cast<float>(group_input[i]);
45
46
47
48
      float abs_val = fabsf(val);
      local_absmax = fmaxf(local_absmax, abs_val);
    }

49
50
    s_absmax[local_group_id][local_tid] = local_absmax;
    __syncthreads();
51

52
53
54
55
    if (local_tid < 8) {
      GroupReduceMax(&s_absmax[local_group_id][0], local_tid);
    }
    __syncthreads();
56

57
58
    const float group_absmax = s_absmax[local_group_id][0];
    const float y_s = group_absmax / fp8_max;
59

60
61
62
    if (local_tid == 0) {
      *scale_output = y_s;
    }
63

64
65
    for (int i = local_tid; i < group_size; i += 16) {
      float val = static_cast<float>(group_input[i]);
66
      float q_val = fminf(fmaxf(val / y_s, fp8_min), fp8_max);
67
      group_output[i] = FP8_TYPE(q_val);
68
69
70
71
    }
  }
}

72
73
74
75
76
77
78
79
void sgl_per_token_group_quant_fp8(
    torch::Tensor input,
    torch::Tensor output_q,
    torch::Tensor output_s,
    int64_t group_size,
    double eps,
    double fp8_min,
    double fp8_max) {
80
81
82
83
84
85
86
87
88
  CHECK_INPUT(input);
  CHECK_INPUT(output_q);
  CHECK_INPUT(output_s);

  const int num_groups = input.numel() / group_size;

  CHECK_EQ(input.numel() % group_size, 0);

  dim3 grid((num_groups + 15) / 16);
89
  dim3 block(256);
90
91
92
93
94

  cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
    per_token_group_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
95
96
97
98
99
100
101
102
        static_cast<scalar_t*>(input.data_ptr()),
        output_q.data_ptr(),
        static_cast<float*>(output_s.data_ptr()),
        group_size,
        num_groups,
        (float)eps,
        (float)fp8_min,
        (float)fp8_max);
103
104
105
    return true;
  });
}