Commit 4c676e3d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.9.1' into v0.9.1-dev

parents b4c4464d b6553be1
...@@ -179,6 +179,31 @@ void merge_attn_states(torch::Tensor& output, ...@@ -179,6 +179,31 @@ void merge_attn_states(torch::Tensor& output,
const torch::Tensor& prefix_lse, const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output, const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse); const torch::Tensor& suffix_lse);
void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
bool causal);
void convert_vertical_slash_indexes_mergehead(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
torch::Tensor vertical_indices_count, // [N_HEADS, ]
torch::Tensor slash_indices_count, int64_t context_size,
int64_t block_size_M, int64_t block_size_N, bool causal);
#endif #endif
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
...@@ -193,6 +218,11 @@ void rms_norm_opt(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weigh ...@@ -193,6 +218,11 @@ void rms_norm_opt(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weigh
void fused_add_rms_norm_opt(torch::Tensor& input, torch::Tensor& residual, void fused_add_rms_norm_opt(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon); torch::Tensor& weight, double epsilon);
void apply_repetition_penalties_(torch::Tensor& logits,
const torch::Tensor& prompt_mask,
const torch::Tensor& output_mask,
const torch::Tensor& repetition_penalties);
// void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, // void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& weight, torch::Tensor& scale, // torch::Tensor& weight, torch::Tensor& scale,
// double epsilon); // double epsilon);
...@@ -212,13 +242,13 @@ void fused_add_rms_norm_opt(torch::Tensor& input, torch::Tensor& residual, ...@@ -212,13 +242,13 @@ void fused_add_rms_norm_opt(torch::Tensor& input, torch::Tensor& residual,
// std::optional<torch::Tensor> residual); // std::optional<torch::Tensor> residual);
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int64_t head_size, std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox); torch::Tensor& cos_sin_cache, bool is_neox);
void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int64_t head_size, std::optional<torch::Tensor> key,
torch::Tensor& cos_sin_cache, bool is_neox, int64_t head_size, torch::Tensor& cos_sin_cache,
int64_t rot_dim, bool is_neox, int64_t rot_dim,
torch::Tensor& cos_sin_cache_offsets); torch::Tensor& cos_sin_cache_offsets);
void rotary_embedding_tgi( void rotary_embedding_tgi(
torch::Tensor& query, torch::Tensor& query,
...@@ -230,6 +260,9 @@ void rotary_embedding_tgi( ...@@ -230,6 +260,9 @@ void rotary_embedding_tgi(
void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale);
void mul_and_silu(torch::Tensor& out, torch::Tensor& input); void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
...@@ -316,6 +349,10 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W, ...@@ -316,6 +349,10 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W,
torch::Tensor num_tokens_post_padded, int64_t type, torch::Tensor num_tokens_post_padded, int64_t type,
int64_t row, int64_t top_k, int64_t tokens); int64_t row, int64_t top_k, int64_t tokens);
torch::Tensor ggml_moe_a8_vec(torch::Tensor X, torch::Tensor W,
torch::Tensor topk_ids, int64_t top_k,
int64_t type, int64_t row, int64_t tokens);
int64_t ggml_moe_get_block_size(int64_t type); int64_t ggml_moe_get_block_size(int64_t type);
#ifndef USE_ROCM #ifndef USE_ROCM
...@@ -340,13 +377,29 @@ void cutlass_moe_mm( ...@@ -340,13 +377,29 @@ void cutlass_moe_mm(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides); torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
void cutlass_fp4_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets);
void get_cutlass_moe_mm_data( void get_cutlass_moe_mm_data(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation, torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k); const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets);
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m, const int64_t n,
const int64_t k);
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
...@@ -369,6 +422,12 @@ std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a); ...@@ -369,6 +422,12 @@ std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
torch::Tensor& output_scale, torch::Tensor& output_scale,
torch::Tensor const& input_scale); torch::Tensor const& input_scale);
void scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
#endif #endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
......
...@@ -161,6 +161,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { ...@@ -161,6 +161,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
int64_t num_tokens = input.numel() / input.size(-1); \ int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \ dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \ dim3 block(std::min(d, 1024)); \
if (num_tokens == 0) { \
return; \
}
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES( \
......
...@@ -38,12 +38,14 @@ inline __device__ void apply_rotary_embedding( ...@@ -38,12 +38,14 @@ inline __device__ void apply_rotary_embedding(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads, // head_size] or [num_tokens, num_heads,
// head_size] // head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, scalar_t* __restrict__ key, // nullptr or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads, // head_size] or [num_tokens, num_kv_heads,
// head_size] // head_size]
const scalar_t* cache_ptr, const int head_size, const int num_heads, const scalar_t* cache_ptr, const int head_size, const int num_heads,
const int num_kv_heads, const int rot_dim, const int token_idx, const int num_kv_heads, const int rot_dim, const int token_idx,
const int64_t query_stride, const int64_t key_stride) { const int64_t query_stride, const int64_t key_stride,
const int64_t head_stride) {
const int embed_dim = rot_dim / 2; const int embed_dim = rot_dim / 2;
const scalar_t* cos_ptr = cache_ptr; const scalar_t* cos_ptr = cache_ptr;
const scalar_t* sin_ptr = cache_ptr + embed_dim; const scalar_t* sin_ptr = cache_ptr + embed_dim;
...@@ -51,19 +53,23 @@ inline __device__ void apply_rotary_embedding( ...@@ -51,19 +53,23 @@ inline __device__ void apply_rotary_embedding(
const int nq = num_heads * embed_dim; const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) { for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim; const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * query_stride + head_idx * head_size; const int64_t token_head =
token_idx * query_stride + head_idx * head_stride;
const int rot_offset = i % embed_dim; const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>( apply_token_rotary_embedding<scalar_t, IS_NEOX>(
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
} }
const int nk = num_kv_heads * embed_dim; if (key != nullptr) {
for (int i = threadIdx.x; i < nk; i += blockDim.x) { const int nk = num_kv_heads * embed_dim;
const int head_idx = i / embed_dim; for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int head_idx = i / embed_dim;
const int rot_offset = i % embed_dim; const int64_t token_head =
apply_token_rotary_embedding<scalar_t, IS_NEOX>( token_idx * key_stride + head_idx * head_stride;
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
}
} }
} }
...@@ -74,13 +80,15 @@ __global__ void rotary_embedding_kernel( ...@@ -74,13 +80,15 @@ __global__ void rotary_embedding_kernel(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads, // head_size] or [num_tokens, num_heads,
// head_size] // head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, scalar_t* __restrict__ key, // nullptr or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads, // head_size] or [num_tokens, num_kv_heads,
// head_size] // head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
// 2] // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size) { const int64_t head_stride, const int num_heads, const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token. // Each thread block is responsible for one token.
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
...@@ -88,7 +96,7 @@ __global__ void rotary_embedding_kernel( ...@@ -88,7 +96,7 @@ __global__ void rotary_embedding_kernel(
apply_rotary_embedding<scalar_t, IS_NEOX>( apply_rotary_embedding<scalar_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride); token_idx, query_stride, key_stride, head_stride);
} }
template <typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
...@@ -98,15 +106,16 @@ __global__ void batched_rotary_embedding_kernel( ...@@ -98,15 +106,16 @@ __global__ void batched_rotary_embedding_kernel(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads, // head_size] or [num_tokens, num_heads,
// head_size] // head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, scalar_t* __restrict__ key, // nullptr or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads, // head_size] or [num_tokens, num_kv_heads,
// head_size] // head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
// 2] // 2]
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
// or [num_tokens]
const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size) { const int64_t head_stride, const int num_heads, const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token. // Each thread block is responsible for one token.
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
...@@ -116,7 +125,7 @@ __global__ void batched_rotary_embedding_kernel( ...@@ -116,7 +125,7 @@ __global__ void batched_rotary_embedding_kernel(
apply_rotary_embedding<scalar_t, IS_NEOX>( apply_rotary_embedding<scalar_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride); token_idx, query_stride, key_stride, head_stride);
} }
} // namespace vllm } // namespace vllm
...@@ -127,10 +136,12 @@ void rotary_embedding( ...@@ -127,10 +136,12 @@ void rotary_embedding(
// [num_tokens, num_heads * head_size] or // [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or // [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size] // [num_tokens, num_heads, head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or std::optional<torch::Tensor> key,
// [num_tokens, num_kv_heads * head_size] or // null or
// [batch_size, seq_len, num_heads, head_size] or // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_heads, head_size] // [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t head_size, int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) { bool is_neox) {
...@@ -138,40 +149,46 @@ void rotary_embedding( ...@@ -138,40 +149,46 @@ void rotary_embedding(
int64_t num_tokens = positions.numel(); int64_t num_tokens = positions.numel();
int positions_ndim = positions.dim(); int positions_ndim = positions.dim();
// Make sure num_tokens dim is consistent across positions, query, and key. // Make sure num_tokens dim is consistent across positions, query, and key
TORCH_CHECK( TORCH_CHECK(
positions_ndim == 1 || positions_ndim == 2, positions_ndim == 1 || positions_ndim == 2,
"positions must have shape [num_tokens] or [batch_size, seq_len]"); "positions must have shape [num_tokens] or [batch_size, seq_len]");
if (positions_ndim == 1) { if (positions_ndim == 1) {
TORCH_CHECK( TORCH_CHECK(query.size(0) == positions.size(0) &&
query.size(0) == positions.size(0) && key.size(0) == positions.size(0), (!key.has_value() || key->size(0) == positions.size(0)),
"query, key and positions must have the same number of tokens"); "query, key and positions must have the same number of tokens");
} }
if (positions_ndim == 2) { if (positions_ndim == 2) {
TORCH_CHECK( TORCH_CHECK(
query.size(0) == positions.size(0) && query.size(0) == positions.size(0) &&
key.size(0) == positions.size(0) && (!key.has_value() || key->size(0) == positions.size(0)) &&
query.size(1) == positions.size(1) && query.size(1) == positions.size(1) &&
key.size(1) == positions.size(1), (!key.has_value() || key->size(1) == positions.size(1)),
"query, key and positions must have the same batch_size and seq_len"); "query, key and positions must have the same batch_size and seq_len");
} }
// Make sure head_size is valid for query and key // Make sure head_size is valid for query and key
// hidden_size = num_heads * head_size // hidden_size = num_heads * head_size
int query_hidden_size = query.numel() / num_tokens; int query_hidden_size = query.numel() / num_tokens;
int key_hidden_size = key.numel() / num_tokens; int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0;
TORCH_CHECK(query_hidden_size % head_size == 0); TORCH_CHECK(query_hidden_size % head_size == 0);
TORCH_CHECK(key_hidden_size % head_size == 0); TORCH_CHECK(key_hidden_size % head_size == 0);
// Make sure query and key have consistent number of heads // Make sure query and key have consistent number of heads
int num_heads = query_hidden_size / head_size; int num_heads = query_hidden_size / head_size;
int num_kv_heads = key_hidden_size / head_size; int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads;
TORCH_CHECK(num_heads % num_kv_heads == 0); TORCH_CHECK(num_heads % num_kv_heads == 0);
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int seq_dim_idx = positions_ndim - 1; int seq_dim_idx = positions_ndim - 1;
int64_t query_stride = query.stride(seq_dim_idx); int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.stride(seq_dim_idx); int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
// Determine head stride: for [*, heads, head_size] use stride of last dim;
// for flat [*, heads*head_size], heads blocks are contiguous of size
// head_size
int query_ndim = query.dim();
int64_t head_stride =
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512)); dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
...@@ -181,15 +198,16 @@ void rotary_embedding( ...@@ -181,15 +198,16 @@ void rotary_embedding(
if (is_neox) { if (is_neox) {
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>( vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim, key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
query_stride, key_stride, num_heads, num_kv_heads, head_size); cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride, key_stride,
head_stride, num_heads, num_kv_heads, head_size);
} else { } else {
vllm::rotary_embedding_kernel<scalar_t, false> vllm::rotary_embedding_kernel<scalar_t, false>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
rot_dim, query_stride, key_stride, num_heads, num_kv_heads, cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
head_size); key_stride, head_stride, num_heads, num_kv_heads, head_size);
} }
}); });
} }
...@@ -204,10 +222,12 @@ void batched_rotary_embedding( ...@@ -204,10 +222,12 @@ void batched_rotary_embedding(
// [num_tokens, num_heads * head_size] or // [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or // [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size] // [num_tokens, num_heads, head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or std::optional<torch::Tensor>
// [num_tokens, num_kv_heads * head_size] or key, // null or
// [batch_size, seq_len, num_heads, head_size] or // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_heads, head_size] // [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t head_size, int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox, int64_t rot_dim, bool is_neox, int64_t rot_dim,
...@@ -221,38 +241,44 @@ void batched_rotary_embedding( ...@@ -221,38 +241,44 @@ void batched_rotary_embedding(
"cos_sin_cache_offsets"); "cos_sin_cache_offsets");
int positions_ndim = positions.dim(); int positions_ndim = positions.dim();
// Make sure num_tokens dim is consistent across positions, query, and key. // Make sure num_tokens dim is consistent across positions, query, and key
TORCH_CHECK( TORCH_CHECK(
positions_ndim == 1 || positions_ndim == 2, positions_ndim == 1 || positions_ndim == 2,
"positions must have shape [num_tokens] or [batch_size, seq_len]"); "positions must have shape [num_tokens] or [batch_size, seq_len]");
if (positions_ndim == 1) { if (positions_ndim == 1) {
TORCH_CHECK( TORCH_CHECK(query.size(0) == positions.size(0) &&
query.size(0) == positions.size(0) && key.size(0) == positions.size(0), (!key.has_value() || key->size(0) == positions.size(0)),
"query, key and positions must have the same number of tokens"); "query, key and positions must have the same number of tokens");
} }
if (positions_ndim == 2) { if (positions_ndim == 2) {
TORCH_CHECK( TORCH_CHECK(
query.size(0) == positions.size(0) && query.size(0) == positions.size(0) &&
key.size(0) == positions.size(0) && (!key.has_value() || key->size(0) == positions.size(0)) &&
query.size(1) == positions.size(1) && query.size(1) == positions.size(1) &&
key.size(1) == positions.size(1), (!key.has_value() || key->size(1) == positions.size(1)),
"query, key and positions must have the same batch_size and seq_len"); "query, key and positions must have the same batch_size and seq_len");
} }
// Make sure head_size is valid for query and key // Make sure head_size is valid for query and key
int query_hidden_size = query.numel() / num_tokens; int query_hidden_size = query.numel() / num_tokens;
int key_hidden_size = key.numel() / num_tokens; int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0;
TORCH_CHECK(query_hidden_size % head_size == 0); TORCH_CHECK(query_hidden_size % head_size == 0);
TORCH_CHECK(key_hidden_size % head_size == 0); TORCH_CHECK(key_hidden_size % head_size == 0);
// Make sure query and key have concistent number of heads // Make sure query and key have concistent number of heads
int num_heads = query_hidden_size / head_size; int num_heads = query_hidden_size / head_size;
int num_kv_heads = key_hidden_size / head_size; int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads;
TORCH_CHECK(num_heads % num_kv_heads == 0); TORCH_CHECK(num_heads % num_kv_heads == 0);
int seq_dim_idx = positions_ndim - 1; int seq_dim_idx = positions_ndim - 1;
int64_t query_stride = query.stride(seq_dim_idx); int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.stride(seq_dim_idx); int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
// Determine head stride: for [*, heads, head_size] use stride of last dim;
// for flat [*, heads*head_size], heads blocks are contiguous of size
// head_size
int query_ndim = query.dim();
int64_t head_stride =
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512)); dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
...@@ -263,16 +289,18 @@ void batched_rotary_embedding( ...@@ -263,16 +289,18 @@ void batched_rotary_embedding(
vllm::batched_rotary_embedding_kernel<scalar_t, true> vllm::batched_rotary_embedding_kernel<scalar_t, true>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride, cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
key_stride, num_heads, num_kv_heads, head_size); key_stride, head_stride, num_heads, num_kv_heads, head_size);
} else { } else {
vllm::batched_rotary_embedding_kernel<scalar_t, false> vllm::batched_rotary_embedding_kernel<scalar_t, false>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride, cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
key_stride, num_heads, num_kv_heads, head_size); key_stride, head_stride, num_heads, num_kv_heads, head_size);
} }
}); });
} }
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <cmath>
#include "core/math.hpp"
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "quantization/fp8/common.cuh"
namespace vllm {
template <typename T>
__device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x)
return (T)(((float)x) / (1.0f + expf((float)-x)));
}
// Activation and gating kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
typename fp8_type>
__global__ void act_and_mul_quant_kernel(
fp8_type* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const float* scale, const int d) {
const int32_t blocks_per_token = gridDim.y;
const int32_t elems_per_128bit_load = (128 / 8) / sizeof(scalar_t);
// We don't expect the hidden dimension to exceed 32 bits so int32 should
// be safe here.
const int32_t tgt_elems_per_block = div_ceil(d, blocks_per_token);
const int32_t elems_per_block =
round_to_next_multiple_of(tgt_elems_per_block, elems_per_128bit_load);
const int32_t block_start = blockIdx.y * elems_per_block;
int32_t block_end = block_start + elems_per_block;
block_end = block_end > d ? d : block_end;
// token_idx is 64 bit to prevent 32 bit overflow when the number of tokens
// is very large
const int64_t token_idx = blockIdx.x;
const scalar_t* __restrict__ x_ptr = input + token_idx * 2 * d;
const scalar_t* __restrict__ y_ptr = input + token_idx * 2 * d + d;
fp8_type* __restrict__ out_ptr = out + token_idx * d;
// 128-bit vectorized code
const int32_t vec_loop_end =
round_to_previous_multiple_of(elems_per_128bit_load, block_end);
const int32_t vec_end_idx = vec_loop_end / elems_per_128bit_load;
const int32_t vec_start_idx = block_start / elems_per_128bit_load;
const int4* __restrict__ x_128bit_ptr = reinterpret_cast<const int4*>(x_ptr);
const int4* __restrict__ y_128bit_ptr = reinterpret_cast<const int4*>(y_ptr);
int2* __restrict__ out_128bit_ptr = reinterpret_cast<int2*>(out_ptr);
float inverted_scale = 1 / *scale;
#pragma unroll
for (int32_t vec_idx = vec_start_idx + threadIdx.x; vec_idx < vec_end_idx;
vec_idx += blockDim.x) {
const int4 x_128bit = VLLM_LDG(&x_128bit_ptr[vec_idx]);
const int4 y_128bit = VLLM_LDG(&y_128bit_ptr[vec_idx]);
using scalar_128bit_vec_t = std::array<scalar_t, elems_per_128bit_load>;
using scalar_64bit_vec_t = std::array<fp8_type, elems_per_128bit_load>;
scalar_64bit_vec_t out_vec;
const auto x_vec = reinterpret_cast<scalar_128bit_vec_t const&>(x_128bit);
const auto y_vec = reinterpret_cast<scalar_128bit_vec_t const&>(y_128bit);
#pragma unroll
for (int i = 0; i < elems_per_128bit_load; i++) {
out_vec[i] = scaled_fp8_conversion<true, fp8_type>(
ACT_FN(x_vec[i]) * y_vec[i], inverted_scale);
}
out_128bit_ptr[vec_idx] = reinterpret_cast<const int2&>(out_vec);
}
// Scalar cleanup code
if (block_end > vec_loop_end) {
for (int64_t idx = vec_loop_end + threadIdx.x; idx < block_end;
idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
out_ptr[idx] =
scaled_fp8_conversion<true, fp8_type>(ACT_FN(x) * y, inverted_scale);
}
}
}
} // namespace vllm
// Launch activation, gating, and quantize kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens, num_tokens > 16 ? num_tokens > 32 ? 1 : 2 : 4); \
dim3 block(std::min(d, 512)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel", [&] { \
VLLM_DISPATCH_FP8_TYPES( \
out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \
vllm::act_and_mul_quant_kernel<scalar_t, KERNEL<scalar_t>, \
fp8_t> \
<<<grid, block, 0, stream>>>(out.data_ptr<fp8_t>(), \
input.data_ptr<scalar_t>(), \
scale.data_ptr<float>(), d); \
}); \
});
void silu_and_mul_quant(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., 2 * d]
torch::Tensor& scale) {
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn ||
out.dtype() == torch::kFloat8_e4m3fnuz);
TORCH_CHECK(input.dtype() == torch::kFloat16 ||
input.dtype() == torch::kBFloat16);
TORCH_CHECK(input.size(-1) % 2 == 0);
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}
...@@ -26,7 +26,13 @@ static inline __device__ int8_t float_to_int8_rn(float x) { ...@@ -26,7 +26,13 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
float dst = std::nearbyint(x); float dst = std::nearbyint(x);
// saturate // saturate
dst = std::clamp(dst, i8_min, i8_max);
// See https://github.com/pytorch/pytorch/issues/127666
// See https://github.com/llvm/llvm-project/issues/95183
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
// Arch/gcc14. The following replaces std::clamp usage with similar logic
// dst = std::clamp(dst, i8_min, i8_max);
dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst;
return static_cast<int8_t>(dst); return static_cast<int8_t>(dst);
#else #else
// CUDA path // CUDA path
...@@ -79,7 +85,13 @@ static inline __device__ int8_t int32_to_int8(int32_t x) { ...@@ -79,7 +85,13 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
static_cast<int32_t>(std::numeric_limits<int8_t>::max()); static_cast<int32_t>(std::numeric_limits<int8_t>::max());
// saturate // saturate
int32_t dst = std::clamp(x, i8_min, i8_max);
// See https://github.com/pytorch/pytorch/issues/127666
// See https://github.com/llvm/llvm-project/issues/95183
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
// Arch/gcc14. The following replaces std::clamp usage with similar logic
// int32_t dst = std::clamp(x, i8_min, i8_max);
int32_t dst = (x < i8_min) ? i8_min : (x > i8_max) ? i8_max : x;
return static_cast<int8_t>(dst); return static_cast<int8_t>(dst);
#else #else
// CUDA path // CUDA path
......
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
if (out.dtype() == torch::kBFloat16) {
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm
#pragma once
#include "cuda_utils.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_gemm_caller.cuh"
namespace vllm {
using namespace cute;
// clang-format off
template <class OutType, int ScaleGranularityM,
int ScaleGranularityN, int ScaleGranularityK,
class MmaTileShape, class ClusterShape,
class EpilogueScheduler, class MainloopScheduler,
bool swap_ab_ = false>
struct cutlass_3x_gemm_fp8_blockwise {
static constexpr bool swap_ab = swap_ab_;
using ElementAB = cutlass::float_e4m3_t;
using ElementA = ElementAB;
using LayoutA = cutlass::layout::RowMajor;
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
using ElementB = ElementAB;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
using ElementD = OutType;
using LayoutD = cutlass::layout::RowMajor;
using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose<LayoutD>::type;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
using ElementC = void; // TODO: support bias
using LayoutC = LayoutD;
using LayoutC_Transpose = LayoutD_Transpose;
static constexpr int AlignmentC = AlignmentD;
using ElementAccumulator = float;
using ElementCompute = float;
using ElementBlockScale = float;
using ScaleConfig = conditional_t<swap_ab,
cutlass::detail::Sm100BlockwiseScaleConfig<
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
cute::UMMA::Major::K, cute::UMMA::Major::MN>,
cutlass::detail::Sm100BlockwiseScaleConfig<
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
cute::UMMA::Major::MN, cute::UMMA::Major::K>>;
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
using ArchTag = cutlass::arch::Sm100;
using OperatorClass = cutlass::arch::OpClassTensorOp;
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
using ElementScalar = float;
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
MmaTileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementCompute,
ElementC,
conditional_t<swap_ab, LayoutC_Transpose, LayoutC>,
AlignmentC,
ElementD,
conditional_t<swap_ab, LayoutD_Transpose, LayoutD>,
AlignmentD,
EpilogueScheduler,
DefaultOperation
>::CollectiveOp;
using StageCountType = cutlass::gemm::collective::StageCountAuto;
using CollectiveMainloop = conditional_t<swap_ab,
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementB,
cute::tuple<LayoutB_Transpose, LayoutSFA>,
AlignmentB,
ElementA,
cute::tuple<LayoutA_Transpose, LayoutSFB>,
AlignmentA,
ElementAccumulator,
MmaTileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduler
>::CollectiveOp,
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementA,
cute::tuple<LayoutA, LayoutSFA>,
AlignmentA,
ElementB,
cute::tuple<LayoutB, LayoutSFB>,
AlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduler
>::CollectiveOp>;
using KernelType = enable_sm100_only<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
struct GemmKernel : public KernelType {};
};
template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
static constexpr bool swap_ab = Gemm::swap_ab;
using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideD = typename Gemm::GemmKernel::StrideD;
using StrideC = typename Gemm::GemmKernel::StrideC;
using LayoutSFA = typename Gemm::LayoutSFA;
using LayoutSFB = typename Gemm::LayoutSFB;
using ScaleConfig = typename Gemm::ScaleConfig;
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
StrideA a_stride;
StrideB b_stride;
StrideC c_stride;
a_stride =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
b_stride =
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
c_stride =
cutlass::make_cute_packed_stride(StrideC{}, swap_ab ? cute::make_shape(n, m, 1) : cute::make_shape(m, n, 1));
LayoutSFA layout_SFA = swap_ab ?
ScaleConfig::tile_atom_to_shape_SFA(make_shape(n, m, k, 1)) :
ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
LayoutSFB layout_SFB = swap_ab ?
ScaleConfig::tile_atom_to_shape_SFB(make_shape(n, m, k, 1)) :
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
auto mainloop_args = [&](){
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
if (swap_ab) {
return typename GemmKernel::MainloopArguments{
b_ptr, b_stride, a_ptr, a_stride,
b_scales_ptr, layout_SFA, a_scales_ptr, layout_SFB
};
}
else {
return typename GemmKernel::MainloopArguments{
a_ptr, a_stride, b_ptr, b_stride,
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
};
}
}();
auto prob_shape = swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1);
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{
{}, c_ptr, c_stride, c_ptr, c_stride};
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
epilogue_args);
}
template <typename OutType>
void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
constexpr int TILE_K = 128;
// TODO: better heuristics
bool swap_ab = (m < 16) || (m % 4 != 0);
bool use_tma_epilogue = (m * n) % 4 == 0;
if (!swap_ab) {
constexpr int TILE_N = 128;
int tile_m = 256;
if (cuda_utils::ceil_div(n, TILE_N) * cuda_utils::ceil_div(m, 64) <= sms) {
tile_m = 64;
}
else if (cuda_utils::ceil_div(n, TILE_N) * cuda_utils::ceil_div(m, 128) <= sms) {
tile_m = 128;
}
if (tile_m == 64) {
if (use_tma_epilogue) {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>,
Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
out, a, b, a_scales, b_scales);
} else {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>,
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
out, a, b, a_scales, b_scales);
}
} else if (tile_m == 128) {
if (use_tma_epilogue) {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>,
Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
out, a, b, a_scales, b_scales);
} else {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>,
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
out, a, b, a_scales, b_scales);
}
} else { // tile_m == 256
if (use_tma_epilogue) {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>,
Shape<_2, _1, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
out, a, b, a_scales, b_scales);
} else {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>,
Shape<_2, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized2Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
out, a, b, a_scales, b_scales);
}
}
} else {
// TODO: Test more tile N configs
constexpr int TILE_M = 128;
constexpr int TILE_N = 16;
// TMA epilogue isn't compatible with Swap A/B
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, TILE_M, 1, TILE_K, Shape<Int<TILE_M>, Int<TILE_N>, Int<TILE_K>>,
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100, true>>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm
#include <torch/all.h>
#include "cuda_utils.h"
#include "cutlass_extensions/common.hpp"
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias,
Fp8Func fp8_func, Int8Func int8_func,
BlockwiseFunc blockwise_func) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
int M = a.size(0), N = b.size(1), K = a.size(1);
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (a.dtype() == torch::kFloat8_e4m3fn) {
fp8_func(c, a, b, a_scales, b_scales, bias);
} else {
TORCH_CHECK(a.dtype() == torch::kInt8);
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
int8_func(c, a, b, a_scales, b_scales, bias);
} else {
TORCH_CHECK(false, "Int8 not supported for this architecture");
}
}
} else {
TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
int32_t version_num = get_sm_version_num();
if (version_num >= 100) {
TORCH_CHECK(
a.size(0) == a_scales.size(0) &&
cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1),
"a_scale_group_shape must be [1, 128].");
TORCH_CHECK(
cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) &&
cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1),
"b_scale_group_shape must be [128, 128].");
} else {
// TODO: Remove this after using cutlass sm90 blockwise scaling gemm
// kernel, or introducing ceil_div to the load_init() of mainloop.
using GroupShape = std::array<int64_t, 2>;
auto make_group_shape = [](torch::Tensor const& x,
torch::Tensor const& s) -> GroupShape {
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
cuda_utils::ceil_div(x.size(1), s.size(1))};
};
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
// 1x128 per-token group scales for activations
// 128x128 blockwise scales for weights
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
b_scale_group_shape == GroupShape{128, 128} &&
a.dtype() == torch::kFloat8_e4m3fn &&
b.dtype() == torch::kFloat8_e4m3fn),
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
"a_scale_group_shape must be [1, 128]. Got: [",
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
"]\n"
"b_scale_group_shape must be [128, 128]. Got: [",
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
}
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
blockwise_func(c, a, b, a_scales, b_scales);
}
}
...@@ -36,4 +36,9 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, ...@@ -36,4 +36,9 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias); std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
} // namespace vllm } // namespace vllm
...@@ -15,6 +15,7 @@ using c3x::cutlass_gemm_caller; ...@@ -15,6 +15,7 @@ using c3x::cutlass_gemm_caller;
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue> template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_default { struct sm100_fp8_config_default {
// M in (128, inf)
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
...@@ -25,6 +26,34 @@ struct sm100_fp8_config_default { ...@@ -25,6 +26,34 @@ struct sm100_fp8_config_default {
KernelSchedule, EpilogueSchedule>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_M128 {
// M in (64, 128]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_128, _128, _64>;
using ClusterShape = Shape<_2, _2, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_M64 {
// M in [1, 64]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_64, _64, _256>;
using ClusterShape = Shape<_1, _8, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue, template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
...@@ -39,8 +68,28 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, ...@@ -39,8 +68,28 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
using Cutlass3xGemmDefault = using Cutlass3xGemmDefault =
typename sm100_fp8_config_default<InType, OutType, typename sm100_fp8_config_default<InType, OutType,
Epilogue>::Cutlass3xGemm; Epilogue>::Cutlass3xGemm;
return cutlass_gemm_caller<Cutlass3xGemmDefault>( using Cutlass3xGemmM64 =
out, a, b, std::forward<EpilogueArgs>(args)...); typename sm100_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm100_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
if (mp2 <= 64) {
// m in [1, 64]
return cutlass_gemm_caller<Cutlass3xGemmM64>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// m in (64, 128]
return cutlass_gemm_caller<Cutlass3xGemmM128>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// m in (128, inf)
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
} }
template <template <typename, typename, typename> typename Epilogue, template <template <typename, typename, typename> typename Epilogue,
......
...@@ -84,7 +84,8 @@ void run_cutlass_moe_mm_sm90( ...@@ -84,7 +84,8 @@ void run_cutlass_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) { torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided."); TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided."); TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided."); TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
...@@ -113,19 +114,23 @@ void run_cutlass_moe_mm_sm90( ...@@ -113,19 +114,23 @@ void run_cutlass_moe_mm_sm90(
if (n >= 8192) { if (n >= 8192) {
cutlass_group_gemm_caller<Cutlass3xGemmN8192>( cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides); problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else if (k >= 8192) { } else if (k >= 8192) {
cutlass_group_gemm_caller<Cutlass3xGemmK8192>( cutlass_group_gemm_caller<Cutlass3xGemmK8192>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides); problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else if (m <= 16) { } else if (m <= 16) {
cutlass_group_gemm_caller<Cutlass3xGemmM16>( cutlass_group_gemm_caller<Cutlass3xGemmM16>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides); problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else { } else {
cutlass_group_gemm_caller<Cutlass3xGemmDefault>( cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides); problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} }
} }
...@@ -134,15 +139,18 @@ void dispatch_moe_mm_sm90( ...@@ -134,15 +139,18 @@ void dispatch_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) { torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
if (out_tensors.dtype() == torch::kBFloat16) { if (out_tensors.dtype() == torch::kBFloat16) {
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>( run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides); problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else { } else {
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::half_t>( run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::half_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides); problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} }
} }
...@@ -153,8 +161,9 @@ void cutlass_moe_mm_sm90( ...@@ -153,8 +161,9 @@ void cutlass_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) { torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides); c_strides, per_act_token, per_out_ch);
} }
...@@ -76,7 +76,8 @@ void cutlass_group_gemm_caller( ...@@ -76,7 +76,8 @@ void cutlass_group_gemm_caller(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) { torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
using ElementAB = typename Gemm::ElementAB; using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD; using ElementD = typename Gemm::ElementD;
...@@ -84,9 +85,6 @@ void cutlass_group_gemm_caller( ...@@ -84,9 +85,6 @@ void cutlass_group_gemm_caller(
int k_size = a_tensors.size(1); int k_size = a_tensors.size(1);
int n_size = out_tensors.size(1); int n_size = out_tensors.size(1);
bool per_act_token = a_scales.numel() != 1;
bool per_out_ch = b_scales.numel() != num_experts;
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
auto options_int = auto options_int =
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
constexpr uint64_t THREADS_PER_EXPERT = 512; constexpr uint64_t THREADS_PER_EXPERT = 512;
__global__ void compute_problem_sizes(const int* __restrict__ topk_ids, __global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids,
int32_t* problem_sizes1, int32_t* problem_sizes1,
int32_t* problem_sizes2, int32_t* problem_sizes2,
int32_t* atomic_buffer, int32_t* atomic_buffer,
...@@ -45,7 +45,24 @@ __global__ void compute_expert_offsets( ...@@ -45,7 +45,24 @@ __global__ void compute_expert_offsets(
} }
} }
__global__ void compute_arg_sorts(const int* __restrict__ topk_ids, __global__ void compute_expert_blockscale_offsets(
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
int32_t* blockscale_offsets, int32_t* atomic_buffer,
const int num_experts) {
int32_t tot_offset = 0;
int32_t tot_offset_round = 0;
expert_offsets[0] = 0;
blockscale_offsets[0] = 0;
for (int i = 0; i < num_experts; ++i) {
atomic_buffer[i] = tot_offset;
tot_offset += problem_sizes1[i * 3];
expert_offsets[i + 1] = tot_offset;
tot_offset_round += (problem_sizes1[i * 3] + (128 - 1)) / 128 * 128;
blockscale_offsets[i + 1] = tot_offset_round;
}
}
__global__ void compute_arg_sorts(const uint32_t* __restrict__ topk_ids,
const int32_t* __restrict__ expert_offsets, const int32_t* __restrict__ expert_offsets,
int32_t* input_permutation, int32_t* input_permutation,
int32_t* output_permutation, int32_t* output_permutation,
...@@ -77,7 +94,8 @@ void get_cutlass_moe_mm_data_caller( ...@@ -77,7 +94,8 @@ void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation, torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k) { const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets) {
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index()); auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
auto options_int32 = auto options_int32 =
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
...@@ -85,19 +103,61 @@ void get_cutlass_moe_mm_data_caller( ...@@ -85,19 +103,61 @@ void get_cutlass_moe_mm_data_caller(
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>( compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()), static_cast<const uint32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()), static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()), static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k); static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
compute_expert_offsets<<<1, 1, 0, stream>>>( if (blockscale_offsets.has_value()) {
static_cast<const int32_t*>(problem_sizes1.data_ptr()), compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>(
static_cast<int32_t*>(expert_offsets.data_ptr()), static_cast<const int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts); static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(blockscale_offsets.value().data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
} else {
compute_expert_offsets<<<1, 1, 0, stream>>>(
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
}
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>( compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()), static_cast<const uint32_t*>(topk_ids.data_ptr()),
static_cast<const int32_t*>(expert_offsets.data_ptr()), static_cast<const int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(input_permutation.data_ptr()), static_cast<int32_t*>(input_permutation.data_ptr()),
static_cast<int32_t*>(output_permutation.data_ptr()), static_cast<int32_t*>(output_permutation.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
topk_ids.size(1)); topk_ids.size(1));
} }
__global__ void compute_pplx_data(int32_t* expert_offsets,
int32_t* problem_sizes1,
int32_t* problem_sizes2,
const int32_t* __restrict__ expert_num_tokens,
const int padded_m, const int n,
const int k) {
int expert_idx = threadIdx.x;
expert_offsets[expert_idx] = expert_idx * padded_m;
problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx];
problem_sizes1[expert_idx * 3 + 1] = 2 * n;
problem_sizes1[expert_idx * 3 + 2] = k;
problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx];
problem_sizes2[expert_idx * 3 + 1] = k;
problem_sizes2[expert_idx * 3 + 2] = n;
}
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m,
const int64_t n, const int64_t k) {
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
compute_pplx_data<<<1, num_local_experts, 0, stream>>>(
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
k);
}
#include <cudaTypedefs.h> #include "c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp" #include "c3x/scaled_mm_kernels.hpp"
#include "cuda_utils.h"
/* /*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm100 (Blackwell). NVIDIA GPUs with sm100 (Blackwell).
...@@ -15,20 +13,10 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a, ...@@ -15,20 +13,10 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) { std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); vllm::cutlass_scaled_mm_sm100_fp8,
nullptr, // int8 not supported on SM100
int M = a.size(0), N = b.size(1), K = a.size(1); vllm::cutlass_scaled_mm_blockwise_sm100_fp8);
TORCH_CHECK(
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
"Currently, only fp8 gemm is implemented for Blackwell");
vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
} }
#endif #endif
#include <cudaTypedefs.h> #include "c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp" #include "c3x/scaled_mm_kernels.hpp"
#include "cuda_utils.h"
/* /*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper). NVIDIA GPUs with sm90a (Hopper).
...@@ -15,49 +13,10 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, ...@@ -15,49 +13,10 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) { std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); vllm::cutlass_scaled_mm_sm90_fp8,
vllm::cutlass_scaled_mm_sm90_int8,
int M = a.size(0), N = b.size(1), K = a.size(1); vllm::cutlass_scaled_mm_blockwise_sm90_fp8);
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (a.dtype() == torch::kFloat8_e4m3fn) {
vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias);
} else {
TORCH_CHECK(a.dtype() == torch::kInt8);
vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias);
}
} else {
using GroupShape = std::array<int64_t, 2>;
auto make_group_shape = [](torch::Tensor const& x,
torch::Tensor const& s) -> GroupShape {
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
cuda_utils::ceil_div(x.size(1), s.size(1))};
};
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
// 1x128 per-token group scales for activations
// 128x128 blockwise scales for weights
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
b_scale_group_shape == GroupShape{128, 128} &&
a.dtype() == torch::kFloat8_e4m3fn &&
b.dtype() == torch::kFloat8_e4m3fn),
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
"a_scale_group_shape must be [1, 128]. Got: [",
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
"]\n"
"b_scale_group_shape must be [128, 128]. Got: [",
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales);
}
} }
void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
......
...@@ -29,19 +29,15 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, ...@@ -29,19 +29,15 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias); std::optional<torch::Tensor> const& bias);
#endif
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
void cutlass_moe_mm_sm90( void cutlass_moe_mm_sm90(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides); torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k);
#endif #endif
...@@ -53,6 +49,24 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a, ...@@ -53,6 +49,24 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
std::optional<torch::Tensor> const& bias); std::optional<torch::Tensor> const& bias);
#endif #endif
#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \
defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100
void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets);
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m,
const int64_t n, const int64_t k);
#endif
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
...@@ -110,6 +124,8 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) { ...@@ -110,6 +124,8 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
#if defined CUDA_VERSION #if defined CUDA_VERSION
if (cuda_device_capability >= 90 && cuda_device_capability < 100) { if (cuda_device_capability >= 90 && cuda_device_capability < 100) {
return CUDA_VERSION >= 12000; return CUDA_VERSION >= 12000;
} else if (cuda_device_capability >= 100) {
return CUDA_VERSION >= 12080;
} }
#endif #endif
...@@ -117,7 +133,7 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) { ...@@ -117,7 +133,7 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
} }
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) { bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
// CUTLASS groped FP8 kernels need at least CUDA 12.3 // CUTLASS grouped FP8 kernels need at least CUDA 12.3
// and SM90 (Hopper) // and SM90 (Hopper)
#if defined CUDA_VERSION #if defined CUDA_VERSION
...@@ -200,12 +216,13 @@ void cutlass_moe_mm( ...@@ -200,12 +216,13 @@ void cutlass_moe_mm(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) { torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 #if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides); c_strides, per_act_token, per_out_ch);
return; return;
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED( TORCH_CHECK_NOT_IMPLEMENTED(
...@@ -218,14 +235,17 @@ void get_cutlass_moe_mm_data( ...@@ -218,14 +235,17 @@ void get_cutlass_moe_mm_data(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation, torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k) { const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets) {
// This function currently gets compiled only if we have a valid cutlass moe // This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for. // mm to run it for.
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 #if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90)
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, input_permutation, problem_sizes2, input_permutation,
output_permutation, num_experts, n, k); output_permutation, num_experts, n, k,
blockscale_offsets);
return; return;
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED( TORCH_CHECK_NOT_IMPLEMENTED(
...@@ -235,6 +255,29 @@ void get_cutlass_moe_mm_data( ...@@ -235,6 +255,29 @@ void get_cutlass_moe_mm_data(
version_num, ". Required capability: 90"); version_num, ". Required capability: 90");
} }
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m, const int64_t n,
const int64_t k) {
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1,
problem_sizes2, expert_num_tokens,
num_local_experts, padded_m, n, k);
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
"for CUDA device capability: ",
version_num, ". Required capability: 90");
}
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
......
#include <torch/all.h>
#include <cutlass/arch/arch.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include <cassert>
using namespace cute;
template <typename ElementAB, typename ElementC, typename ElementSF,
typename ElementAccumulator, typename LayoutSFA, typename LayoutSFB,
typename ScaleConfig>
__global__ void __get_group_gemm_starts(
ElementAB** a_offsets, ElementAB** b_offsets, ElementC** out_offsets,
ElementSF** a_scales_offsets, ElementSF** b_scales_offsets,
ElementAccumulator** alpha_offsets, LayoutSFA* layout_sfa_base_as_int,
LayoutSFB* layout_sfb_base_as_int, ElementAB* a_base_as_int,
ElementAB* b_base_as_int, ElementC* out_base_as_int,
ElementSF* a_scales_base_as_int, ElementSF* b_scales_base_as_int,
ElementAccumulator* alphas_base_as_int, const int32_t* expert_offsets,
const int32_t* sf_offsets, const int32_t* problem_sizes_as_shapes,
const int K, const int N) {
int64_t expert_id = threadIdx.x;
if (expert_id >= gridDim.x * blockDim.x) {
return;
}
// Originally int32_t but upcasting to int64_t to avoid overflow
// during offset calculations
int64_t expert_offset = static_cast<int64_t>(expert_offsets[expert_id]);
int64_t sf_offset = static_cast<int64_t>(sf_offsets[expert_id]);
// size for block in block scale.
int64_t group_size = 16;
int64_t m = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3]);
int64_t n = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3 + 1]);
int64_t k = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3 + 2]);
assert((m >= 0 && n == N && k == K && k % 2 == 0) &&
"unexpected problem sizes");
int64_t half_k = static_cast<int64_t>(k / 2);
int64_t group_k = static_cast<int64_t>(k / group_size);
// Shape of A as uint8/byte = [M, K // 2]
// Shape of B as uint8/byte = [E, N, K // 2]
a_offsets[expert_id] = a_base_as_int + expert_offset * half_k;
b_offsets[expert_id] = b_base_as_int + expert_id * n * half_k;
// Shape of C = [M, N]
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
// Shape of a_scale = [sum(sf_sizes), K // group_size]
a_scales_offsets[expert_id] = a_scales_base_as_int + sf_offset * group_k;
assert((reinterpret_cast<uintptr_t>(a_scales_offsets[expert_id]) % 128) ==
0 &&
"TMA requires 128-byte alignment");
// Shape of B scale = [E, N, K // group_size]
b_scales_offsets[expert_id] = b_scales_base_as_int + expert_id * n * group_k;
assert((reinterpret_cast<uintptr_t>(b_scales_offsets[expert_id]) % 128) ==
0 &&
"TMA requires 128-byte alignment");
// Shape of alpha = [E]
alpha_offsets[expert_id] = alphas_base_as_int + expert_id;
LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id;
LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id;
*layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), 1));
*layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), 1));
}
#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(ELEMENT_AB_TYPE, SF_TYPE, \
TENSOR_C_TYPE, C_TYPE, LayoutSFA, \
LayoutSFB, ScaleConfig) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
__get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, float, \
LayoutSFA, LayoutSFB, ScaleConfig> \
<<<1, num_experts, 0, stream>>>( \
static_cast<ELEMENT_AB_TYPE**>(a_starts.data_ptr()), \
static_cast<ELEMENT_AB_TYPE**>(b_starts.data_ptr()), \
static_cast<C_TYPE**>(out_starts.data_ptr()), \
static_cast<SF_TYPE**>(a_scales_starts.data_ptr()), \
static_cast<SF_TYPE**>(b_scales_starts.data_ptr()), \
static_cast<float**>(alpha_starts.data_ptr()), \
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), \
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()), \
static_cast<ELEMENT_AB_TYPE*>(a_tensors.data_ptr()), \
static_cast<ELEMENT_AB_TYPE*>(b_tensors.data_ptr()), \
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
static_cast<SF_TYPE*>(a_scales.data_ptr()), \
static_cast<SF_TYPE*>(b_scales.data_ptr()), \
static_cast<float*>(alphas.data_ptr()), \
static_cast<int32_t*>(expert_offsets.data_ptr()), \
static_cast<int32_t*>(sf_offsets.data_ptr()), \
static_cast<int32_t*>(problem_sizes.data_ptr()), K, N); \
}
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
void run_get_group_gemm_starts(
const torch::Tensor& a_starts, const torch::Tensor& b_starts,
const torch::Tensor& out_starts, const torch::Tensor& a_scales_starts,
const torch::Tensor& b_scales_starts, const torch::Tensor& alpha_starts,
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
/*these are used for their base addresses*/
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
torch::Tensor const& out_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& alphas,
torch::Tensor const& expert_offsets, torch::Tensor const& sf_offsets,
torch::Tensor const& problem_sizes, int M, int N, int K) {
int num_experts = (int)expert_offsets.size(0);
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
TORCH_CHECK(out_tensors.size(1) == N,
"Output tensor shape doesn't match expected shape");
TORCH_CHECK(K / 2 == b_tensors.size(2),
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
" dimension must match");
if (false) {
}
//(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB,
// ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(
cutlass::float_e2m1_t, cutlass::float_ue4m3_t, torch::kBFloat16,
cutlass::bfloat16_t, LayoutSFA, LayoutSFB, ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(cutlass::float_e2m1_t,
cutlass::float_ue4m3_t, torch::kFloat16,
half, LayoutSFA, LayoutSFB, ScaleConfig)
else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
}
template <typename OutType>
void run_fp4_blockwise_scaled_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
int N, int K) {
using ProblemShape =
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
using ElementType = cutlass::float_e2m1_t;
using ElementSFType = cutlass::float_ue4m3_t;
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
using ElementC = OutType;
using ElementD = ElementC;
using ElementAccumulator = float;
// Layout definitions
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = LayoutC;
// Alignment constraints
static constexpr int AlignmentA = 32;
static constexpr int AlignmentB = 32;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
// Architecture definitions
using ArchTag = cutlass::arch::Sm100;
using EpilogueOperatorClass =
cutlass::arch::OpClassTensorOp; // Epilogue Operator class tag
using MainloopOperatorClass =
cutlass::arch::OpClassBlockScaledTensorOp; // Mainloop Operator class tag
using StageCountType =
cutlass::gemm::collective::StageCountAuto; // Stage count maximized based
// on the tile size
using ClusterShape = Shape<_1, _1, _1>;
struct MMA1SMConfig {
using MmaTileShape = Shape<_128, _128, _128>;
using KernelSchedule = cutlass::gemm::
KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100; // Kernel to launch
using EpilogueSchedule =
cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch
};
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, EpilogueOperatorClass, typename MMA1SMConfig::MmaTileShape,
ClusterShape, Shape<_128, _64>, ElementAccumulator,
ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD,
LayoutC*, AlignmentD,
typename MMA1SMConfig::EpilogueSchedule>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, MainloopOperatorClass, ElementA, LayoutA*, AlignmentA,
ElementB, LayoutB*, AlignmentB, ElementAccumulator,
typename MMA1SMConfig::MmaTileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
typename MMA1SMConfig::KernelSchedule>::CollectiveOp;
using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
CollectiveEpilogue>;
using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using Gemm = Gemm1SM;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
using LayoutSFA =
typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
using LayoutSFB =
typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
using ScaleConfig =
typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
int num_experts = static_cast<int>(expert_offsets.size(0));
auto options_int =
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
torch::Tensor c_strides1 =
torch::full({num_experts}, output.stride(0), options_int);
torch::Tensor a_strides1 =
torch::full({num_experts}, a.stride(0) * 2, options_int);
torch::Tensor b_strides1 =
torch::full({num_experts}, b.stride(1) * 2, options_int);
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
layout_sfa, layout_sfb, a, b, output, a_blockscale, b_blockscales, alphas,
expert_offsets, sf_offsets, problem_sizes, M, N, K);
// Create an instance of the GEMM
Gemm gemm_op;
// Initialize problem_sizes_as_shapes correctly
UnderlyingProblemShape* problem_sizes_as_shapes =
static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
// Set the Scheduler info
cutlass::KernelHardwareInfo hw_info;
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::
PersistentTileSchedulerSm100GroupParams<
typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
scheduler.raster_order = RasterOrderOptions::AlongM;
hw_info.device_id = a.get_device();
static std::unordered_map<int, int> cached_sm_counts;
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
cached_sm_counts[hw_info.device_id] =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
hw_info.device_id);
}
hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX);
// Mainloop Arguments
typename GemmKernel::MainloopArguments mainloop_args{
static_cast<const ElementType**>(a_ptrs.data_ptr()),
static_cast<StrideA*>(a_strides1.data_ptr()),
static_cast<const ElementType**>(b_ptrs.data_ptr()),
static_cast<StrideB*>(b_strides1.data_ptr()),
static_cast<const ElementSFType**>(a_scales_ptrs.data_ptr()),
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()),
static_cast<const ElementSFType**>(b_scales_ptrs.data_ptr()),
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr())};
// Epilogue Arguments
typename GemmKernel::EpilogueArguments epilogue_args{
{}, // epilogue.thread
nullptr,
static_cast<StrideC*>(c_strides1.data_ptr()),
static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<StrideC*>(c_strides1.data_ptr())};
auto& fusion_args = epilogue_args.thread;
fusion_args.alpha_ptr_array =
reinterpret_cast<float**>(alpha_ptrs.data_ptr());
fusion_args.dAlpha = {_0{}, _0{}, 1};
// Gemm Arguments
typename GemmKernel::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped,
{num_experts, problem_sizes_as_shapes, nullptr},
mainloop_args,
epilogue_args,
hw_info,
scheduler};
size_t workspace_size = Gemm::get_workspace_size(args);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM");
// Run the GEMM
auto status = gemm_op.initialize(args, workspace.data_ptr());
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
status = gemm_op.run(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
#define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
void cutlass_fp4_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) {
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
// Input validation
CHECK_INPUT(a, FLOAT4_E2M1X2, "a");
CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
CHECK_INPUT(a_blockscale, SF_DTYPE, "a_blockscale");
CHECK_INPUT(b_blockscales, SF_DTYPE, "b_blockscales");
CHECK_INPUT(alphas, at::ScalarType::Float, "alphas");
TORCH_CHECK(a_blockscale.dim() == 2,
"expected a_blockscale to be of shape [num_experts, rounded_m,"
" k // group_size], observed rank: ",
a_blockscale.dim())
TORCH_CHECK(b_blockscales.dim() == 3,
"expected b_blockscale to be of shape: "
" [num_experts, n, k // group_size], observed rank: ",
b_blockscales.dim())
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be a 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3,
"problem_sizes must have the shape (num_experts, 3)");
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
"Number of experts in problem_sizes must match expert_offsets");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
"problem_sizes must be int32.");
int M = static_cast<int>(a.size(0));
int N = static_cast<int>(b.size(1));
int E = static_cast<int>(b.size(0));
int K = static_cast<int>(2 * b.size(2));
if (output.scalar_type() == torch::kBFloat16) {
run_fp4_blockwise_scaled_group_mm<cutlass::bfloat16_t>(
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
expert_offsets, sf_offsets, M, N, K);
} else {
run_fp4_blockwise_scaled_group_mm<cutlass::half_t>(
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
expert_offsets, sf_offsets, M, N, K);
}
#else
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_fp4_group_mm kernel, vLLM must "
"be compiled with ENABLE_NVFP4 for SM100+ and CUDA "
"12.8 or above.");
#endif
}
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <cuda_fp8.h>
template <typename T>
struct TypeConverter {
using Type = half2;
}; // keep for generality
template <>
struct TypeConverter<half2> {
using Type = half;
};
template <>
struct TypeConverter<half> {
using Type = half2;
};
template <>
struct TypeConverter<__nv_bfloat162> {
using Type = __nv_bfloat16;
};
template <>
struct TypeConverter<__nv_bfloat16> {
using Type = __nv_bfloat162;
};
#define ELTS_PER_THREAD 8
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
constexpr int CVT_FP4_SF_VEC_SIZE = 16;
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t val;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
".reg .b8 byte1;\n"
".reg .b8 byte2;\n"
".reg .b8 byte3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
: "=r"(val)
: "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]),
"f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7]));
return val;
#else
return 0;
#endif
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t val;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
".reg .b8 byte1;\n"
".reg .b8 byte2;\n"
".reg .b8 byte3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
: "=r"(val)
: "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y),
"f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
return val;
#else
return 0;
#endif
}
// Fast reciprocal.
inline __device__ float reciprocal_approximate_ftz(float a) {
float b;
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
return b;
}
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx,
int numCols,
SFType* SFout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 ||
CVT_FP4_NUM_THREADS_PER_SF == 2);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
// SF vector index (16 elements share one SF in the K dimension).
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
int32_t mIdx = rowIdx;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t mTileIdx = mIdx / (32 * 4);
// SF vector size 16.
int factor = CVT_FP4_SF_VEC_SIZE * 4;
int32_t numKTiles = (numCols + factor - 1) / factor;
int64_t mTileStride = numKTiles * 32 * 4 * 4;
int32_t kTileIdx = (kIdx / 4);
int64_t kTileStride = 32 * 4 * 4;
// M tile layout [32, 4] is column-major.
int32_t outerMIdx = (mIdx % 32);
int64_t outerMStride = 4 * 4;
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
int64_t innerMStride = 4;
int32_t innerKIdx = (kIdx % 4);
int64_t innerKStride = 1;
// Compute the global offset.
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride +
outerMIdx * outerMStride + innerMIdx * innerMStride +
innerKIdx * innerKStride;
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
}
#endif
return nullptr;
}
// Define a 16 bytes packed data type.
template <class Type>
struct PackedVec {
typename TypeConverter<Type>::Type elts[4];
};
template <>
struct PackedVec<__nv_fp8_e4m3> {
__nv_fp8x2_e4m3 elts[8];
};
// Quantizes the provided PackedVec into the uint32_t output
template <class Type, bool UE8M0_SF = false>
__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
uint8_t* SFout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Get absolute maximum values among the local 8 values.
auto localMax = __habs2(vec.elts[0]);
// Local maximum value.
#pragma unroll
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
localMax = __hmax2(localMax, __habs2(vec.elts[i]));
}
// Get the absolute maximum among all 16 values (two threads).
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
// Get the final absolute maximum values.
float vecMax = float(__hmax(localMax.x, localMax.y));
// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
// 8 bits representation of the SF.
uint8_t fp8SFVal;
// Write the SF to global memory (STG.8).
if constexpr (UE8M0_SF) {
// Extract the 8 exponent bits from float32.
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
uint32_t tmp = reinterpret_cast<uint32_t&>(SFValue) >> 23;
fp8SFVal = tmp & 0xff;
// Convert back to fp32.
reinterpret_cast<uint32_t&>(SFValue) = tmp << 23;
} else {
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp;
// Convert back to fp32.
SFValue = float(tmp);
}
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
// reciprocal(SFScaleVal))
float outputScale =
SFValue != 0 ? reciprocal_approximate_ftz(
SFValue * reciprocal_approximate_ftz(SFScaleVal))
: 0.0f;
if (SFout) {
// Write the SF to global memory (STG.8).
*SFout = fp8SFVal;
}
// Convert the input to float.
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
#pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
if constexpr (std::is_same_v<Type, half>) {
fp2Vals[i] = __half22float2(vec.elts[i]);
} else {
fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
}
fp2Vals[i].x *= outputScale;
fp2Vals[i].y *= outputScale;
}
// Convert to e2m1 values.
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
// Write the e2m1 values to global memory.
return e2m1Vec;
#else
return 0;
#endif
}
// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false>
__global__ void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__(512, 4) cvt_fp16_to_fp4(
#else
cvt_fp16_to_fp4(
#endif
int32_t numRows, int32_t numCols, Type const* in, float const* SFScale,
uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts,
uint32_t* output_scale_offset_by_experts, int n_experts) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using PackedVec = PackedVec<Type>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
"Vec size is not matched.");
// Input tensor row/col loops.
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD;
colIdx += blockDim.x) {
int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = inOffset;
auto& out_pos = out[outOffset];
// Find index within the experts.
int rowIdx_in_expert = 0;
int expert_idx = 0;
for (int i = 0; i < n_experts; i++) {
if (rowIdx >= input_offset_by_experts[i] &&
rowIdx < input_offset_by_experts[i + 1]) {
rowIdx_in_expert = rowIdx - input_offset_by_experts[i];
expert_idx = i;
break;
}
}
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
int factor = CVT_FP4_SF_VEC_SIZE * 4;
// The actual output_scales dim is computed from the padded numCols.
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
uint32_t* SFout_in_expert =
SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
auto sf_out =
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
out_pos =
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
}
}
#endif
}
template <typename T>
void quant_impl(void* output, void* output_scale, void* input,
void* input_global_scale, void* input_offset_by_experts,
void* output_scale_offset_by_experts, int m_topk, int k,
int n_experts, cudaStream_t stream) {
// TODO: this multiProcessorCount should be cached.
int device;
cudaGetDevice(&device);
int multiProcessorCount;
cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount,
device);
// Grid, Block size.
// Each thread converts 8 values.
dim3 block(std::min(int(k / ELTS_PER_THREAD), 512));
// Get number of blocks per SM (assume we can fully utilize the SM).
int const numBlocksPerSM = 2048 / block.x;
dim3 grid(std::min(int(m_topk), multiProcessorCount * numBlocksPerSM));
cvt_fp16_to_fp4<T, false><<<grid, block, 0, stream>>>(
m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<float*>(input_global_scale),
reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts), n_experts);
}
/*Quantization entry for fp4 experts quantization*/
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
#define CHECK_INPUT(x, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m);
constexpr auto HALF = at::ScalarType::Half;
constexpr auto BF16 = at::ScalarType::BFloat16;
constexpr auto FLOAT = at::ScalarType::Float;
constexpr auto INT = at::ScalarType::Int;
constexpr auto UINT8 = at::ScalarType::Byte;
void scaled_fp4_experts_quant_sm100a(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
CHECK_INPUT(output, "output must be a CUDA tensor");
CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor");
CHECK_INPUT(input, "input must be a CUDA tensor");
CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor");
CHECK_INPUT(input_offset_by_experts,
"input_offset_by_experts must be a CUDA tensor");
CHECK_INPUT(output_scale_offset_by_experts,
"output_scale_offset_by_experts must be a CUDA tensor");
TORCH_CHECK(output.dim() == 2);
TORCH_CHECK(output_scale.dim() == 2);
TORCH_CHECK(input.dim() == 2);
TORCH_CHECK(input_global_scale.dim() == 1);
TORCH_CHECK(input_offset_by_experts.dim() == 1);
TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
// output is uint8 (two nvfp4 values are packed into one uint8)
// output_scale is int32 (four fp8 values are packed into one int32)
TORCH_CHECK(output.scalar_type() == UINT8);
TORCH_CHECK(output_scale.scalar_type() == INT);
const int BLOCK_SIZE = 16;
auto m_topk = input.size(0);
auto k = input.size(1);
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
auto n_experts = input_global_scale.size(0);
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
TORCH_CHECK(output.size(0) == m_topk);
TORCH_CHECK(output.size(1) == k / 2);
int scales_k = k / BLOCK_SIZE;
// 4 means the swizzle requirement by nvidia nvfp4.
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
// 4 means 4 fp8 values are packed into one int32
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
auto in_dtype = input.dtype();
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
const cudaStream_t stream =
at::cuda::getCurrentCUDAStream(input.get_device());
if (in_dtype == at::ScalarType::Half) {
quant_impl<half>(output.data_ptr(), output_scale.data_ptr(),
input.data_ptr(), input_global_scale.data_ptr(),
input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(), m_topk, k,
n_experts, stream);
} else if (in_dtype == at::ScalarType::BFloat16) {
quant_impl<__nv_bfloat16>(output.data_ptr(), output_scale.data_ptr(),
input.data_ptr(), input_global_scale.data_ptr(),
input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(), m_topk,
k, n_experts, stream);
} else {
TORCH_CHECK(false, "Expected input data type to be half or bfloat16");
}
}
\ No newline at end of file
...@@ -23,10 +23,32 @@ void scaled_fp4_quant_sm100a(torch::Tensor const& output, ...@@ -23,10 +23,32 @@ void scaled_fp4_quant_sm100a(torch::Tensor const& output,
torch::Tensor const& input_sf); torch::Tensor const& input_sf);
#endif #endif
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
void scaled_fp4_experts_quant_sm100a(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
#endif
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
torch::Tensor& output_sf, torch::Tensor const& input_sf) { torch::Tensor& output_sf, torch::Tensor const& input_sf) {
#if defined ENABLE_NVFP4 && ENABLE_NVFP4 #if defined ENABLE_NVFP4 && ENABLE_NVFP4
return scaled_fp4_quant_sm100a(output, input, output_sf, input_sf); return scaled_fp4_quant_sm100a(output, input, output_sf, input_sf);
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization"); TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel");
}
void scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return scaled_fp4_experts_quant_sm100a(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false,
"No compiled nvfp4 experts quantization kernel");
} }
...@@ -39,8 +39,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( ...@@ -39,8 +39,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
fp8_type* __restrict__ token_output = &out[offset]; fp8_type* __restrict__ token_output = &out[offset];
// For vectorization, token_input and token_output pointers need to be // For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively. // aligned at 32-byte and 16-byte addresses respectively.
bool const can_vectorize = hidden_size % 4 == 0; bool const can_vectorize = hidden_size % 16 == 0;
float absmax_val = 0.0f; float absmax_val = 0.0f;
if (can_vectorize) { if (can_vectorize) {
...@@ -48,24 +48,24 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( ...@@ -48,24 +48,24 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
} else { } else {
for (int i = tid; i < hidden_size; i += blockDim.x) { for (int i = tid; i < hidden_size; i += blockDim.x) {
float const x = static_cast<float>(token_input[i]); float const x = static_cast<float>(token_input[i]);
absmax_val = max(absmax_val, fabs(x)); absmax_val = fmaxf(absmax_val, fabsf(x));
} }
} }
using BlockReduce = cub::BlockReduce<float, 1024>; using BlockReduce = cub::BlockReduce<float, 256>;
__shared__ typename BlockReduce::TempStorage reduceStorage; __shared__ typename BlockReduce::TempStorage reduceStorage;
float const block_absmax_val_maybe = float const block_absmax_val_maybe =
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x); BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
__shared__ float token_scale; __shared__ float token_scale;
if (tid == 0) { if (tid == 0) {
if (scale_ub) { if (scale_ub) {
token_scale = min(block_absmax_val_maybe, *scale_ub); token_scale = fminf(block_absmax_val_maybe, *scale_ub);
} else { } else {
token_scale = block_absmax_val_maybe; token_scale = block_absmax_val_maybe;
} }
// token scale computation // token scale computation
token_scale = max(token_scale / quant_type_max_v<fp8_type>, token_scale = fmaxf(token_scale / quant_type_max_v<fp8_type>,
min_scaling_factor<fp8_type>::val()); min_scaling_factor<fp8_type>::val());
scale[token_idx] = token_scale; scale[token_idx] = token_scale;
} }
__syncthreads(); __syncthreads();
...@@ -88,10 +88,11 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] ...@@ -88,10 +88,11 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d] torch::Tensor const& input, // [..., d]
torch::Tensor const& scale) // [1] torch::Tensor const& scale) // [1]
{ {
int64_t num_tokens = input.numel() / input.size(-1); int const block_size = 256;
int64_t num_elems = input.numel(); int const num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens); int const num_elems = input.numel();
dim3 block(1024); dim3 const grid(num_tokens);
dim3 const block(block_size);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
...@@ -110,10 +111,11 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] ...@@ -110,10 +111,11 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d] torch::Tensor const& input, // [..., d]
torch::Tensor& scale) // [1] torch::Tensor& scale) // [1]
{ {
int64_t num_tokens = input.numel() / input.size(-1); int const block_size = 256;
int64_t num_elems = input.numel(); int const num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens); int const num_elems = input.numel();
dim3 block(1024); dim3 const grid(num_tokens);
dim3 const block(block_size);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
...@@ -141,8 +143,9 @@ void dynamic_per_token_scaled_fp8_quant( ...@@ -141,8 +143,9 @@ void dynamic_per_token_scaled_fp8_quant(
int const hidden_size = input.size(-1); int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size; int const num_tokens = input.numel() / hidden_size;
int const block_size = 256;
dim3 const grid(num_tokens); dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 1024)); dim3 const block(std::min(hidden_size, block_size));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment