// SPDX-License-Identifier: MIT #include "aiter_hip_common.h" #include "dispatch_utils.h" #include "hip_reduce.h" #include "quant_common.cuh" #include "rocprim/rocprim.hpp" #include "vec_convert.h" #include #include const int32_t BlockSize = 256; const int32_t groupQuantBlockSize = 64; namespace aiter { template __global__ void dynamic_per_group_scaled_quant_kernel(DTYPE_O* __restrict__ out, float* __restrict__ scale, DTYPE_I const* __restrict__ input, float const* __restrict__ scale_ub, const int32_t group_size, int64_t ori_rows, int32_t ori_cols, int32_t ori_row_stride, bool shuffle_scale = true, int32_t const* __restrict__ num_rows = nullptr, const int32_t num_cols_factor = 1) { auto fp4_scale_shuffle_id = [](int32_t scaleN_pad, int32_t x, int32_t y) { return (x / 32 * scaleN_pad) * 32 + (y / 8) * 256 + (y % 4) * 64 + (x % 16) * 4 + (y % 8) / 4 * 2 + (x % 32) / 16; }; if(num_rows != nullptr) { ori_rows = *num_rows * num_cols_factor; } int num_thread_per_group = group_size / thread_data_size; int64_t row_offset = blockIdx.x * groupQuantBlockSize; int64_t groupId = (row_offset + threadIdx.x) / num_thread_per_group; int32_t scaleN = ori_cols / group_size; int32_t scaleN_pad = (std::is_same_v && shuffle_scale) ? (((scaleN + 7) / 8) * 8) : scaleN; int64_t x = groupId / scaleN_pad; int32_t y = groupId % scaleN_pad; if constexpr(std::is_same_v) { if(x >= ori_rows || y >= scaleN) { // if (shuffle_scale && threadIdx.x % num_thread_per_group == 0) // { // auto *tmp = reinterpret_cast(scale); // groupId = fp4_scale_shuffle_id(scaleN_pad, x, y); // tmp[groupId] = 0x7f; // } return; } } else { if(x >= ori_rows) return; } row_offset = x * ori_row_stride + y * group_size; using vec_i = ck_tile::vec_t; static constexpr int32_t vec_size_o = std::is_same_v ? thread_data_size / 2 : thread_data_size; using vec_o = ck_tile::vec_t; const float inverted_DTYPE_MAX = std::is_same_v ? 0.25 : (1. / ck_tile::type_convert(ck_tile::numeric::max())); // static constexpr int32_t ooba_i = 4 / sizeof(DTYPE_I); static constexpr int32_t ooba_o = 4 / sizeof(DTYPE_O); // const int32_t oob_i = (cols + ooba_i - 1) / ooba_i * ooba_i; const int64_t oob_o = (ori_rows * ori_cols + ooba_o - 1) / ooba_o * ooba_o; // auto buffer_i = ck_tile::make_buffer_view(input + // row_offset, oob_i); buffer_i.init_raw(); auto const* input_vecs = reinterpret_cast(input + row_offset); // vec_i thread_data = buffer_i.template get(vec_idx * vec_size_i, 0, true); vec_i thread_data = input_vecs[threadIdx.x % num_thread_per_group]; float absMax = 1e-10f; for(size_t j = 0; j < thread_data_size; j++) { absMax = max(absMax, abs(ck_tile::type_convert(thread_data[j]))); } absMax = multithread_reduce(absMax, hipcub::Max(), num_thread_per_group); auto fp4_scale = [](float tmp) { uint32_t u32 = ck_tile::bit_cast(tmp); uint32_t exponent = (u32 >> 23) & 0b11111111; if(exponent == 0b11111111) { return ck_tile::bit_cast(exponent << 23); } if(((u32 & 0x400000)) && (((u32 & 0x200000)) || ((u32 & 0x1FFFFF)) || (exponent))) exponent += 1; return ck_tile::bit_cast(exponent << 23); }; float inverted_scale = std::is_same_v ? fp4_scale(absMax) * inverted_DTYPE_MAX : absMax * inverted_DTYPE_MAX; row_offset = std::is_same_v ? groupId * group_size / 2 + (threadIdx.x % num_thread_per_group) * vec_size_o : groupId * group_size + (threadIdx.x % num_thread_per_group) * vec_size_o; if(threadIdx.x % num_thread_per_group == 0) { if constexpr(std::is_same_v) { auto* tmp = reinterpret_cast(scale); uint8_t exponent = (ck_tile::bit_cast(inverted_scale) >> 23) & 0b11111111; if(shuffle_scale) { groupId = fp4_scale_shuffle_id(scaleN_pad, x, y); } tmp[groupId] = exponent; } else { if(shuffle_scale) { groupId = y * ori_rows + x; } scale[groupId] = inverted_scale; } } inverted_scale = std::is_same_v ? inverted_scale : 1.0f / inverted_scale; using DTYPE_STORE = typename ck_tile::vector_traits::scalar_type; auto* out_ptr = reinterpret_cast(out); auto buffer_o = ck_tile::make_buffer_view(out_ptr, oob_o); buffer_o.init_raw(); auto out_s = ck_tile::vec_convert(thread_data, inverted_scale) .template get_as(); if constexpr(thread_data_size <= 16) { buffer_o.template set(row_offset, 0, true, out_s); } else { static constexpr int32_t o_step = std::is_same_v ? 8 : 16; assert(thread_data_size % 16 == 0); using vecT = ck_tile::vec_t; auto vec = out_s.template get_as(); static constexpr int32_t num_iter = thread_data_size / 16; for(size_t j = 0; j < num_iter; j++) { buffer_o.template set(row_offset + j * o_step, 0, true, vec[j]); } } } template __device__ std::tuple data_to_per_row_scale(const DTYPE_I* __restrict__ input, const int32_t cols) { static constexpr int32_t vec_size_i = thread_data_size == 0 ? 16 / sizeof(DTYPE_O) : thread_data_size; static constexpr int32_t vec_size_o = std::is_same_v ? vec_size_i / 2 : vec_size_i; using vec_i = ck_tile::vec_t; const float inverted_DTYPE_MAX = std::is_same_v ? 0.25 : (1. / ck_tile::type_convert(ck_tile::numeric::max())); const int64_t row_offset = blockIdx.x * cols; auto const* ptr_i = reinterpret_cast(input + row_offset); auto const* input_vecs = reinterpret_cast(ptr_i); static constexpr int32_t ooba_i = 4 / sizeof(DTYPE_I); const int32_t oob_i = (cols + ooba_i - 1) / ooba_i * ooba_i; auto buffer_i = ck_tile::make_buffer_view(ptr_i, oob_i); buffer_i.init_raw(); // double load core loop start const int32_t num_elems_tail = cols % vec_size_i; const int32_t num_vecs = (cols + vec_size_i - 1) / vec_size_i; vec_i vec_cur; int32_t vec_idx = threadIdx.x; int32_t vec_stride = BlockSize; static constexpr int32_t max_vec_size_i = 16 / sizeof(DTYPE_I); static constexpr int32_t vec_i_iter = vec_size_i > max_vec_size_i ? vec_size_i / max_vec_size_i : 1; if(vec_idx < num_vecs) { #pragma unroll for (int i=0; i < vec_i_iter; i++) { if constexpr (vec_size_i > max_vec_size_i) { using max_vec_i = ck_tile::vec_t; max_vec_i vec_tmp; vec_tmp = buffer_i.template get(vec_idx * vec_size_i ,i * max_vec_size_i, true); #pragma unroll for(int j = 0; j < max_vec_size_i; j++) { vec_cur[i * max_vec_size_i +j] = vec_tmp[j]; } } else { vec_cur = buffer_i.template get(vec_idx * vec_size_i, 0, true); } } } float absMax = 0.f; if constexpr(thread_data_size == 0) { vec_i vec_nxt; for(vec_idx += vec_stride; vec_idx < num_vecs; vec_idx += vec_stride) { vec_nxt = buffer_i.template get(vec_idx * vec_size_i, 0, true); for(size_t j = 0; j < vec_size_i; j++) { absMax = max(absMax, abs(ck_tile::type_convert(vec_cur[j]))); } vec_cur = vec_nxt; } vec_idx -= vec_stride; } if(vec_idx < num_vecs) { #pragma unroll for(size_t j = 0; j < vec_size_i; j++) { absMax = max(absMax, abs(ck_tile::type_convert(vec_cur[j]))); } } // double load core loop end // using BlockReduce = hipcub::BlockReduce; // __shared__ typename BlockReduce::TempStorage temp_storage; // absMax = BlockReduce(temp_storage).Reduce(absMax, hipcub::Max()); absMax = block_reduce(absMax, hipcub::Max()); auto fp4_scale = [](float tmp) { uint32_t u32 = ck_tile::bit_cast(tmp); uint32_t exponent = (u32 >> 23) & 0b11111111; if(exponent == 0b11111111) { return ck_tile::bit_cast(exponent << 23); } if(((u32 & 0x400000)) && (((u32 & 0x200000)) || ((u32 & 0x1FFFFF)) || (exponent))) exponent += 1; return ck_tile::bit_cast(exponent << 23); }; float row_scale = std::is_same_v ? fp4_scale(absMax) * inverted_DTYPE_MAX : absMax * inverted_DTYPE_MAX; return std::make_tuple(row_scale, reinterpret_cast(&vec_cur)); } template __global__ void data_to_scale_kernel(float* __restrict__ scale, const DTYPE_I* __restrict__ input, const int cols) { auto res = data_to_per_row_scale(input, cols); float row_scale = std::get<0>(res); if(threadIdx.x == 0) { vllm::atomicMaxFloat(scale, row_scale); } } template __device__ void scaled_quant_impl(DTYPE_O* __restrict__ out, const DTYPE_I* __restrict__ input, const float* __restrict__ scale, const int32_t cols) { const float inverted_scale = std::is_same_v ? (*scale) : 1.0f / (*scale); static constexpr int32_t vec_size_i = 16 / sizeof(DTYPE_O); static constexpr int32_t vec_size_o = std::is_same_v ? vec_size_i / 2 : vec_size_i; using vec_i = ck_tile::vec_t; using vec_o = ck_tile::vec_t; using DTYPE_STORE = typename ck_tile::vector_traits::scalar_type; const int64_t row_offset = blockIdx.x * cols; auto const* ptr_i = reinterpret_cast(input + row_offset); auto const* input_vecs = reinterpret_cast(ptr_i); auto* ptr_o = std::is_same_v ? reinterpret_cast(out + row_offset / 2) : reinterpret_cast(out + row_offset); auto* out_vecs = reinterpret_cast(ptr_o); static constexpr int32_t ooba_i = 4 / sizeof(DTYPE_I); static constexpr int32_t ooba_o = 4 / sizeof(DTYPE_O); const int32_t oob_i = (cols + ooba_i - 1) / ooba_i * ooba_i; const int32_t oob_o = (cols + ooba_o - 1) / ooba_o * ooba_o; auto buffer_i = ck_tile::make_buffer_view(ptr_i, oob_i); buffer_i.init_raw(); auto buffer_o = ck_tile::make_buffer_view(ptr_o, oob_o); buffer_o.init_raw(); // double load core loop start const int32_t num_elems_tail = cols % vec_size_i; const int32_t num_vecs = (cols + vec_size_i - 1) / vec_size_i; const int32_t tail_thread = num_vecs % BlockSize; vec_i vec_nxt; vec_i vec_cur; // size_t vec_idx = threadIdx.x * vec_size_i; // size_t vec_stride = BlockSize * vec_size_i; int32_t vec_idx = threadIdx.x; int32_t vec_stride = BlockSize; if(vec_idx < num_vecs) { vec_cur = buffer_i.template get(vec_idx * vec_size_i, 0, true); } for(vec_idx += vec_stride; vec_idx < num_vecs; vec_idx += vec_stride) { vec_nxt = buffer_i.template get(vec_idx * vec_size_i, 0, true); buffer_o.template set( (vec_idx - vec_stride) * vec_size_o, 0, true, ck_tile::vec_convert(vec_cur, inverted_scale) .template get_as()); vec_cur = vec_nxt; } if(vec_idx - vec_stride < num_vecs) { buffer_o.template set( (vec_idx - vec_stride) * vec_size_o, 0, true, ck_tile::vec_convert(vec_cur, inverted_scale) .template get_as()); } // double load core loop end } template __device__ void scaled_quant_vgpr_impl(DTYPE_O* __restrict__ out, DTYPE_I* __restrict__ input, const float* __restrict__ scale, const int cols) { const float inverted_scale = std::is_same_v ? (*scale) : 1.0f / (*scale); static constexpr int32_t vec_size_i = thread_data_size; static constexpr int32_t vec_size_o = std::is_same_v ? vec_size_i / 2 : vec_size_i; using vec_i = ck_tile::vec_t; using vec_o = ck_tile::vec_t; using DTYPE_STORE = typename ck_tile::vector_traits::scalar_type; const int64_t row_offset = blockIdx.x * cols; auto const* ptr_i = reinterpret_cast(input); auto const* input_vecs = reinterpret_cast(ptr_i); auto* out_ptr = reinterpret_cast(out); auto* ptr_o = std::is_same_v ? reinterpret_cast(out + row_offset / 2) : reinterpret_cast(out + row_offset); static constexpr int32_t ooba_i = 4 / sizeof(DTYPE_I); static constexpr int32_t ooba_o = 4 / sizeof(DTYPE_O); const int32_t oob_i = (cols + ooba_i - 1) / ooba_i * ooba_i; const int32_t oob_o = (cols + ooba_o - 1) / ooba_o * ooba_o; auto buffer_o = ck_tile::make_buffer_view(ptr_o, oob_o); buffer_o.init_raw(); const int32_t num_vecs = (cols + vec_size_i - 1) / vec_size_i; if(threadIdx.x < num_vecs) { auto out = ck_tile::vec_convert(*input_vecs, inverted_scale) .template get_as(); if constexpr(vec_size_i <= 16) { buffer_o.template set(threadIdx.x * vec_size_o, 0, true, out); } else { static constexpr int32_t o_step = std::is_same_v ? 8 : 16; assert(vec_size_i % 16 == 0); using vecT = ck_tile::vec_t; auto vec = out.template get_as(); static constexpr int32_t num_iter = vec_size_i / 16; for(size_t j = 0; j < num_iter; j++) { buffer_o.template set(threadIdx.x * vec_size_o + j * o_step, 0, true, vec[j]); } } } } template __global__ void scaled_quant_kernel(DTYPE_O* __restrict__ out, const DTYPE_I* __restrict__ input, const float* __restrict__ scale, const int cols) { scaled_quant_impl(out, input, scale, cols); } template __global__ void dynamic_per_token_scaled_quant_kernel(DTYPE_O* __restrict__ out, float* __restrict__ scale, DTYPE_I* __restrict__ input, float const* __restrict__ scale_ub, const int32_t cols, int32_t const* __restrict__ num_rows = nullptr, const int32_t num_rows_factor = 1) { const int token_idx = blockIdx.x; if(num_rows != nullptr) { int32_t rows = *num_rows * num_rows_factor; if(token_idx >= rows) return; } auto res = data_to_per_row_scale(input, cols); float row_scale = std::get<0>(res); DTYPE_I* vec_ptr = std::get<1>(res); if(threadIdx.x == 0) { if constexpr(std::is_same_v) { auto* tmp = reinterpret_cast(scale); uint8_t exponent = (ck_tile::bit_cast(row_scale) >> 23) & 0b11111111; tmp[token_idx] = exponent; } else { scale[token_idx] = row_scale; } } if constexpr(thread_data_size == 0) { scaled_quant_impl(out, input, &row_scale, cols); } else { scaled_quant_vgpr_impl(out, vec_ptr, &row_scale, cols); } } template __device__ std::tuple smooth_data_to_per_row_scale(const DTYPE_I* __restrict__ input, const float* __restrict__ smooth_scale, const int32_t* __restrict__ smooth_scale_map, const int32_t cols, const int32_t token_idx) { static constexpr int32_t vec_size_i = thread_data_size == 0 ? 16 / sizeof(DTYPE_O) : thread_data_size; static constexpr int32_t vec_size_o = std::is_same_v ? vec_size_i / 2 : vec_size_i; using vec_i = ck_tile::vec_t; using vec_s = ck_tile::vec_t; const float inverted_DTYPE_MAX = std::is_same_v ? 0.25 : (1. / ck_tile::type_convert(ck_tile::numeric::max())); const int32_t smscale_map_idx = smooth_scale_map == nullptr ? 0 : smooth_scale_map[blockIdx.x]; const int64_t row_offset = token_idx * cols; auto const* ptr_i = reinterpret_cast(input + row_offset); auto const* input_vecs = reinterpret_cast(ptr_i); static constexpr int32_t ooba_i = 4 / sizeof(DTYPE_I); const int32_t oob_i = (cols + ooba_i - 1) / ooba_i * ooba_i; auto buffer_i = ck_tile::make_buffer_view(ptr_i, oob_i); buffer_i.init_raw(); auto const* ptr_smscale = reinterpret_cast(smooth_scale + smscale_map_idx * cols); auto const* smscale_vecs = reinterpret_cast(ptr_smscale); auto buffer_s = ck_tile::make_buffer_view(ptr_smscale, cols); buffer_s.init_raw(); const int32_t num_vecs = (cols + vec_size_i - 1) / vec_size_i; vec_i vec_cur; vec_s smscale_cur; int32_t vec_idx = threadIdx.x; float absMax = 0.f; if(vec_idx < num_vecs) { vec_cur = buffer_i.template get(vec_idx * vec_size_i, 0, true); smscale_cur = buffer_s.template get(vec_idx * vec_size_i, 0, true); #pragma unroll for(size_t j = 0; j < vec_size_i; j++) { smscale_cur[j] = ck_tile::type_convert(vec_cur[j]) * smscale_cur[j]; absMax = max(absMax, abs(smscale_cur[j])); } } absMax = block_reduce(absMax, hipcub::Max()); auto fp4_scale = [](float tmp) { uint32_t u32 = ck_tile::bit_cast(tmp); uint32_t exponent = (u32 >> 23) & 0b11111111; if(exponent == 0b11111111) { return ck_tile::bit_cast(exponent << 23); } if(((u32 & 0x400000)) && (((u32 & 0x200000)) || ((u32 & 0x1FFFFF)) || (exponent))) exponent += 1; return ck_tile::bit_cast(exponent << 23); }; float row_scale = std::is_same_v ? fp4_scale(absMax) * inverted_DTYPE_MAX : absMax * inverted_DTYPE_MAX; return std::make_tuple(row_scale, reinterpret_cast(&smscale_cur)); } template __global__ void smooth_per_token_scaled_quant_kernel(DTYPE_O* __restrict__ out, float* __restrict__ scale, DTYPE_I* __restrict__ input, float* __restrict__ smooth_scale, int* __restrict__ smooth_scale_map, const int32_t cols, int32_t const* __restrict__ num_rows = nullptr, const int32_t num_rows_factor = 1, const int32_t input_dim0 = 1, const int32_t input_dim1 = 1, const int32_t input_stride0 = 1, const int32_t input_stride1 = 1) { int token_idx = blockIdx.x; if(num_rows != nullptr) { int32_t rows = *num_rows * num_rows_factor; if(token_idx >= rows) return; } int real_token_idx = token_idx % input_dim1 * (input_stride1 / cols) + (token_idx / input_dim1) % input_dim0 * (input_stride0 / cols); auto res = smooth_data_to_per_row_scale( input, smooth_scale, smooth_scale_map, cols, real_token_idx); float row_scale = std::get<0>(res); float* vec_ptr = std::get<1>(res); if(threadIdx.x == 0) { if constexpr(std::is_same_v) { auto* tmp = reinterpret_cast(scale); uint8_t exponent = (ck_tile::bit_cast(row_scale) >> 23) & 0b11111111; tmp[token_idx] = exponent; } else { scale[token_idx] = row_scale; } } scaled_quant_vgpr_impl(out, vec_ptr, &row_scale, cols); } void static_per_tensor_quant(torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor const& scale) // [1] { const int cols = input.size(-1); int rows = input.numel() / cols; dim3 grid(rows); dim3 block(BlockSize); const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); const hipStream_t stream = at::hip::getCurrentHIPStream(); if(out.dtype() == torch::kInt8) { AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "scaled_quant_kernel", [&] { using input_dtype = typename t2ck::type; aiter::scaled_quant_kernel<<>>( reinterpret_cast(out.data_ptr()), reinterpret_cast(input.data_ptr()), scale.data_ptr(), cols); }); } #ifdef GPU_ENABLE_FP8 else if(out.dtype() == torch_fp8) { AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "scaled_quant_kernel", [&] { using input_dtype = typename t2ck::type; aiter::scaled_quant_kernel<<>>( reinterpret_cast(out.data_ptr()), reinterpret_cast(input.data_ptr()), scale.data_ptr(), cols); }); } #endif else { TORCH_CHECK(false, __func__, " not support output type: ", out.dtype()); } } #define DYNAMIC_PER_TOKEN_SCALED_QUANT_KERNEL_IMPL(quant_kernel, DTYPE_O, THREAD_DATA) \ AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "quant_kernel", [&] { \ using input_dtype = typename t2ck::type; \ aiter::quant_kernel<<>>( \ reinterpret_cast(out.data_ptr()), \ scales.data_ptr(), \ reinterpret_cast(input.data_ptr()), \ scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, \ cols, \ num_rows_ptr, \ num_rows_factor); \ }); #define DYNAMIC_PER_TOKEN_SCALED_QUANT_KERNEL_DISPATCH(quant_kernel, DTYPE_O, cols) \ if(cols <= 8 * BlockSize) \ { \ DYNAMIC_PER_TOKEN_SCALED_QUANT_KERNEL_IMPL(quant_kernel, DTYPE_O, 8) \ } \ else if(cols <= 16 * BlockSize) \ { \ DYNAMIC_PER_TOKEN_SCALED_QUANT_KERNEL_IMPL(quant_kernel, DTYPE_O, 16) \ } \ else if(cols <= 32 * BlockSize) \ { \ DYNAMIC_PER_TOKEN_SCALED_QUANT_KERNEL_IMPL(quant_kernel, DTYPE_O, 32) \ } \ else \ { \ DYNAMIC_PER_TOKEN_SCALED_QUANT_KERNEL_IMPL(quant_kernel, DTYPE_O, 0) \ } void dynamic_per_tensor_quant(torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor& scale) // [1] { const int cols = input.size(-1); int rows = input.numel() / cols; dim3 grid(rows); dim3 block(BlockSize); const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); const hipStream_t stream = at::hip::getCurrentHIPStream(); if(out.dtype() == torch::kInt8) { AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "scaled_quant_kernel", [&] { using input_dtype = typename t2ck::type; vllm::initializeScale<<>>( scale.data_ptr(), 1, 0.0f); aiter::data_to_scale_kernel<<>>( scale.data_ptr(), reinterpret_cast(input.data_ptr()), cols); aiter::scaled_quant_kernel<<>>( reinterpret_cast(out.data_ptr()), reinterpret_cast(input.data_ptr()), scale.data_ptr(), cols); }); } #ifdef GPU_ENABLE_FP8 else if(out.dtype() == torch_fp8) { AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "scaled_quant_kernel", [&] { using input_dtype = typename t2ck::type; vllm::initializeScale<<>>( scale.data_ptr(), 1, 0.0f); aiter::data_to_scale_kernel<<>>( scale.data_ptr(), reinterpret_cast(input.data_ptr()), cols); aiter::scaled_quant_kernel<<>>( reinterpret_cast(out.data_ptr()), reinterpret_cast(input.data_ptr()), scale.data_ptr(), cols); }); } #endif else { TORCH_CHECK(false, __func__, " not support output type: ", out.dtype()); } } void dynamic_per_token_scaled_quant(torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor& scales, std::optional const& scale_ub, bool shuffle_scale = false, std::optional const& num_rows = std::nullopt, int num_rows_factor = 1) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); int const cols = input.size(-1); int const rows = input.numel() / cols; int32_t* num_rows_ptr = num_rows.has_value() ? num_rows->data_ptr() : nullptr; const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); const hipStream_t stream = at::hip::getCurrentHIPStream(); if(cols == 32 || cols == 64 || cols == 128) { int group_size = cols; int thread_data_size = 32; int num_thread_per_group = group_size / thread_data_size; int num_group_per_tg = groupQuantBlockSize / num_thread_per_group; if(out.dtype() == torch::kInt8) { int ori_cols = cols; int scaleN = ori_cols / cols; int ori_rows = rows / scaleN; int num_group = rows; dim3 const grid((num_group + num_group_per_tg - 1) / num_group_per_tg); dim3 const block(groupQuantBlockSize); AITER_DISPATCH_FLOATING16_TYPES( input.scalar_type(), "dynamic_per_group_scaled_quant_kernel", [&] { using input_dtype = typename t2ck::type; aiter::dynamic_per_group_scaled_quant_kernel<<>>( reinterpret_cast(out.data_ptr()), scales.data_ptr(), reinterpret_cast(input.data_ptr()), scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, group_size, ori_rows, ori_cols, ori_cols, shuffle_scale, num_rows_ptr, num_rows_factor); }); } #ifdef GPU_ENABLE_FP8 else if(out.dtype() == torch_fp8) { int ori_cols = out.size(-1); int scaleN = ori_cols / cols; int ori_rows = rows / scaleN; int num_group = rows; dim3 const grid((num_group + num_group_per_tg - 1) / num_group_per_tg); dim3 const block(groupQuantBlockSize); AITER_DISPATCH_FLOATING16_TYPES( input.scalar_type(), "dynamic_per_group_scaled_quant_kernel", [&] { using input_dtype = typename t2ck::type; aiter::dynamic_per_group_scaled_quant_kernel<<>>( reinterpret_cast(out.data_ptr()), scales.data_ptr(), reinterpret_cast(input.data_ptr()), scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, group_size, ori_rows, ori_cols, ori_cols, shuffle_scale, num_rows_ptr, num_rows_factor); }); } #endif #if defined(__Float4_e2m1fn_x2) else if(out.dtype() == torch_fp4x2) { int ori_cols = out.size(-1) * 2; int scaleN = ori_cols / cols; int ori_rows = rows / scaleN; int num_group = shuffle_scale ? ori_rows * ((scaleN + 7) / 8 * 8) : rows; // int num_group = shuffle_scale ? ((ori_rows + 255) / 256 * 256) * ((scaleN + 7) / 8 * // 8) : rows; dim3 const grid((num_group + num_group_per_tg - 1) / num_group_per_tg); dim3 const block(groupQuantBlockSize); AITER_DISPATCH_FLOATING16_TYPES( input.scalar_type(), "dynamic_per_group_scaled_quant_kernel", [&] { using input_dtype = typename t2ck::type; aiter::dynamic_per_group_scaled_quant_kernel<<>>( reinterpret_cast(out.data_ptr()), reinterpret_cast(scales.data_ptr()), reinterpret_cast(input.data_ptr()), scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, group_size, ori_rows, ori_cols, ori_cols, shuffle_scale, num_rows_ptr, num_rows_factor); }); } #endif else { TORCH_CHECK(false, __func__, " not support output type: ", out.dtype()); } } else { dim3 const grid(rows); dim3 const block(BlockSize); if(out.dtype() == torch::kInt8) { DYNAMIC_PER_TOKEN_SCALED_QUANT_KERNEL_DISPATCH( dynamic_per_token_scaled_quant_kernel, ck_tile::int8_t, cols); } #ifdef GPU_ENABLE_FP8 else if(out.dtype() == torch_fp8) { DYNAMIC_PER_TOKEN_SCALED_QUANT_KERNEL_DISPATCH( dynamic_per_token_scaled_quant_kernel, FP8_TYPE, cols); } #endif #if defined(__Float4_e2m1fn_x2) else if(out.dtype() == torch_fp4x2) { DYNAMIC_PER_TOKEN_SCALED_QUANT_KERNEL_DISPATCH( dynamic_per_token_scaled_quant_kernel, ck_tile::fp4x2_t, cols); } #endif else { TORCH_CHECK(false, __func__, " not support output type: ", out.dtype()); } } } void dynamic_per_group_scaled_quant_fp4(torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor& scales, int group_size = 32, bool shuffle_scale = true, std::optional const& num_rows = std::nullopt, int num_rows_factor = 1) { TORCH_CHECK(group_size == 32 || group_size == 64 || group_size == 128, __func__, " only support group_size [32, 64 , 128]"); TORCH_CHECK(out.is_contiguous()); int const cols = input.size(-1); int const rows = input.numel() / cols; int const row_stride = input.stride(-2); int32_t* num_rows_ptr = num_rows.has_value() ? num_rows->data_ptr() : nullptr; TORCH_CHECK(cols % group_size == 0, __func__, " cols is not divisible by group_size"); const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); const hipStream_t stream = at::hip::getCurrentHIPStream(); int thread_data_size = 32; int num_thread_per_group = group_size / thread_data_size; int num_group_per_tg = groupQuantBlockSize / num_thread_per_group; int scaleN = cols / group_size; int num_group = shuffle_scale ? rows * ((scaleN + 7) / 8 * 8) : rows * scaleN; // int num_group = shuffle_scale ? ((rows + 255) / 256 * 256) * ((scaleN + 7) / 8 * 8) : rows * // scaleN; dim3 const grid((num_group + num_group_per_tg - 1) / num_group_per_tg); dim3 const block(groupQuantBlockSize); #if defined(__Float4_e2m1fn_x2) AITER_DISPATCH_FLOATING16_TYPES( input.scalar_type(), "dynamic_per_group_scaled_quant_kernel", [&] { using input_dtype = typename t2ck::type; aiter::dynamic_per_group_scaled_quant_kernel<<>>( reinterpret_cast(out.data_ptr()), reinterpret_cast(scales.data_ptr()), reinterpret_cast(input.data_ptr()), nullptr, group_size, rows, cols, row_stride, shuffle_scale, num_rows_ptr, num_rows_factor); }); #else TORCH_CHECK(false, __func__, " device not support Float4_e2m1fn_x2 dtype"); #endif } #define SMOOTH_PER_TOKEN_SCALED_QUANT_KERNEL_IMPL(quant_kernel, DTYPE_O, THREAD_DATA, BLOCK_SIZE) \ AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "quant_kernel", [&] { \ using input_dtype = typename t2ck::type; \ aiter::quant_kernel \ <<>>( \ reinterpret_cast(out.data_ptr()), \ scales.data_ptr(), \ reinterpret_cast(input.data_ptr()), \ smooth_scale.data_ptr(), \ smooth_scale_map_ptr, \ cols, \ num_rows_ptr, \ num_rows_factor, \ input_dim0, \ input_dim1, \ input_stride0, \ input_stride1); \ }); #define SMOOTH_PER_TOKEN_SCALED_QUANT_KERNEL_DISPATCH(quant_kernel, DTYPE_O, cols) \ if(cols <= 8 * BlockSize) \ { \ SMOOTH_PER_TOKEN_SCALED_QUANT_KERNEL_IMPL(quant_kernel, DTYPE_O, 8, BlockSize) \ } \ else if(cols <= 16 * BlockSize) \ { \ SMOOTH_PER_TOKEN_SCALED_QUANT_KERNEL_IMPL(quant_kernel, DTYPE_O, 16, BlockSize) \ } \ else if(cols <= 16 * BlockSize * 2) \ { \ SMOOTH_PER_TOKEN_SCALED_QUANT_KERNEL_IMPL(quant_kernel, DTYPE_O, 16, BlockSize * 2) \ } \ else \ { \ TORCH_CHECK(false, "input last dim has exceeded the maximum value ", 32 * BlockSize) \ } void smooth_per_token_scaled_quant( torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor& scales, torch::Tensor const& smooth_scale, std::optional const& smooth_scale_map = std::nullopt, bool shuffle_scale = false, std::optional const& num_rows = std::nullopt, int num_rows_factor = 1) { TORCH_CHECK(out.is_contiguous()); int const cols = input.size(-1); int const rows = input.numel() / cols; int32_t* num_rows_ptr = num_rows.has_value() ? num_rows->data_ptr() : nullptr; int32_t* smooth_scale_map_ptr = smooth_scale_map.has_value() ? smooth_scale_map->data_ptr() : nullptr; TORCH_CHECK( input.dim() < 4, __func__, " only support input dim <=3, but get dim: ", input.dim()); int32_t input_dim0 = input.size(0); int32_t input_dim1 = input.dim() > 2 ? input.size(1) : 1; int32_t input_stride0 = input.stride(0); int32_t input_stride1 = input.dim() > 2 ? input.stride(1) : cols; const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); const hipStream_t stream = at::hip::getCurrentHIPStream(); dim3 const grid(rows); dim3 const block(BlockSize); if(out.dtype() == torch::kInt8) { SMOOTH_PER_TOKEN_SCALED_QUANT_KERNEL_DISPATCH( smooth_per_token_scaled_quant_kernel, ck_tile::int8_t, cols); } #ifdef GPU_ENABLE_FP8 else if(out.dtype() == torch_fp8) { SMOOTH_PER_TOKEN_SCALED_QUANT_KERNEL_DISPATCH( smooth_per_token_scaled_quant_kernel, FP8_TYPE, cols); } #endif #if defined(__Float4_e2m1fn_x2) else if(out.dtype() == torch::kFloat4_e2m1fn_x2 || out.dtype() == torch::kUInt8) { SMOOTH_PER_TOKEN_SCALED_QUANT_KERNEL_DISPATCH( smooth_per_token_scaled_quant_kernel, ck_tile::fp4x2_t, cols); } #endif else { TORCH_CHECK(false, __func__, " not support output type: ", out.dtype()); } } template __global__ void partial_transpose_kernel(DTYPE* __restrict__ out, DTYPE* __restrict__ input, const int* __restrict__ num_rows, const int cols) { using vec_i = ck_tile::vec_t; int GRID_SIZE = gridDim.x; int ori_rows = *num_rows; int thread_per_row = (cols + thread_data_size - 1) / thread_data_size; auto const* ptr_i = reinterpret_cast(input); static constexpr int32_t ooba_i = 4 / sizeof(DTYPE); const int32_t oob_i = (ori_rows * cols + ooba_i - 1) / ooba_i * ooba_i; auto buffer_i = ck_tile::make_buffer_view(ptr_i, oob_i); buffer_i.init_raw(); for(int i = 0; i < MAX_ITERS; i++) { int64_t y = i * GRID_SIZE * BLOCK_SIZE + blockIdx.x * BLOCK_SIZE + threadIdx.x; int x = y % thread_per_row * thread_data_size; y = y / thread_per_row; if(y >= ori_rows) return; vec_i input_vecs = buffer_i.template get(y * cols + x, 0, true); int64_t out_offset = x * ori_rows + y; // printf("blockIdx: %d, threadIdx:%d, y: %d, x: %d, ori_rows: %d, cols: %d, val:%f\n", // blockIdx.x, threadIdx.x, y, x, ori_rows, cols, // ck_tile::type_convert(input_vecs[0])); for(int j = 0; j < thread_data_size; j++) { if((x + j) < cols) { out[out_offset + j * ori_rows] = input_vecs[j]; } } } } void partial_transpose(torch::Tensor& out, // [rows, d] torch::Tensor const& input, // [rows, d] torch::Tensor const& num_rows) { TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(input.is_contiguous()); uint32_t num_cu = get_num_cu_func(); int const cols = input.size(-1); int const rows = input.numel() / cols; int32_t* num_rows_ptr = num_rows.data_ptr(); const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); const hipStream_t stream = at::hip::getCurrentHIPStream(); if(cols <= 1024) { const int BlockSize = 256; const int GridSize = num_cu * 8; // Adjust as needed const int thread_data_size = 1024 / BlockSize; dim3 grid(GridSize); dim3 block(BlockSize); VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "partial_transpose_kernel", [&] { using input_dtype = typename t2ck::type; aiter::partial_transpose_kernel <<>>(reinterpret_cast(out.data_ptr()), reinterpret_cast(input.data_ptr()), num_rows_ptr, cols); }); } else if(cols <= 2048) { const int BlockSize = 256; const int GridSize = num_cu * 4; const int thread_data_size = 2048 / BlockSize; dim3 grid(GridSize); dim3 block(BlockSize); VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "partial_transpose_kernel", [&] { using input_dtype = typename t2ck::type; aiter::partial_transpose_kernel <<>>(reinterpret_cast(out.data_ptr()), reinterpret_cast(input.data_ptr()), num_rows_ptr, cols); }); } else if(cols <= 4096) { const int BlockSize = 256; const int GridSize = num_cu * 2; const int thread_data_size = 4096 / BlockSize; dim3 grid(GridSize); dim3 block(BlockSize); VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "partial_transpose_kernel", [&] { using input_dtype = typename t2ck::type; aiter::partial_transpose_kernel <<>>(reinterpret_cast(out.data_ptr()), reinterpret_cast(input.data_ptr()), num_rows_ptr, cols); }); } else if(cols <= 8192) { const int BlockSize = 512; const int GridSize = num_cu; const int thread_data_size = 8192 / BlockSize; dim3 grid(GridSize); dim3 block(BlockSize); VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "partial_transpose_kernel", [&] { using input_dtype = typename t2ck::type; aiter::partial_transpose_kernel <<>>(reinterpret_cast(out.data_ptr()), reinterpret_cast(input.data_ptr()), num_rows_ptr, cols); }); } else { TORCH_CHECK(false, __func__, " cols is not supported: ", cols); } } } // namespace aiter