per_tensor_quant_fp8.cu 4.08 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h>

#include <cmath>
#include <cub/block/block_reduce.cuh>
#include <flashinfer/vec_dtypes.cuh>

#include "utils.h"

template <typename T>
11
12
__global__ void
per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s, const int64_t num_elements) {
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
  float max_value = 0.0f;
  unsigned int tid = threadIdx.x;
  unsigned int gid = blockIdx.x * blockDim.x + threadIdx.x;
  const int grid_size = blockDim.x * gridDim.x;

  constexpr uint32_t vec_size = 16 / sizeof(T);
  using vec_t = flashinfer::vec_t<T, vec_size>;

  const int32_t num_vec_elems = num_elements / vec_size;

  for (int32_t i = gid; i < num_vec_elems; i += grid_size) {
    vec_t input_vec;
    input_vec.cast_load(input + i * vec_size);

#pragma unroll
    for (uint32_t j = 0; j < vec_size; ++j) {
      float val = static_cast<float>(input_vec[j]);
      max_value = fmaxf(max_value, fabsf(val));
    }
  }

  const int32_t remaining_start = num_vec_elems * vec_size;
  for (int32_t idx = remaining_start + gid; idx < num_elements; idx += grid_size) {
    float val = static_cast<float>(input[idx]);
    max_value = fmaxf(max_value, fabsf(val));
  }

40
  max_value = blockReduceMax(max_value);
41
42
43
44
45
46
47

  if (tid == 0) {
    atomicMaxFloat(output_s, max_value / FP8_E4M3_MAX);
  }
}

template <typename T>
48
__global__ void per_tensor_quant_fp8_kernel(
49
50
51
52
    const T* __restrict__ input,
    FP8_TYPE* __restrict__ output,
    const float* __restrict__ scale,
    const int64_t num_elements) {
53
54
  const int gid = blockIdx.x * blockDim.x + threadIdx.x;
  const int grid_size = blockDim.x * gridDim.x;
55
  const float scale_val = 1.0f / (*scale);
56

57
58
59
60
  // We want to store 128 bits of data at a time. 16 = 128 / 8 bits
  // Load is already vectorized, so 16 elements work for T.
  const uint32_t VEC_SIZE = 16;
  using vec_t = flashinfer::vec_t<T, VEC_SIZE>;
61

62
  const int32_t num_vec_elems = num_elements / VEC_SIZE;
63
64
65

  for (int32_t i = gid; i < num_vec_elems; i += grid_size) {
    vec_t input_vec;
66
    input_vec.cast_load(input + i * VEC_SIZE);
67

68
    FP8_TYPE output_arr[VEC_SIZE];
69
#pragma unroll
70
    for (uint32_t j = 0; j < VEC_SIZE; ++j) {
71
72
73
74
75
      float val = fmax(fmin(static_cast<float>(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
#ifndef USE_ROCM
      output_arr[j] = static_cast<FP8_TYPE>(val);
#else
      output_arr[j] = c10::Float8_e4m3fnuz(
76
          __hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
77
78
79
          c10::Float8_e4m3fnuz::from_bits());
#endif
    }
80
    *(uint4*)(output + i * VEC_SIZE) = *(uint4*)output_arr;
81
82
  }

83
  const int32_t remaining_start = num_vec_elems * VEC_SIZE;
84
85
86
87
88
89
  for (int32_t idx = remaining_start + gid; idx < num_elements; idx += grid_size) {
    float val = fmax(-FP8_E4M3_MAX, fmin(static_cast<float>(input[idx]) * scale_val, FP8_E4M3_MAX));
#ifndef USE_ROCM
    output[idx] = static_cast<FP8_TYPE>(val);
#else
    output[idx] = c10::Float8_e4m3fnuz(
90
        __hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        c10::Float8_e4m3fnuz::from_bits());
#endif
  }
}

void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s, bool is_static) {
  CHECK_INPUT(input);
  CHECK_INPUT(output_q);
  CHECK_INPUT(output_s);

  const int block_size = 256;
  const int num_elements = input.numel();
  const int num_blocks = min((num_elements + block_size - 1) / block_size, 1024);

  dim3 grid(num_blocks);
  dim3 block(block_size);

  cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
    if (is_static == false) {
      per_tensor_absmax_kernel<scalar_t><<<grid, block, 0, stream>>>(
          static_cast<scalar_t*>(input.data_ptr()), static_cast<float*>(output_s.data_ptr()), num_elements);
    }
115

116
    per_tensor_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
117
118
119
120
        static_cast<scalar_t*>(input.data_ptr()),
        static_cast<FP8_TYPE*>(output_q.data_ptr()),
        static_cast<float*>(output_s.data_ptr()),
        num_elements);
121
122
123
    return true;
  });
}