Commit 4d3a2c28 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.5' into v0.6.5-dev

parents 92ec5d8e 2d1b9baa
...@@ -9,11 +9,13 @@ bool call_marlin_moe_kernel_ku8b128( ...@@ -9,11 +9,13 @@ bool call_marlin_moe_kernel_ku8b128(
bool has_act_order, int group_blocks, int num_threads, int blocks, bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr, int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
bool replicate_input, bool apply_weights, int m_block, int max_par, int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int cfg_max_m_blocks) { int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = false;
if (false) { if (false) {
} }
GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256)
......
...@@ -9,10 +9,10 @@ bool call_marlin_moe_kernel_ku8b128( ...@@ -9,10 +9,10 @@ bool call_marlin_moe_kernel_ku8b128(
bool has_act_order, int group_blocks, int num_threads, int blocks, bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr, int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
bool replicate_input, bool apply_weights, int m_block, int max_par, int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int cfg_max_m_blocks); int m_block, int max_par, int cfg_max_m_blocks);
} }
...@@ -25,9 +25,12 @@ ...@@ -25,9 +25,12 @@
#include <iostream> #include <iostream>
#include "core/exception.hpp"
#include "core/scalar_type.hpp" #include "core/scalar_type.hpp"
#include "core/registration.h"
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h" #include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h" #include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
#include "marlin_kernels/marlin_moe_kernel_ku4.h"
template <typename T> template <typename T>
inline std::string str(T x) { inline std::string str(T x) {
...@@ -155,6 +158,7 @@ thread_config_t small_batch_thread_configs[] = { ...@@ -155,6 +158,7 @@ thread_config_t small_batch_thread_configs[] = {
{128, 64, 128}, // Reduce N 2X, same K {128, 64, 128}, // Reduce N 2X, same K
{64, 256, 256}, // Reduce K 2X, increase N 2X {64, 256, 256}, // Reduce K 2X, increase N 2X
{64, 128, 128}, // Reduce K 2X, same N {64, 128, 128}, // Reduce K 2X, same N
{64, 64, 128}, // Reduce both 2X
}; };
thread_config_t large_batch_thread_configs[] = { thread_config_t large_batch_thread_configs[] = {
...@@ -165,6 +169,7 @@ thread_config_t large_batch_thread_configs[] = { ...@@ -165,6 +169,7 @@ thread_config_t large_batch_thread_configs[] = {
{128, 128, 256}, // Reduce N 2X, increase K 2X {128, 128, 256}, // Reduce N 2X, increase K 2X
{64, 128, 128}, // Reduce N 2X, same K {64, 128, 128}, // Reduce N 2X, same K
{128, 64, 128}, // Reduce N 4X, increase K 2X {128, 64, 128}, // Reduce N 4X, increase K 2X
{64, 64, 128}, // Reduce N 4X, same K
}; };
int get_scales_cache_size(thread_config_t const& th_config, int prob_m, int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
...@@ -189,7 +194,7 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m, ...@@ -189,7 +194,7 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
int load_groups = int load_groups =
tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 2; return load_groups * tb_n * 4;
} else { } else {
int tb_scales = tb_groups * tb_n * 2; int tb_scales = tb_groups * tb_n * 2;
...@@ -310,27 +315,28 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, ...@@ -310,27 +315,28 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
return exec_config_t{0, {-1, -1, -1}}; return exec_config_t{0, {-1, -1, -1}};
} }
#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \ #define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \
else if (KERNEL_FUNCTION(q_type, thread_n_blocks, thread_k_blocks, \ else if (KERNEL_FUNCTION( \
has_act_order, group_blocks, num_threads, blocks, \ q_type, thread_n_blocks, thread_k_blocks, has_act_order, \
max_shared_mem, stream, A_ptr, B_ptr, C_ptr, \ group_blocks, num_threads, blocks, max_shared_mem, stream, \
sorted_ids_ptr, topk_weights_ptr, s_ptr, g_idx_ptr, \ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
expert_offsets_ptr, num_groups, expert_idx, \ zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
num_experts, topk, prob_m, prob_n, prob_k, tot_m, \ num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
locks, replicate_input, apply_weights, m_block, \ replicate_input, apply_weights, m_block, max_par, \
max_par, exec_cfg.max_m_blocks)) { \ exec_cfg.max_m_blocks)) { \
} }
void marlin_mm_moe(const void* A, const void* B, void* C, void marlin_mm_moe(const void* A, const void* B, void* C,
const void* sorted_ids, const void* topk_weights, const void* sorted_ids, const void* topk_weights,
const void* topk_ids, const void* s, const void* g_idx, const void* topk_ids, const void* s, void* zp,
const void* perm, void* a_tmp, void* expert_offsets, const void* g_idx, const void* perm, void* a_tmp,
int prob_m, int prob_n, int prob_k, void* workspace, void* expert_offsets, int prob_m, int prob_n, int prob_k,
vllm::ScalarType const& q_type, bool has_act_order, void* workspace, vllm::ScalarType const& q_type,
bool is_k_full, int num_groups, int group_size, bool has_act_order, bool is_k_full, bool has_zp,
int num_experts, int topk, int moe_block_size, int dev, int num_groups, int group_size, int num_experts, int topk,
cudaStream_t stream, int thread_k, int thread_n, int sms, int moe_block_size, int dev, cudaStream_t stream,
int max_par, bool replicate_input, bool apply_weights) { int thread_k, int thread_n, int sms, int max_par,
bool replicate_input, bool apply_weights) {
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]"); ", ", prob_n, ", ", prob_k, "]");
...@@ -433,11 +439,9 @@ void marlin_mm_moe(const void* A, const void* B, void* C, ...@@ -433,11 +439,9 @@ void marlin_mm_moe(const void* A, const void* B, void* C,
int4* C_ptr = (int4*)C; int4* C_ptr = (int4*)C;
const float* topk_weights_ptr = (const float*)topk_weights; const float* topk_weights_ptr = (const float*)topk_weights;
const int* sorted_ids_ptr = (const int*)sorted_ids; const int* sorted_ids_ptr = (const int*)sorted_ids;
const int4* s_ptr = const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx;
(const int4*)s + const int4* zp_ptr =
(((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) * (const int4*)zp + num_groups * prob_n / (pack_factor * 4) * expert_idx;
prob_n / 8) *
expert_idx;
const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx;
const int* perm_ptr = (const int*)perm + prob_k * expert_idx; const int* perm_ptr = (const int*)perm + prob_k * expert_idx;
int* locks = (int*)workspace; int* locks = (int*)workspace;
...@@ -458,6 +462,7 @@ void marlin_mm_moe(const void* A, const void* B, void* C, ...@@ -458,6 +462,7 @@ void marlin_mm_moe(const void* A, const void* B, void* C,
} }
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8) CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8)
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128) CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128)
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4)
else { else {
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
str(prob_n) + ", " + str(prob_k) + "]" + str(prob_n) + ", " + str(prob_k) + "]" +
...@@ -477,15 +482,24 @@ torch::Tensor marlin_gemm_moe( ...@@ -477,15 +482,24 @@ torch::Tensor marlin_gemm_moe(
const torch::Tensor& a, const torch::Tensor& b_q_weights, const torch::Tensor& a, const torch::Tensor& b_q_weights,
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
const torch::Tensor& topk_ids, const torch::Tensor& b_scales, const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
const torch::Tensor& g_idx, const torch::Tensor& perm, torch::Tensor& b_zeros, const torch::Tensor& g_idx,
torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, const torch::Tensor& perm, torch::Tensor& workspace,
int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n,
int64_t num_experts, int64_t topk, int64_t moe_block_size, int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk,
bool replicate_input, bool apply_weights) { int64_t moe_block_size, bool replicate_input, bool apply_weights) {
TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
"b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str()); bool has_zp = b_zeros.size(1) != 0;
if (has_zp) {
TORCH_CHECK(
b_q_type == vllm::kU4,
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
} else {
TORCH_CHECK(
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
"b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type.str());
}
int pack_factor = 32 / b_q_type->size_bits(); int pack_factor = 32 / b_q_type.size_bits();
int max_par = 4; int max_par = 4;
...@@ -521,6 +535,9 @@ torch::Tensor marlin_gemm_moe( ...@@ -521,6 +535,9 @@ torch::Tensor marlin_gemm_moe(
" is not size_n = ", size_n); " is not size_n = ", size_n);
num_groups = b_scales.size(1); num_groups = b_scales.size(1);
TORCH_CHECK(VLLM_IMPLIES(!is_k_full, has_act_order),
"if is_k_full is false, has_act_order must be true");
if (has_act_order) { if (has_act_order) {
if (is_k_full) { if (is_k_full) {
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
...@@ -542,13 +559,30 @@ torch::Tensor marlin_gemm_moe( ...@@ -542,13 +559,30 @@ torch::Tensor marlin_gemm_moe(
} }
} }
// Verify b_zeros
if (has_zp) {
int rank = b_zeros.sizes().size();
TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3");
TORCH_CHECK(b_zeros.size(1) == num_groups,
"b_zeros dim 1 = ", b_zeros.size(1),
" is not num_groups = ", num_groups);
TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor,
"b_zeros dim 2 = ", b_zeros.size(2),
" is not size_n / pack_factor = ", size_n / pack_factor);
}
marlin_moe::marlin_mm_moe( marlin_moe::marlin_mm_moe(
a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(),
topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),
expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(),
*b_q_type, has_act_order, is_k_full, num_groups, group_size, num_experts, b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size,
topk, moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, num_experts, topk, moe_block_size, dev,
thread_n, sms, max_par, replicate_input, apply_weights); at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par,
replicate_input, apply_weights);
return c; return c;
} }
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("marlin_gemm_moe", &marlin_gemm_moe);
}
#pragma once
#include <torch/all.h>
#include "core/scalar_type.hpp"
torch::Tensor marlin_gemm_moe(
const torch::Tensor& a, const torch::Tensor& b_q_weights,
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
const torch::Tensor& g_idx, const torch::Tensor& perm,
torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full,
int64_t num_experts, int64_t topk, int64_t moe_block_size,
bool replicate_input, bool apply_weights);
#include <torch/all.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <THC/THCAtomics.cuh> #include <THC/THCAtomics.cuh>
#include "cuda_compat.h" #include "../cuda_compat.h"
#include "dispatch_utils.h" #include "../dispatch_utils.h"
#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) #define CEILDIV(x, y) (((x) + (y) - 1) / (y))
#define MAX_SHARED_MEM_SIZE 64 * 1024 #define MAX_SHARED_MEM_SIZE 64 * 1024
namespace vllm { namespace vllm {
namespace moe {
namespace { namespace {
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
...@@ -37,14 +39,14 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, ...@@ -37,14 +39,14 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
int32_t* tokens_cnts = nullptr; int32_t* tokens_cnts = nullptr;
int32_t* cumsum = nullptr; int32_t* cumsum = nullptr;
if (experts_num_exceed_limit) { if (experts_num_exceed_limit) {
// 2d tensor with shape (num_experts + 1, num_experts) // 2d tensor with shape (blockDim.x + 1, num_experts)
tokens_cnts = global_tokens_cnts_ptr; tokens_cnts = global_tokens_cnts_ptr;
// 1d tensor with shape (num_experts + 1) // 1d tensor with shape (num_experts + 1)
cumsum = shared_mem; cumsum = shared_mem;
} else { } else {
tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) tokens_cnts = shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts)
cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1) cumsum = shared_mem + (blockDim.x + 1) * num_experts; // 1d tensor with shape (num_experts + 1)
} }
for (int i = 0; i < num_experts; ++i) { for (int i = 0; i < num_experts; ++i) {
...@@ -63,10 +65,12 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, ...@@ -63,10 +65,12 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
__syncthreads(); __syncthreads();
// For each expert we accumulate the token counts from the different threads. // For each expert we accumulate the token counts from the different threads.
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; if (threadIdx.x < num_experts) {
for (int i = 1; i <= blockDim.x; ++i) { tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
tokens_cnts[index(num_experts, i, threadIdx.x)] += for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
}
} }
__syncthreads(); __syncthreads();
...@@ -89,9 +93,11 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, ...@@ -89,9 +93,11 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
* For each expert, each thread processes the tokens of the corresponding * For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block. * blocks and stores the corresponding expert_id for each block.
*/ */
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; if (threadIdx.x < num_experts) {
i += block_size) { for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
expert_ids[i / block_size] = threadIdx.x; i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
}
} }
/** /**
...@@ -116,6 +122,24 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, ...@@ -116,6 +122,24 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
} }
} }
template <typename scalar_t, int TOPK>
__global__ void moe_sum_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., topk, d]
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
scalar_t x = 0.0;
#pragma unroll
for (int k = 0; k < TOPK; ++k) {
x += VLLM_LDG(&input[token_idx * TOPK * d + k * d + idx]);
}
out[token_idx * d + idx] = x;
}
}
} // namespace moe
} // namespace vllm } // namespace vllm
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
...@@ -125,7 +149,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, ...@@ -125,7 +149,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_INTEGRAL_TYPES( VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
int32_t shared_mem_normal = ((num_experts + 1) * num_experts + (num_experts + 1)) * const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
int32_t shared_mem_normal = ((num_thread + 1) * num_experts + (num_experts + 1)) *
sizeof(int32_t); sizeof(int32_t);
const bool experts_num_exceed_limit = shared_mem_normal > MAX_SHARED_MEM_SIZE; const bool experts_num_exceed_limit = shared_mem_normal > MAX_SHARED_MEM_SIZE;
...@@ -146,8 +171,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, ...@@ -146,8 +171,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
kernel<<<1, num_experts, shared_mem, stream>>>( kernel<<<1, num_experts, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(), topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(), experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), key_cache_ptrs_tensor.data_ptr<int32_t>(), num_experts, num_tokens_post_pad.data_ptr<int32_t>(), key_cache_ptrs_tensor.data_ptr<int32_t>(), num_experts, block_size,
block_size, topk_ids.numel()); topk_ids.numel());
} else { } else {
// set dynamic shared mem // set dynamic shared mem
auto kernel = vllm::moe_align_block_size_kernel<scalar_t, false>; auto kernel = vllm::moe_align_block_size_kernel<scalar_t, false>;
...@@ -159,6 +184,48 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, ...@@ -159,6 +184,48 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
num_tokens_post_pad.data_ptr<int32_t>(), nullptr, num_experts, block_size, num_tokens_post_pad.data_ptr<int32_t>(), nullptr, num_experts, block_size,
topk_ids.numel()); topk_ids.numel());
} }
});
}
void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
torch::Tensor& output) // [num_tokens, hidden_size]
{
const int hidden_size = input.size(-1);
const int num_tokens = output.numel() / hidden_size;
const int topk = input.size(1);
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (topk) {
case 2:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
vllm::moe::moe_sum_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
});
break;
case 3:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
vllm::moe::moe_sum_kernel<scalar_t, 3><<<grid, block, 0, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
}); });
break;
case 4:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
vllm::moe::moe_sum_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
});
break;
default:
at::sum_out(output, input, 1);
break;
}
} }
...@@ -5,3 +5,10 @@ ...@@ -5,3 +5,10 @@
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices, torch::Tensor& token_expert_indices,
torch::Tensor& gating_output); torch::Tensor& gating_output);
void moe_sum(torch::Tensor& input, torch::Tensor& output);
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);
#include "core/registration.h" #include "core/registration.h"
#include "moe_ops.h" #include "moe_ops.h"
#include "marlin_moe_ops.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Apply topk softmax to the gating outputs. // Apply topk softmax to the gating outputs.
...@@ -9,16 +8,31 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -9,16 +8,31 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"token_expert_indices, Tensor gating_output) -> ()"); "token_expert_indices, Tensor gating_output) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax); m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
// Calculate the result of moe by summing up the partial results
// from all selected experts.
m.def("moe_sum(Tensor! input, Tensor output) -> ()");
m.impl("moe_sum", torch::kCUDA, &moe_sum);
// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size.
m.def(
"moe_align_block_size(Tensor topk_ids, int num_experts,"
" int block_size, Tensor! sorted_token_ids,"
" Tensor! experts_ids,"
" Tensor! num_tokens_post_pad) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
#ifndef USE_ROCM #ifndef USE_ROCM
m.def( m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"g_idx, Tensor! perm, Tensor! workspace, " "b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, " "int b_q_type, SymInt size_m, "
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, " "SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int "
"topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)" "int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor"); " -> Tensor");
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); // conditionally compiled so impl registration is in source file
#endif #endif
} }
......
...@@ -5,6 +5,30 @@ ...@@ -5,6 +5,30 @@
#include "core/scalar_type.hpp" #include "core/scalar_type.hpp"
#include <vector>
torch::Tensor weak_ref_tensor(torch::Tensor& tensor) {
// Ensure tensor is on CUDA
if (!tensor.is_cuda()) {
throw std::runtime_error("Tensor must be on CUDA device");
}
// Get the raw data pointer
void* data_ptr = tensor.data_ptr();
// Get tensor sizes and strides
std::vector<int64_t> sizes = tensor.sizes().vec();
std::vector<int64_t> strides = tensor.strides().vec();
// Get tensor options (dtype, device)
auto options = tensor.options();
// Create a new tensor from the raw data pointer
auto new_tensor = torch::from_blob(data_ptr, sizes, strides, options);
return new_tensor;
}
void paged_attention_v1( void paged_attention_v1(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
...@@ -158,6 +182,24 @@ void rms_norm_opt(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weigh ...@@ -158,6 +182,24 @@ void rms_norm_opt(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weigh
void fused_add_rms_norm_opt(torch::Tensor& input, torch::Tensor& residual, void fused_add_rms_norm_opt(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon); torch::Tensor& weight, double epsilon);
// void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& weight, torch::Tensor& scale,
// double epsilon);
// void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
// torch::Tensor& input,
// torch::Tensor& residual,
// torch::Tensor& weight,
// torch::Tensor& scale, double epsilon);
void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor const& weight,
torch::Tensor& scales,
double const epsilon,
std::optional<torch::Tensor> scale_ub,
std::optional<torch::Tensor> residual);
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int64_t head_size, torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox); torch::Tensor& cos_sin_cache, bool is_neox);
...@@ -187,6 +229,9 @@ void gelu_and_mul_opt(torch::Tensor& out, torch::Tensor& input); ...@@ -187,6 +229,9 @@ void gelu_and_mul_opt(torch::Tensor& out, torch::Tensor& input);
void gelu_tanh_and_mul_opt(torch::Tensor& out, torch::Tensor& input); void gelu_tanh_and_mul_opt(torch::Tensor& out, torch::Tensor& input);
void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input,
double threshold);
void gelu_new(torch::Tensor& out, torch::Tensor& input); void gelu_new(torch::Tensor& out, torch::Tensor& input);
void gelu_fast(torch::Tensor& out, torch::Tensor& input); void gelu_fast(torch::Tensor& out, torch::Tensor& input);
...@@ -231,62 +276,8 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel, ...@@ -231,62 +276,8 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor _zeros, int64_t split_k_iters, torch::Tensor _zeros, int64_t split_k_iters,
int64_t thx, int64_t thy); int64_t thx, int64_t thy);
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t size_m, int64_t size_n, int64_t size_k);
namespace machete {
std::vector<std::string> supported_schedules(
vllm::ScalarTypeTorchPtr const& btype);
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
vllm::ScalarTypeTorchPtr const& btype,
c10::optional<torch::Tensor> const& scales,
c10::optional<torch::Tensor> const& zeros,
c10::optional<int64_t> group_size,
c10::optional<torch::Tensor> const& C,
c10::optional<double> alpha, c10::optional<double> beta,
c10::optional<std::string> schedule);
torch::Tensor prepack_B(torch::Tensor const& B,
vllm::ScalarTypeTorchPtr const& btype);
}; // namespace machete
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm); torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
#endif
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta,
torch::Tensor& b_scales,
torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n,
int64_t size_k);
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& b_zeros,
torch::Tensor& g_idx, torch::Tensor& perm,
torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp,
bool use_fp32_reduce);
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits);
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
torch::Tensor& perm, c10::SymInt size_k,
c10::SymInt size_n, int64_t num_bits);
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits);
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
c10::SymInt size_k, c10::SymInt size_n,
int64_t num_bits);
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
int64_t n); int64_t n);
...@@ -297,11 +288,7 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, ...@@ -297,11 +288,7 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
int64_t row); int64_t row);
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, #ifndef USE_ROCM
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k);
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
...@@ -316,14 +303,6 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, ...@@ -316,14 +303,6 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& azp_adj, torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp, c10::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias); c10::optional<torch::Tensor> const& bias);
torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
torch::Tensor const& b_q_weight,
torch::Tensor const& s_tok,
torch::Tensor const& s_ch,
torch::Tensor const& s_group,
torch::Tensor& workspace, int64_t size_m,
int64_t size_n, int64_t size_k);
#endif #endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
...@@ -351,48 +330,46 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, ...@@ -351,48 +330,46 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
// torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale, // torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
// c10::optional<torch::Tensor> const& scale_ub); // c10::optional<torch::Tensor> const& scale_ub);
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
int64_t block_size, torch::Tensor sorted_token_ids, const torch::Tensor& A, const torch::Tensor& B,
torch::Tensor experts_ids, const torch::Tensor& C,
torch::Tensor num_tokens_post_pad); const c10::optional<torch::Tensor>& D_,
const c10::optional<torch::Tensor>& z_,
std::vector<torch::Tensor> selective_scan_fwd( const c10::optional<torch::Tensor>& delta_bias_,
const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A, bool delta_softplus,
const torch::Tensor& B, const torch::Tensor& C, const c10::optional<torch::Tensor>& query_start_loc,
const c10::optional<torch::Tensor>& D_, const c10::optional<torch::Tensor>& cache_indices,
const c10::optional<torch::Tensor>& z_, const c10::optional<torch::Tensor>& has_initial_state,
const c10::optional<torch::Tensor>& delta_bias_, bool delta_softplus, const torch::Tensor& ssm_states, int64_t pad_slot_id);
const c10::optional<torch::Tensor>& index_,
const c10::optional<torch::Tensor>& x); void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
const at::Tensor& weight,
at::Tensor causal_conv1d_update( const c10::optional<at::Tensor>& bias_,
const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight, bool silu_activation,
const c10::optional<at::Tensor>& bias, bool silu_activation, const c10::optional<at::Tensor>& cache_seqlens_,
const c10::optional<at::Tensor>& conv_state_indices); const c10::optional<at::Tensor>& conv_state_indices_,
int64_t pad_slot_id);
at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_, void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
const c10::optional<at::Tensor>& seq_idx_, const c10::optional<at::Tensor>& bias_,
const c10::optional<at::Tensor>& initial_states_, const c10::optional<at::Tensor>& conv_states,
const c10::optional<at::Tensor>& final_states_out_, const c10::optional<at::Tensor>& query_start_loc,
bool silu_activation); const c10::optional<at::Tensor>& cache_indices,
const c10::optional<at::Tensor>& has_initial_state,
bool silu_activation, int64_t pad_slot_id);
#ifndef USE_ROCM #ifndef USE_ROCM
using fptr_t = int64_t; using fptr_t = int64_t;
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
const std::vector<std::string>& handles, torch::Tensor& rank_data, int64_t rank, bool full_nvlink);
const std::vector<int64_t>& offsets, int64_t rank, void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
bool full_nvlink); fptr_t reg_buffer, int64_t reg_buffer_sz_bytes);
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor& out);
void dispose(fptr_t _fa); void dispose(fptr_t _fa);
int64_t meta_size(); int64_t meta_size();
void register_buffer(fptr_t _fa, torch::Tensor& t, void register_buffer(fptr_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
const std::vector<std::string>& handles, std::tuple<std::vector<int64_t>, std::vector<int64_t>>
const std::vector<int64_t>& offsets); get_graph_buffer_ipc_meta(fptr_t _fa);
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta( void register_graph_buffers(fptr_t _fa,
fptr_t _fa); const std::vector<std::vector<int64_t>>& handles,
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets); const std::vector<std::vector<int64_t>>& offsets);
#endif #endif
...@@ -107,8 +107,41 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { ...@@ -107,8 +107,41 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
return (T)(0.5f * f * (1.0f + ::tanhf(inner))); return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
} }
template <typename T>
__device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) {
const float f = (float)x;
return (T)(f > threshold ? f : 0.0f);
}
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
__global__ void act_and_mul_kernel_with_param(
scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d,
const float param) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = ACT_FN(x, param) * y;
}
} // namespace vllm } // namespace vllm
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \
vllm::act_and_mul_kernel_with_param<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d, \
PARAM); \
});
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ #define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \ int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \ int64_t num_tokens = input.numel() / input.size(-1); \
...@@ -163,4 +196,10 @@ void gelu_tanh_and_mul_opt(torch::Tensor& out, // [..., d] ...@@ -163,4 +196,10 @@ void gelu_tanh_and_mul_opt(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d] torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
}
void fatrelu_and_mul(torch::Tensor& out, // [..., d],
torch::Tensor& input, // [..., 2 * d]
double threshold) {
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold);
} }
\ No newline at end of file
#include <torch/all.h> #include "type_convert.cuh"
#include <ATen/cuda/CUDAContext.h> #include "dispatch_utils.h"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh> #include <ATen/native/cuda/MemoryAccess.cuh>
#include <c10/cuda/CUDAMathCompat.h> #include <c10/cuda/CUDAMathCompat.h>
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <THC/THCDeviceUtils.cuh> #include <THC/THCDeviceUtils.cuh>
#include "../dispatch_utils.h"
#ifndef USE_ROCM #ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cub/util_type.cuh>
#include <cub/cub.cuh> #include <cub/cub.cuh>
#else #else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp> #include <hipcub/hipcub.hpp>
using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162;
#endif #endif
namespace vllm { namespace vllm {
// TODO(woosuk): Further optimize this kernel. // TODO(woosuk): Further optimize this kernel.
...@@ -55,154 +48,6 @@ __global__ void rms_norm_kernel( ...@@ -55,154 +48,6 @@ __global__ void rms_norm_kernel(
} }
} }
/* Converter structs for the conversion from torch types to HIP/CUDA types,
and the associated type conversions within HIP/CUDA. These helpers need
to be implemented for now because the relevant type conversion
operators/constructors are not consistently implemented by HIP/CUDA, so
a generic conversion via type casts cannot be implemented.
Each struct should have the member static constexpr bool `exists`:
If false, the optimized kernel is not used for the corresponding torch type.
If true, the struct should be fully defined as shown in the examples below.
*/
template <typename torch_type>
struct _typeConvert {
static constexpr bool exists = false;
};
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion
template <>
struct _typeConvert<c10::Half> {
static constexpr bool exists = true;
using hip_type = __half;
using packed_hip_type = __half2;
__device__ static inline float convert(hip_type x) { return __half2float(x); }
__device__ static inline float2 convert(packed_hip_type x) {
return __half22float2(x);
}
__device__ static inline hip_type convert(float x) {
return __float2half_rn(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22half2_rn(x);
}
};
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely
template <>
struct _typeConvert<c10::BFloat16> {
static constexpr bool exists = true;
using hip_type = __nv_bfloat16;
using packed_hip_type = __nv_bfloat162;
__device__ static inline float convert(hip_type x) {
return __bfloat162float(x);
}
__device__ static inline float2 convert(packed_hip_type x) {
return __bfloat1622float2(x);
}
__device__ static inline hip_type convert(float x) {
return __float2bfloat16(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22bfloat162_rn(x);
}
};
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
// 12000))
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
for appropriate specializations of fused_add_rms_norm_kernel.
Only functions that are necessary in that kernel are implemented.
Alignment to 16 bytes is required to use 128-bit global memory ops.
*/
template <typename scalar_t, int width>
struct alignas(16) _f16Vec {
/* Not theoretically necessary that width is a power of 2 but should
almost always be the case for optimization purposes */
static_assert(width > 0 && (width & (width - 1)) == 0,
"Width is not a positive power of 2!");
using Converter = _typeConvert<scalar_t>;
using T1 = typename Converter::hip_type;
using T2 = typename Converter::packed_hip_type;
T1 data[width];
__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i + 1]};
temp += T2{other.data[i], other.data[i + 1]};
data[i] = temp.x;
data[i + 1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) data[i] += other.data[i];
}
return *this;
}
__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i + 1]};
temp *= T2{other.data[i], other.data[i + 1]};
data[i] = temp.x;
data[i + 1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) data[i] *= other.data[i];
}
return *this;
}
__device__ _f16Vec& operator*=(const float scale) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
temp_f.x *= scale;
temp_f.y *= scale;
T2 temp = Converter::convert(temp_f);
data[i] = temp.x;
data[i + 1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) {
float temp = Converter::convert(data[i]) * scale;
data[i] = Converter::convert(temp);
}
}
return *this;
}
__device__ float sum_squares() const {
float result = 0.0f;
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i + 1]});
result += z.x * z.x + z.y * z.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) {
float x = Converter::convert(data[i]);
result += x * x;
}
}
return result;
}
};
/* Function specialization in the case of FP16/BF16 tensors. /* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are Additional optimizations we can make in this case are
......
...@@ -17,6 +17,17 @@ __global__ void advance_step_flashattn_kernel( ...@@ -17,6 +17,17 @@ __global__ void advance_step_flashattn_kernel(
long const* sampled_token_ids_ptr, long* input_positions_ptr, long const* sampled_token_ids_ptr, long* input_positions_ptr,
int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr, int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr,
int64_t const block_tables_stride) { int64_t const block_tables_stride) {
int const n_pad = num_seqs - num_queries;
if (n_pad && blockIdx.x == 0) {
// Handle cuda graph padding
int const offset = num_queries;
for (int i = threadIdx.x; i < n_pad; i += blockDim.x) {
input_tokens_ptr[offset + i] = 0;
input_positions_ptr[offset + i] = 0;
slot_mapping_ptr[offset + i] = -1;
}
}
int num_query_blocks = div_ceil(num_queries, num_threads); int num_query_blocks = div_ceil(num_queries, num_threads);
if (blockIdx.x >= num_query_blocks) { if (blockIdx.x >= num_query_blocks) {
...@@ -52,7 +63,7 @@ __global__ void advance_step_flashattn_kernel( ...@@ -52,7 +63,7 @@ __global__ void advance_step_flashattn_kernel(
slot_mapping_ptr[cur_query_id] = slot_num; slot_mapping_ptr[cur_query_id] = slot_num;
} }
inline void verify_tensor(std::string const& name, torch::Tensor& t, inline void verify_tensor(std::string const& name, torch::Tensor const& t,
int64_t const size_0, int64_t const size_1, int64_t const size_0, int64_t const size_1,
c10::ScalarType const type) { c10::ScalarType const type) {
bool size_0_cond = true; bool size_0_cond = true;
...@@ -77,6 +88,7 @@ inline void verify_tensor(std::string const& name, torch::Tensor& t, ...@@ -77,6 +88,7 @@ inline void verify_tensor(std::string const& name, torch::Tensor& t,
} }
} }
/// each thread processes a block per query
__global__ void advance_step_flashinfer_kernel( __global__ void advance_step_flashinfer_kernel(
int num_threads, int num_seqs, int num_queries, int block_size, int num_threads, int num_seqs, int num_queries, int block_size,
long* input_tokens_ptr, long const* sampled_token_ids_ptr, long* input_tokens_ptr, long const* sampled_token_ids_ptr,
...@@ -123,8 +135,10 @@ __global__ void advance_step_flashinfer_indptr_kernel( ...@@ -123,8 +135,10 @@ __global__ void advance_step_flashinfer_indptr_kernel(
int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr, int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr,
int* block_table_bound_ptr) { int* block_table_bound_ptr) {
int idx = blockIdx.x * num_threads + threadIdx.x; int idx = blockIdx.x * num_threads + threadIdx.x;
// Update paged_kv_indptr // Update paged_kv_indptr
if (idx == 0) {
paged_kv_indptr_ptr[idx] = 0;
}
if (idx < num_queries) { if (idx < num_queries) {
int sum = 0; int sum = 0;
for (int i = 0; i <= idx; ++i) { for (int i = 0; i <= idx; ++i) {
...@@ -135,20 +149,33 @@ __global__ void advance_step_flashinfer_indptr_kernel( ...@@ -135,20 +149,33 @@ __global__ void advance_step_flashinfer_indptr_kernel(
} }
__global__ void advance_step_flashinfer_indices_kernel( __global__ void advance_step_flashinfer_indices_kernel(
int num_threads, int num_seqs, int num_queries, int const* block_tables_ptr, int num_seqs, int num_queries, int const* block_tables_ptr,
int64_t const block_tables_stride, int* paged_kv_indices_ptr, int64_t const max_num_blocks_per_seq, int* paged_kv_indices_ptr,
int* paged_kv_indptr_ptr, int* block_table_bound_ptr) { int* paged_kv_indptr_ptr, int* block_table_bound_ptr) {
int idx = blockIdx.x * num_threads + threadIdx.x; // note: max_num_blocks_per_seq = block_tables.stride(0)
int row = idx / block_tables_stride; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int col = idx % block_tables_stride;
// when cuda graphs are enabled, paged_kv_indptr tensor
if (row < num_queries && col < block_table_bound_ptr[row]) { // has to be updated for the padded queries
paged_kv_indices_ptr[paged_kv_indptr_ptr[row] + col] = // tid represents a query# for paged_kv_indptr tensor
block_tables_ptr[row * block_tables_stride + col]; if (num_queries < tid && tid <= num_seqs) {
paged_kv_indptr_ptr[tid] = paged_kv_indptr_ptr[num_queries];
} }
// if cudagraph, fill padded seqs with the last valid seq's indptr
if (num_queries < row && row <= num_seqs) { // each thread processes a block_ptr in block_tables
paged_kv_indptr_ptr[row] = paged_kv_indptr_ptr[num_queries]; // block_tables shape: [num_queries, max_num_blocks_per_seq]
// paged_kv_indices is flattened block_tables.
for (int idx = tid; idx < (num_seqs * max_num_blocks_per_seq);
idx += (gridDim.x * blockDim.x)) {
// block_tables-row = paged_kv_indptr[queryNum]
int queryNum = idx / max_num_blocks_per_seq;
int col = idx % max_num_blocks_per_seq;
if (queryNum < num_queries && col < block_table_bound_ptr[queryNum]) {
int indices_arr_idx = paged_kv_indptr_ptr[queryNum] + col;
int block_tables_idx = queryNum * max_num_blocks_per_seq + col;
paged_kv_indices_ptr[indices_arr_idx] =
block_tables_ptr[block_tables_idx];
}
} }
} }
...@@ -211,7 +238,7 @@ void advance_step_flashinfer( ...@@ -211,7 +238,7 @@ void advance_step_flashinfer(
printf(" num_seqs = %d\n", num_seqs); printf(" num_seqs = %d\n", num_seqs);
printf(" num_queries = %d\n", num_queries); printf(" num_queries = %d\n", num_queries);
printf(" block_size = %d\n", block_size); printf(" block_size = %d\n", block_size);
printf(" block_tables.stride(0) = %d\n", block_tables.stride(0)); printf(" block_tables.stride(0) = %zu\n", block_tables.stride(0));
} }
// Verify all tensors // Verify all tensors
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong); verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
...@@ -236,22 +263,16 @@ void advance_step_flashinfer( ...@@ -236,22 +263,16 @@ void advance_step_flashinfer(
int threads; int threads;
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev); cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev);
if (logging) {
printf("launching kernel with %d blocks\n", blocks);
}
// TODO(will): support arbitrary block_tables stride int block_tables_stride = block_tables.stride(0);
if ((blocks * threads) / block_tables.stride(0) < num_queries) { TORCH_CHECK((blocks * threads > num_queries),
TORCH_CHECK(false, "multi-step: not enough threads to map to num_queries = ",
"multi-step: not enough threads to map block_table to" num_queries, " block_tables.stride(0) = ", block_tables.stride(0),
"FlashInfer's paged_kv_indices on GPU. Try reducing the number " " blocks = ", blocks, " max_threads = ", threads);
"of seqs,", if (logging) {
" increasing the block size or take smaller steps.", printf("launching kernels with %d blocks and %d threads\n", blocks,
" num_queries = ", num_queries, threads);
" block_tables.stride(0) = ", block_tables.stride(0),
" blocks = ", blocks, " max_threads = ", threads);
} }
advance_step_flashinfer_kernel<<<blocks, threads, 0, stream>>>( advance_step_flashinfer_kernel<<<blocks, threads, 0, stream>>>(
threads, num_seqs, num_queries, block_size, threads, num_seqs, num_queries, block_size,
reinterpret_cast<long*>(input_tokens.data_ptr()), reinterpret_cast<long*>(input_tokens.data_ptr()),
...@@ -270,7 +291,7 @@ void advance_step_flashinfer( ...@@ -270,7 +291,7 @@ void advance_step_flashinfer(
reinterpret_cast<int*>(block_table_bound.data_ptr())); reinterpret_cast<int*>(block_table_bound.data_ptr()));
advance_step_flashinfer_indices_kernel<<<blocks, threads, 0, stream>>>( advance_step_flashinfer_indices_kernel<<<blocks, threads, 0, stream>>>(
threads, num_seqs, num_queries, num_seqs, num_queries,
reinterpret_cast<int const*>(block_tables.data_ptr()), reinterpret_cast<int const*>(block_tables.data_ptr()),
block_tables.stride(0), block_tables.stride(0),
reinterpret_cast<int*>(paged_kv_indices.data_ptr()), reinterpret_cast<int*>(paged_kv_indices.data_ptr()),
...@@ -303,4 +324,4 @@ void advance_step_flashinfer( ...@@ -303,4 +324,4 @@ void advance_step_flashinfer(
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices, input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices,
paged_kv_indptr, paged_kv_last_page_len, block_table_bound); paged_kv_indptr, paged_kv_last_page_len, block_table_bound);
} }
\ No newline at end of file
...@@ -96,12 +96,15 @@ __global__ void static_scaled_int8_quant_kernel( ...@@ -96,12 +96,15 @@ __global__ void static_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out, scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type const* scale_ptr, const int hidden_size) { scale_type const* scale_ptr, const int hidden_size) {
int const tid = threadIdx.x; int const tid = threadIdx.x;
int const token_idx = blockIdx.x; int64_t const token_idx = blockIdx.x;
scale_type const scale = *scale_ptr; scale_type const scale = *scale_ptr;
// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;
for (int i = tid; i < hidden_size; i += blockDim.x) { for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] = float_to_int8_rn( out[i] = float_to_int8_rn(static_cast<float>(input[i]) / scale);
static_cast<float>(input[token_idx * hidden_size + i]) / scale);
} }
} }
...@@ -111,14 +114,18 @@ __global__ void static_scaled_int8_azp_quant_kernel( ...@@ -111,14 +114,18 @@ __global__ void static_scaled_int8_azp_quant_kernel(
scale_type const* scale_ptr, azp_type const* azp_ptr, scale_type const* scale_ptr, azp_type const* azp_ptr,
const int hidden_size) { const int hidden_size) {
int const tid = threadIdx.x; int const tid = threadIdx.x;
int const token_idx = blockIdx.x; int64_t const token_idx = blockIdx.x;
scale_type const scale = *scale_ptr; scale_type const scale = *scale_ptr;
azp_type const azp = *azp_ptr; azp_type const azp = *azp_ptr;
// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;
for (int i = tid; i < hidden_size; i += blockDim.x) { for (int i = tid; i < hidden_size; i += blockDim.x) {
auto const val = static_cast<float>(input[token_idx * hidden_size + i]); auto const val = static_cast<float>(input[i]);
auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp); auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp);
out[token_idx * hidden_size + i] = quant_val; out[i] = quant_val;
} }
} }
...@@ -127,12 +134,16 @@ __global__ void dynamic_scaled_int8_quant_kernel( ...@@ -127,12 +134,16 @@ __global__ void dynamic_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out, scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type* scale, const int hidden_size) { scale_type* scale, const int hidden_size) {
int const tid = threadIdx.x; int const tid = threadIdx.x;
int const token_idx = blockIdx.x; int64_t const token_idx = blockIdx.x;
float absmax_val = 0.0f; float absmax_val = 0.0f;
float const zero = 0.0f; float const zero = 0.0f;
// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;
for (int i = tid; i < hidden_size; i += blockDim.x) { for (int i = tid; i < hidden_size; i += blockDim.x) {
float val = static_cast<float>(input[token_idx * hidden_size + i]); float val = static_cast<float>(input[i]);
val = val > zero ? val : -val; val = val > zero ? val : -val;
absmax_val = val > absmax_val ? val : absmax_val; absmax_val = val > absmax_val ? val : absmax_val;
} }
...@@ -150,8 +161,7 @@ __global__ void dynamic_scaled_int8_quant_kernel( ...@@ -150,8 +161,7 @@ __global__ void dynamic_scaled_int8_quant_kernel(
float const tmp_scale = 127.0f / block_absmax_val; float const tmp_scale = 127.0f / block_absmax_val;
for (int i = tid; i < hidden_size; i += blockDim.x) { for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] = float_to_int8_rn( out[i] = float_to_int8_rn(static_cast<float>(input[i]) * tmp_scale);
static_cast<float>(input[token_idx * hidden_size + i]) * tmp_scale);
} }
} }
...@@ -159,13 +169,17 @@ template <typename scalar_t, typename scale_type, typename azp_type> ...@@ -159,13 +169,17 @@ template <typename scalar_t, typename scale_type, typename azp_type>
__global__ void dynamic_scaled_int8_azp_quant_kernel( __global__ void dynamic_scaled_int8_azp_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out, scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type* scale, azp_type* azp, const int hidden_size) { scale_type* scale, azp_type* azp, const int hidden_size) {
int const token_idx = blockIdx.x; int64_t const token_idx = blockIdx.x;
// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;
// Scan for the min and max value for this token // Scan for the min and max value for this token
float max_val = std::numeric_limits<float>::min(); float max_val = std::numeric_limits<float>::min();
float min_val = std::numeric_limits<float>::max(); float min_val = std::numeric_limits<float>::max();
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
auto val = static_cast<float>(input[token_idx * hidden_size + i]); auto val = static_cast<float>(input[i]);
max_val = std::max(max_val, val); max_val = std::max(max_val, val);
min_val = std::min(min_val, val); min_val = std::min(min_val, val);
} }
...@@ -200,10 +214,10 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( ...@@ -200,10 +214,10 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
// Quantize the values // Quantize the values
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
auto const val = static_cast<float>(input[token_idx * hidden_size + i]); auto const val = static_cast<float>(input[i]);
auto const quant_val = auto const quant_val =
int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val); int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val);
out[token_idx * hidden_size + i] = quant_val; out[i] = quant_val;
} }
} }
......
...@@ -8,6 +8,10 @@ ...@@ -8,6 +8,10 @@
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh" #include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh" #include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
using namespace vllm;
/* /*
This file defines quantized GEMM operations using the CUTLASS 2.x API, for This file defines quantized GEMM operations using the CUTLASS 2.x API, for
NVIDIA GPUs with SM versions prior to sm90 (Hopper). NVIDIA GPUs with SM versions prior to sm90 (Hopper).
...@@ -22,12 +26,11 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a, ...@@ -22,12 +26,11 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) { if (out.dtype() == torch::kBFloat16) {
return vllm::cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); TORCH_CHECK(out.dtype() == torch::kFloat16);
return vllm::cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>( return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} }
} }
...@@ -42,10 +45,10 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a, ...@@ -42,10 +45,10 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
if (bias) { if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(), TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype()); "currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBias>( return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias); out, a, b, a_scales, b_scales, *bias);
} else { } else {
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogue>( return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales); out, a, b, a_scales, b_scales);
} }
} }
...@@ -61,10 +64,10 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a, ...@@ -61,10 +64,10 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (azp) { if (azp) {
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBiasAzpToken>( return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias); out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else { } else {
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBiasAzp>( return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias); out, a, b, a_scales, b_scales, azp_adj, bias);
} }
} }
...@@ -78,12 +81,11 @@ void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a, ...@@ -78,12 +81,11 @@ void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) { if (out.dtype() == torch::kBFloat16) {
return vllm::cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); TORCH_CHECK(out.dtype() == torch::kFloat16);
return vllm::cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>( return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} }
} }
...@@ -98,10 +100,10 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a, ...@@ -98,10 +100,10 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
if (bias) { if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(), TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype()); "currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBias>( return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias); out, a, b, a_scales, b_scales, *bias);
} else { } else {
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogue>( return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales); out, a, b, a_scales, b_scales);
} }
} }
...@@ -117,10 +119,10 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a, ...@@ -117,10 +119,10 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (azp) { if (azp) {
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBiasAzpToken>( return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias); out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else { } else {
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBiasAzp>( return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias); out, a, b, a_scales, b_scales, azp_adj, bias);
} }
} }
...@@ -134,13 +136,12 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a, ...@@ -134,13 +136,12 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) { if (out.dtype() == torch::kBFloat16) {
return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t, return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>( Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
assert(out.dtype() == torch::kFloat16); assert(out.dtype() == torch::kFloat16);
return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} }
} else { } else {
...@@ -148,13 +149,13 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a, ...@@ -148,13 +149,13 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) { if (out.dtype() == torch::kBFloat16) {
return vllm::cutlass_gemm_sm89_fp8_dispatch< return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue>( cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); TORCH_CHECK(out.dtype() == torch::kFloat16);
return vllm::cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t, return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>( cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} }
} }
...@@ -170,10 +171,10 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a, ...@@ -170,10 +171,10 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
if (bias) { if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(), TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype()); "currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBias>( return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias); out, a, b, a_scales, b_scales, *bias);
} else { } else {
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogue>( return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales); out, a, b, a_scales, b_scales);
} }
} }
...@@ -189,10 +190,10 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a, ...@@ -189,10 +190,10 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (azp) { if (azp) {
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBiasAzpToken>( return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias); out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else { } else {
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBiasAzp>( return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias); out, a, b, a_scales, b_scales, azp_adj, bias);
} }
} }
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" #include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "broadcast_load_epilogue_c2x.hpp"
#include "common.hpp" #include "common.hpp"
// clang-format on // clang-format on
...@@ -71,307 +70,6 @@ struct enable_sm89_to_sm90 : Kernel { ...@@ -71,307 +70,6 @@ struct enable_sm89_to_sm90 : Kernel {
#endif #endif
} }
}; };
/*
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
template <typename T>
using ColOrScalarLoad =
cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowOrScalarLoad =
cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
template <typename T>
using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
template <typename T>
using RowOrZeroLoad =
cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template <typename Descriptor, typename T>
static auto args_from_tensor(torch::Tensor const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
return Arguments{data_ptr, tensor.numel() != 1};
} else {
// it would technically work but no use case as data_ptr is never nullptr
static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
return Arguments{data_ptr};
}
}
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T>
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
return Arguments{data_ptr};
}
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogue
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBias
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
protected:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args, bias_args};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBiasAzp
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
// Compute float(accum - azp_adj), both operands are int32_t
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
EVTComputeAzp>;
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBiasAzpToken
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
// Per-token azp term, shape (m,1)
using Azp = typename SUPER::template ColLoad<int32_t>;
// This is the AZP adjustment term, J @ B, shape (1,n)
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
// Compute azp * azp_adj
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, int32_t, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
// Compute float(accum - azp*azp_adj), all operands are int32_t
using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAcc =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
EVTComputeAcc>;
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
torch::Tensor const& azp,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
template <typename Arch, template <typename> typename ArchGuard, template <typename Arch, template <typename> typename ArchGuard,
typename ElementAB_, typename ElementD_, typename ElementAB_, typename ElementD_,
template <typename, typename> typename Epilogue_, typename TileShape, template <typename, typename> typename Epilogue_, typename TileShape,
......
...@@ -23,11 +23,12 @@ ...@@ -23,11 +23,12 @@
#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp"
#include "broadcast_load_epilogue_c3x.hpp" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "common.hpp" #include "common.hpp"
// clang-format on // clang-format on
using namespace cute; using namespace cute;
using namespace vllm;
/* /*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for This file defines quantized GEMM operations using the CUTLASS 3.x API, for
...@@ -56,305 +57,6 @@ struct enable_sm90_or_later : Kernel { ...@@ -56,305 +57,6 @@ struct enable_sm90_or_later : Kernel {
#endif #endif
} }
}; };
/*
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
template <typename T>
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<0>, Int<1>, Int<0>>>;
// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template <typename Descriptor, typename T>
static auto args_from_tensor(torch::Tensor const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
return Arguments{data_ptr, tensor.numel() != 1};
} else {
static_assert(!std::is_same_v<Descriptor, ColLoad<T, true>> &&
!std::is_same_v<Descriptor, RowLoad<T, true>>);
return Arguments{data_ptr};
}
}
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T>
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
std::is_same_v<Descriptor, RowLoad<T, true>>);
return Arguments{data_ptr};
}
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch.scaled_mm_.
A and B may be both either int8 or fp8_e4m3. A can be
quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogue
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBias
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args, bias_args};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBiasAzp
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD, true>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
// Compute float(accum - azp_adj), both operands are int32_t
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Accum, AzpWithAdj>;
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBiasAzpToken
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD, true>;
// Per-token azp term, shape (m,1)
using Azp = typename SUPER::template ColLoad<int32_t>;
// This is the AZP adjustment term, J @ B, shape (1,n)
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
// Compute azp * azp_adj
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, int32_t, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Azp, AzpAdj>;
// Compute float(accum - azp*azp_adj), all operands are int32_t
using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAcc =
cutlass::epilogue::fusion::Sm90EVT<ComputeAcc, Accum, EVTComputeAzp>;
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
torch::Tensor const& azp,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
template <typename ElementAB_, typename ElementD_, template <typename ElementAB_, typename ElementD_,
template <typename, typename, typename> typename Epilogue_, template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule, typename TileShape, typename ClusterShape, typename KernelSchedule,
...@@ -721,11 +423,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, ...@@ -721,11 +423,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
if (bias) { if (bias) {
TORCH_CHECK(bias->dtype() == c.dtype(), TORCH_CHECK(bias->dtype() == c.dtype(),
"currently bias dtype must match output dtype ", c.dtype()); "currently bias dtype must match output dtype ", c.dtype());
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBias>( return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBias>(
c, a, b, a_scales, b_scales, *bias); c, a, b, a_scales, b_scales, *bias);
} else { } else {
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogue>(c, a, b, a_scales, return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogue>(
b_scales); c, a, b, a_scales, b_scales);
} }
} }
...@@ -740,10 +442,10 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a, ...@@ -740,10 +442,10 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (azp) { if (azp) {
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBiasAzpToken>( return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias); out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else { } else {
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBiasAzp>( return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias); out, a, b, a_scales, b_scales, azp_adj, bias);
} }
} }
......
...@@ -21,7 +21,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a, ...@@ -21,7 +21,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias); c10::optional<torch::Tensor> const& bias);
#if defined CUDA_VERSION && CUDA_VERSION >= 12000 #if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
...@@ -114,26 +114,41 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, ...@@ -114,26 +114,41 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
if (version_num >= 90) { // Hopper
// Hopper
// Guard against compilation issues for sm90 kernels // Guard against compilation issues for sm90 kernels
#if defined CUDA_VERSION && CUDA_VERSION >= 12000 #if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
if (version_num >= 90) {
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias); cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
#else return;
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias); }
#endif #endif
} else if (version_num == 89) {
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if (version_num == 89) {
// Ada Lovelace // Ada Lovelace
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias); cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
} else if (version_num >= 80) { return;
}
if (version_num >= 80) {
// Ampere // Ampere
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias); cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
} else { return;
}
if (version_num >= 75) {
// Turing // Turing
TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias); cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
return;
} }
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm for a compute capability less than "
"CUDA device capability: ",
version_num);
} }
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
...@@ -174,25 +189,38 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, ...@@ -174,25 +189,38 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
"currently bias dtype must match output dtype ", c.dtype()); "currently bias dtype must match output dtype ", c.dtype());
at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
if (version_num >= 90) {
// Hopper
// Guard against compilation issues for sm90 kernels #if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
#if defined CUDA_VERSION && CUDA_VERSION >= 12000 if (version_num >= 90) {
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias); cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
#else return;
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias); }
#endif #endif
} else if (version_num == 89) {
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if (version_num == 89) {
// Ada Lovelace // Ada Lovelace
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias); cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
} else if (version_num >= 80) { return;
}
if (version_num >= 80) {
// Ampere // Ampere
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias); cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
} else { return;
// Turing
TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
} }
// Turing
TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
"CUDA device capability: ",
version_num);
} }
\ No newline at end of file
#include <ATen/cuda/CUDAContext.h> #include "common.cuh"
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <cmath>
#include "cuda_compat.h"
#include "dispatch_utils.h" #include "dispatch_utils.h"
#include <c10/cuda/CUDAGuard.h>
#ifndef USE_ROCM #ifndef USE_ROCM
#include <cub/util_type.cuh>
#include <cub/cub.cuh> #include <cub/cub.cuh>
#else #else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp> #include <hipcub/hipcub.hpp>
#endif #endif
#ifndef USE_ROCM
using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
std::numeric_limits<FP8_TYPE>::max();
#else
#include "amd/hip_float8.h"
using FP8_TYPE = c10::Float8_e4m3fnuz;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr auto FP8_E4M3_MAX = 224.0f;
#endif
namespace vllm { namespace vllm {
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
float old;
old = (value >= 0)
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(
atomicMin((unsigned int*)addr, __float_as_uint(value)));
return old;
}
template <bool is_scale_inverted>
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
float const scale) {
float x = 0.0f;
if constexpr (is_scale_inverted) {
x = val * scale;
} else {
x = val / scale;
}
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
#ifndef USE_ROCM
return static_cast<c10::Float8_e4m3fn>(r);
#else
// Use hardware cvt instruction for fp8 on rocm
return c10::Float8_e4m3fnuz(hip_fp8(r).data,
c10::Float8_e4m3fnuz::from_bits());
#endif
}
// Compute the absolute maximum m of the input tensor and store
// m / float8_e4m3::max() in *scale. Each thread block performs a
// reduction tree and the memory in scale is atomically updated.
// So to get the right answer, *scale needs to be initialized to
// a value <= 0.0 and we need to wait for all thread blocks to
// finish before consuming *scale.
template <typename scalar_t>
__global__ void segmented_max_reduction(float* __restrict__ scale,
const scalar_t* __restrict__ input,
int64_t num_elems) {
__shared__ float cache[1024];
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
// First store maximum for all values processes by
// the current thread in cache[threadIdx.x]
scalar_t tmp = 0.0;
while (i < num_elems) {
float x = static_cast<float>(input[i]);
tmp = max(tmp, fabs(x));
i += blockDim.x * gridDim.x;
}
cache[threadIdx.x] = tmp;
__syncthreads();
// Now perform parallel reduction within the thread block
int ib = blockDim.x / 2;
while (ib != 0) {
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
cache[threadIdx.x] = cache[threadIdx.x + ib];
}
__syncthreads();
ib /= 2;
}
// Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location
if (threadIdx.x == 0) {
atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
}
}
template <typename scalar_t>
struct __align__(8) vec4_t {
scalar_t x;
scalar_t y;
scalar_t z;
scalar_t w;
};
typedef struct __align__(4) {
FP8_TYPE x;
FP8_TYPE y;
FP8_TYPE z;
FP8_TYPE w;
}
float8x4_t;
template <typename scalar_t>
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
int64_t const num_elems, int const tid,
int const step) {
// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);
int64_t const num_vec_elems = num_elems >> 2;
float absmax_val = 0.0f;
#pragma unroll 4
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
absmax_val = max(absmax_val, fabs(in_vec.x));
absmax_val = max(absmax_val, fabs(in_vec.y));
absmax_val = max(absmax_val, fabs(in_vec.z));
absmax_val = max(absmax_val, fabs(in_vec.w));
}
// Handle the remaining elements if num_elems is not divisible by 4
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
absmax_val = max(absmax_val, fabs(input[i]));
}
return absmax_val;
}
template <typename scalar_t, bool is_scale_inverted>
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
scalar_t const* __restrict__ input,
float const scale,
int64_t const num_elems,
int const tid, int const step) {
// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
int64_t const num_vec_elems = num_elems >> 2;
#pragma unroll 4
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
float8x4_t out_vec;
out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.x), scale);
out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.y), scale);
out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.z), scale);
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.w), scale);
vectorized_out[i] = out_vec;
}
// Handle the remaining elements if num_elems is not divisible by 4
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
out[i] = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(input[i]), scale);
}
}
template <typename scalar_t> template <typename scalar_t>
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out, __global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
const scalar_t* __restrict__ input, const scalar_t* __restrict__ input,
...@@ -204,8 +35,10 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( ...@@ -204,8 +35,10 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
int const tid = threadIdx.x; int const tid = threadIdx.x;
int const token_idx = blockIdx.x; int const token_idx = blockIdx.x;
scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size]; // Use int64 to avoid overflowing an int32 when calculating this offset
FP8_TYPE* __restrict__ token_output = &out[token_idx * hidden_size]; int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
scalar_t const* __restrict__ token_input = &input[offset];
FP8_TYPE* __restrict__ token_output = &out[offset];
// For vectorization, token_input and token_output pointers need to be // For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively. // aligned at 8-byte and 4-byte addresses respectively.
......
#pragma once
#include "quantization/vectorization.cuh"
#include <cmath>
#include <c10/core/ScalarType.h>
#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
std::numeric_limits<FP8_TYPE>::max();
#else
#include <c10/util/Float8_e4m3fnuz.h>
#include "amd/hip_float8.h"
using FP8_TYPE = c10::Float8_e4m3fnuz;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr auto FP8_E4M3_MAX = 224.0f;
#endif
constexpr static auto kFp8Type = c10::CppTypeToScalarType<FP8_TYPE>::value;
namespace vllm {
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
float old;
old = (value >= 0)
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(
atomicMin((unsigned int*)addr, __float_as_uint(value)));
return old;
}
template <bool is_scale_inverted>
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
float const scale) {
float x = 0.0f;
if constexpr (is_scale_inverted) {
x = val * scale;
} else {
x = val / scale;
}
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
#ifndef USE_ROCM
return static_cast<c10::Float8_e4m3fn>(r);
#else
// Use hardware cvt instruction for fp8 on rocm
return c10::Float8_e4m3fnuz(hip_fp8(r).data,
c10::Float8_e4m3fnuz::from_bits());
#endif
}
// Compute the absolute maximum m of the input tensor and store
// m / float8_e4m3::max() in *scale. Each thread block performs a
// reduction tree and the memory in scale is atomically updated.
// So to get the right answer, *scale needs to be initialized to
// a value <= 0.0 and we need to wait for all thread blocks to
// finish before consuming *scale.
template <typename scalar_t>
__global__ void segmented_max_reduction(float* __restrict__ scale,
const scalar_t* __restrict__ input,
int64_t num_elems) {
__shared__ float cache[1024];
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
// First store maximum for all values processes by
// the current thread in cache[threadIdx.x]
scalar_t tmp = 0.0;
while (i < num_elems) {
float x = static_cast<float>(input[i]);
tmp = max(tmp, fabs(x));
i += blockDim.x * gridDim.x;
}
cache[threadIdx.x] = tmp;
__syncthreads();
// Now perform parallel reduction within the thread block
int ib = blockDim.x / 2;
while (ib != 0) {
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
cache[threadIdx.x] = cache[threadIdx.x + ib];
}
__syncthreads();
ib /= 2;
}
// Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location
if (threadIdx.x == 0) {
atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
}
}
template <typename scalar_t>
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
int64_t const num_elems, int const tid,
int const step) {
// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);
int64_t const num_vec_elems = num_elems >> 2;
float absmax_val = 0.0f;
#pragma unroll 4
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
absmax_val = max(absmax_val, fabs(in_vec.x));
absmax_val = max(absmax_val, fabs(in_vec.y));
absmax_val = max(absmax_val, fabs(in_vec.z));
absmax_val = max(absmax_val, fabs(in_vec.w));
}
// Handle the remaining elements if num_elems is not divisible by 4
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
absmax_val = max(absmax_val, fabs(input[i]));
}
return absmax_val;
}
template <typename scalar_t, bool is_scale_inverted>
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
scalar_t const* __restrict__ input,
float const scale,
int64_t const num_elems,
int const tid, int const step) {
using float8x4_t = q8x4_t<FP8_TYPE>;
// Vectorized input/output to better utilize memory bandwidth.
auto const* vectorized_in = reinterpret_cast<vec4_t<scalar_t> const*>(input);
auto* vectorized_out = reinterpret_cast<float8x4_t*>(out);
int64_t const num_vec_elems = num_elems >> 2;
#pragma unroll 4
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
float8x4_t out_vec;
out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.x), scale);
out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.y), scale);
out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.z), scale);
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.w), scale);
vectorized_out[i] = out_vec;
}
// Handle the remaining elements if num_elems is not divisible by 4
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
out[i] = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(input[i]), scale);
}
}
} // namespace vllm
\ No newline at end of file
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
#include "../gptq_marlin/marlin.cuh" #include "../gptq_marlin/marlin.cuh"
#include "../gptq_marlin/marlin_dtypes.cuh" #include "../gptq_marlin/marlin_dtypes.cuh"
#include "core/registration.h"
using namespace marlin; using namespace marlin;
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
...@@ -1303,3 +1305,7 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ...@@ -1303,3 +1305,7 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
} }
#endif #endif
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
}
\ No newline at end of file
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "../../dispatch_utils.h"
#include "layernorm_utils.cuh"
#include "quant_conversions.cuh"
namespace vllm {
template <typename scalar_t, typename scalar_out_t, bool has_residual = false>
__device__ void rms_norm_dynamic_per_token_quant_vec(
scalar_out_t* __restrict__ out, // [..., hidden_size]
float* __restrict__ scales, // [num_tokens]
scalar_t const* __restrict__ input, // [..., hidden_size]
scalar_t const* __restrict__ weight, // [hidden_size]
float const* scale_ub, float const var_epsilon,
float const min_scaling_factor, int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr) {
float rms = 0.0f;
float token_scale = 0.0f;
// Compute rms
vllm::vectorized::compute_rms<scalar_t, has_residual>(
&rms, input, hidden_size, var_epsilon, residual);
// Compute scale
vllm::vectorized::compute_dynamic_per_token_scales<scalar_t, scalar_out_t,
has_residual>(
&token_scale, scales, input, weight, rms, scale_ub, min_scaling_factor,
hidden_size, residual);
// RMS Norm + Quant
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, true,
has_residual>(
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual);
} else {
// FP8 - Do not invert token_scale for exact match with FBGemm
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, false,
has_residual>(
out, input, weight, rms, token_scale, hidden_size, residual);
}
}
// RMS norm + quant kernel
template <typename scalar_t, typename scalar_out_t, bool has_residual = false>
__global__ void rms_norm_dynamic_per_token_quant_kernel(
scalar_out_t* __restrict__ out, // [..., hidden_size]
float* __restrict__ scales, // [num_tokens]
scalar_t const* __restrict__ input, // [..., hidden_size]
scalar_t const* __restrict__ weight, // [hidden_size]
float const* scale_ub, float const var_epsilon,
float const min_scaling_factor, int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr) {
// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
bool const can_vectorize = hidden_size % 4 == 0;
if (can_vectorize) {
return rms_norm_dynamic_per_token_quant_vec<scalar_t, scalar_out_t,
has_residual>(
out, scales, input, weight, scale_ub, var_epsilon, min_scaling_factor,
hidden_size, residual);
}
float rms = 0.0f;
float token_scale = 0.0f;
// Compute RMS
vllm::compute_rms<scalar_t, has_residual>(&rms, input, hidden_size,
var_epsilon, residual);
// Compute Scale
vllm::compute_dynamic_per_token_scales<scalar_t, scalar_out_t, has_residual>(
&token_scale, scales, input, weight, rms, scale_ub, min_scaling_factor,
hidden_size, residual);
// RMS Norm + Quant
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
vllm::norm_and_quant<scalar_t, scalar_out_t, true, has_residual>(
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual);
} else {
// FP8 - Do not invert s_token_scale for exact match with FBGemm
vllm::norm_and_quant<scalar_t, scalar_out_t, false, has_residual>(
out, input, weight, rms, token_scale, hidden_size, residual);
}
}
} // namespace vllm
// Residual add + RMS norm + dynamic per token
template <typename scalar_in_t>
void rms_norm_dynamic_per_token_quant_dispatch(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size]
torch::Tensor const& weight, // [hidden_size]
torch::Tensor& scales, // [num_tokens]
double const var_epsilon, // Variance epsilon used in norm calculation
std::optional<at::Tensor> const& scale_ub,
std::optional<at::Tensor>& residual) {
int32_t hidden_size = input.size(-1);
int32_t num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const float min_scaling_factor =
out.dtype() == torch::kInt8
? std::numeric_limits<float>::epsilon()
: 1.0f / (std::numeric_limits<c10::Float8_e4m3fn>::max() * 512.f);
if (residual.has_value()) {
VLLM_DISPATCH_QUANT_TYPES(
out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] {
vllm::rms_norm_dynamic_per_token_quant_kernel<scalar_in_t, scalar_t,
true>
<<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
var_epsilon, min_scaling_factor, hidden_size,
residual->data_ptr<scalar_in_t>());
});
} else {
VLLM_DISPATCH_QUANT_TYPES(
out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] {
vllm::rms_norm_dynamic_per_token_quant_kernel<scalar_in_t, scalar_t,
false>
<<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
var_epsilon, min_scaling_factor, hidden_size, nullptr);
});
}
}
void rms_norm_dynamic_per_token_quant(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size]
torch::Tensor const& weight, // [hidden_size]
torch::Tensor& scales, // [num_tokens]
double const var_epsilon, // Variance epsilon used in norm calculation
std::optional<at::Tensor> scale_ub, std::optional<at::Tensor> residual) {
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
if (scale_ub.has_value()) {
TORCH_CHECK(out.dtype() == kFp8Type);
}
TORCH_CHECK(scales.dtype() == torch::kFloat32);
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] {
rms_norm_dynamic_per_token_quant_dispatch<scalar_t>(
out, input, weight, scales, var_epsilon, scale_ub, residual);
});
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment