cast.cu 4.52 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
#include "pytorch_extension_utils.h"

template <typename T>
struct ConvertToFP8 {
  static __device__ __nv_fp8_storage_t convert_to_fp8(T value) {
    return 0;
  }
};

template <>
struct ConvertToFP8<__nv_bfloat16> {
  static __device__ __nv_fp8_storage_t convert_to_fp8(__nv_bfloat16 value) {
    return __nv_cvt_bfloat16raw_to_fp8(value, __NV_SATFINITE, __NV_E4M3);
  }
};

template <>
struct ConvertToFP8<half> {
  static __device__ __nv_fp8_storage_t convert_to_fp8(half value) {
    return __nv_cvt_halfraw_to_fp8(value, __NV_SATFINITE, __NV_E4M3);
  }
};

template <typename T>
struct ConvertFromFloat {
  static __device__ T convert_from_float(float value) {
    return 0;
  }
};

template <>
struct ConvertFromFloat<__nv_bfloat16> {
  static __device__ __nv_bfloat16 convert_from_float(float value) {
    return __float2bfloat16(value);
  }
};

template <>
struct ConvertFromFloat<half> {
  static __device__ half convert_from_float(float value) {
    return __float2half(value);
  }
};

template <typename T>
__global__ void fused_downcast_kernel(
    const T* cache_k,
    const T* cache_v,
    const float* k_scale,
    const float* v_scale,
    __nv_fp8_storage_t* output_k,
    __nv_fp8_storage_t* output_v,
    const int input_sl,
    const int head,
    const int dim,
    const T max_fp8,
    const T min_fp8,
    const int64_t mult,
    const int64_t offset,
    const int64_t* loc) {
  // TODO: change name
  int token_idx = blockIdx.x;
  int thread_idx = threadIdx.x;
  int total_threads = blockDim.x;

  T k_scale_val = ConvertFromFloat<T>::convert_from_float(k_scale[0]);
  T v_scale_val = ConvertFromFloat<T>::convert_from_float(v_scale[0]);

  T k_scale_inv = static_cast<T>(1.f) / k_scale_val;
  T v_scale_inv = static_cast<T>(1.f) / v_scale_val;

  auto clamp = [&](T val) { return val > max_fp8 ? max_fp8 : (min_fp8 > val ? min_fp8 : val); };

  if (token_idx < input_sl) {
    int out_seq_idx = loc[token_idx];

#pragma unroll
    for (int i = thread_idx; i < head * dim; i += total_threads) {
      int in_idx = token_idx * head * dim + i;
      int out_idx = (out_seq_idx * mult + offset) * head * dim + i;

      T k_val = cache_k[in_idx] * k_scale_inv;
      k_val = clamp(k_val);
      output_k[out_idx] = ConvertToFP8<T>::convert_to_fp8(k_val);

      T v_val = cache_v[in_idx] * v_scale_inv;
      v_val = clamp(v_val);
      output_v[out_idx] = ConvertToFP8<T>::convert_to_fp8(v_val);
    }
  }
}

template <typename T>
void downcast_fp8_impl(
    at::Tensor& k,
    at::Tensor& v,
    at::Tensor& k_out,
    at::Tensor& v_out,
    at::Tensor& k_scale,
    at::Tensor& v_scale,
    at::Tensor& loc,
    int64_t mult,
    int64_t offset,
    cudaStream_t stream) {
  CHECK_INPUT(k);
  CHECK_INPUT(v);
  CHECK_INPUT(k_out);
  CHECK_INPUT(v_out);
  CHECK_INPUT(k_scale);
  CHECK_INPUT(v_scale);
  CHECK_INPUT(loc);

  int64_t input_sl = k.size(0);
  int64_t head = k.size(1);
  int64_t dim = k.size(2);

  dim3 grid(input_sl * head);
  int vec_size = 8;
  dim3 block(std::min(int(dim) / vec_size, 1024));

  const T max_fp8 = static_cast<T>(448.0f);
  const T min_fp8 = static_cast<T>(-448.0f);

  fused_downcast_kernel<T><<<grid, block, 0, stream>>>(
      static_cast<const T*>(k.data_ptr()),
      static_cast<const T*>(v.data_ptr()),
      static_cast<const float*>(k_scale.data_ptr()),
      static_cast<const float*>(v_scale.data_ptr()),
      static_cast<__nv_fp8_storage_t*>(k_out.data_ptr()),
      static_cast<__nv_fp8_storage_t*>(v_out.data_ptr()),
      input_sl,
      head,
      dim,
      max_fp8,
      min_fp8,
      mult,
      offset,
      static_cast<const int64_t*>(loc.data_ptr()));

  cudaError_t status = cudaGetLastError();
  TORCH_CHECK(status == cudaSuccess, "Kernel launch failed: " + std::string(cudaGetErrorString(status)));
}

void downcast_fp8(
    at::Tensor& k,
    at::Tensor& v,
    at::Tensor& k_out,
    at::Tensor& v_out,
    at::Tensor& k_scale,
    at::Tensor& v_scale,
    at::Tensor& loc,
    int64_t mult,
    int64_t offset,
    int64_t cuda_stream) {
  CHECK_INPUT(k);
  CHECK_INPUT(v);
  CHECK_INPUT(k_out);
  CHECK_INPUT(v_out);

  cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
  switch (k.scalar_type()) {
    case at::ScalarType::BFloat16:
      downcast_fp8_impl<__nv_bfloat16>(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, stream);
      break;
    case at::ScalarType::Half:
      downcast_fp8_impl<__half>(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, stream);
      break;
    default:
      TORCH_CHECK(false, "Unsupported input type for downcast_fp8. Expected bfloat16 or float16.");
  }
}