Commit 4c676e3d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.9.1' into v0.9.1-dev

parents b4c4464d b6553be1
#include "marlin_moe_kernel_ku4b8.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4b8(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = false;
if (false) {
}
GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256)
GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 8, 256)
GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 4, 128)
GPTQ_CALL_IF_MOE(vllm::kU4B8, 4, 8, 128)
else {
return false;
}
return true;
}
} // namespace marlin_moe
#pragma once
#include "marlin_moe_kernel.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4b8(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks);
} // namespace marlin_moe
#include "marlin_moe_kernel_ku8b128.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku8b128(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = false;
if (false) {
}
GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256)
GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 8, 256)
GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 4, 128)
GPTQ_CALL_IF_MOE(vllm::kU8B128, 4, 8, 128)
else {
return false;
}
return true;
}
} // namespace marlin_moe
#pragma once
#include "marlin_moe_kernel.h"
namespace marlin_moe {
bool call_marlin_moe_kernel_ku8b128(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks);
}
/*
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#include "core/exception.hpp"
#include "core/scalar_type.hpp"
#include "core/registration.h"
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
#include "marlin_kernels/marlin_moe_kernel_ku4.h"
template <typename T>
inline std::string str(T x) {
return std::to_string(x);
}
namespace marlin_moe {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices.
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr, int size_m,
int size_k, int block_rows) {
int start_row = block_rows * blockIdx.x;
int finish_row = start_row + block_rows;
if (finish_row > size_m) {
finish_row = size_m;
}
int cur_block_rows = finish_row - start_row;
int row_stride = size_k * sizeof(half) / 16;
auto permute_row = [&](int row) {
int iters = size_k / blockDim.x;
int rest = size_k % blockDim.x;
int offset = row * row_stride;
half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
int base_k = 0;
for (int i = 0; i < iters; i++) {
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
base_k += blockDim.x;
}
if (rest) {
if (threadIdx.x < rest) {
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
}
}
};
for (int i = 0; i < cur_block_rows; i++) {
int cur_row = start_row + i;
if (cur_row < size_m) {
permute_row(cur_row);
}
}
}
__global__ void compute_expert_offsets(int const* __restrict__ topk_ids,
int* __restrict__ expert_offsets,
int topk_length, int block_size) {
int expert_id = threadIdx.x;
int num_experts = blockDim.x;
int occurrences = 0;
for (int i = 0; i < topk_length; ++i) {
occurrences += (topk_ids[i] == expert_id);
}
expert_offsets[expert_id + 1] = occurrences;
__syncthreads();
if (threadIdx.x == 0) {
int tot_offset = 0;
expert_offsets[0] = 0;
for (int i = 0; i < num_experts; ++i) {
tot_offset += ceildiv(expert_offsets[i + 1], block_size) * block_size;
expert_offsets[i + 1] = tot_offset;
}
}
__syncthreads();
}
#else
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr, int size_m,
int size_k, int block_rows) {
// Marlin is not implemented yet for SM < 8.0
assert(false);
return;
}
__global__ void compute_expert_offsets(int const* __restrict__ topk_ids,
int* __restrict__ expert_offsets,
int topk_length, int block_size) {
// Marlin is not implemented yet for SM < 8.0
assert(false);
return;
}
#endif
typedef struct {
int thread_k;
int thread_n;
int num_threads;
} thread_config_t;
typedef struct {
int max_m_blocks;
thread_config_t tb_cfg;
} exec_config_t;
thread_config_t small_batch_thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{128, 128, 256}, // Default
{128, 64, 128}, // Reduce N 2X, same K
{64, 256, 256}, // Reduce K 2X, increase N 2X
{64, 128, 128}, // Reduce K 2X, same N
{64, 64, 128}, // Reduce both 2X
};
thread_config_t large_batch_thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{64, 256, 256}, // Default
{128, 128, 256}, // Reduce N 2X, increase K 2X
{64, 128, 128}, // Reduce N 2X, same K
{128, 64, 128}, // Reduce N 4X, increase K 2X
{64, 64, 128}, // Reduce N 4X, same K
};
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
int prob_n, int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full) {
bool cache_scales_chunk = has_act_order && !is_k_full;
int tb_n = th_config.thread_n;
int tb_k = th_config.thread_k;
// Get max scale groups per thread-block
int tb_groups;
if (group_size == -1) {
tb_groups = 1;
} else if (group_size == 0) {
tb_groups = ceildiv(tb_k, 32); // Worst case is 32 group size
} else {
tb_groups = ceildiv(tb_k, group_size);
}
if (cache_scales_chunk) {
int load_groups =
tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 4;
} else {
int tb_scales = tb_groups * tb_n * 2;
return tb_scales * STAGES;
}
}
bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int scales_cache_size, int max_shared_mem) {
int pack_factor = 32 / num_bits;
// Get B size
int tb_k = th_config.thread_k;
int tb_n = th_config.thread_n;
int b_size = (tb_k * tb_n / pack_factor) * 4;
// Get A size
int m_blocks = ceildiv(prob_m, 16);
int tb_max_m = 16;
while (true) {
if (m_blocks >= max_m_blocks) {
tb_max_m *= max_m_blocks;
break;
}
max_m_blocks--;
if (max_m_blocks == 0) {
TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
}
}
int a_size = (tb_max_m * tb_k) * 2;
float pipe_size = (a_size + b_size) * STAGES;
TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
}
bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int group_size, bool has_act_order, bool is_k_full,
int max_shared_mem) {
// Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) {
return false;
}
// Verify K/N are divisible by thread K/N
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
return false;
}
// thread_k can be only 128 or 64 (because it must be less than groupsize
// which is 128)
if (th_config.thread_k != 128 && th_config.thread_k != 64) {
return false;
}
// Verify min for thread K/N
if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
return false;
}
// num_threads must be at least 128 (= 4 warps)
if (th_config.num_threads < 128) {
return false;
}
// Determine cache for scales
int scales_cache_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full);
// Check that pipeline fits into cache
if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, scales_cache_size, max_shared_mem)) {
return false;
}
return true;
}
exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
int num_bits, int group_size,
bool has_act_order, bool is_k_full,
int max_shared_mem) {
int max_m_blocks = 4;
while (max_m_blocks > 0) {
if (prob_m <= 16) {
for (auto th_config : small_batch_thread_configs) {
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full,
max_shared_mem)) {
return exec_config_t{max_m_blocks, th_config};
}
}
} else {
for (auto th_config : large_batch_thread_configs) {
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full,
max_shared_mem)) {
return exec_config_t{max_m_blocks, th_config};
}
}
}
max_m_blocks--; // Process less M blocks per invocation to reduce cache
// usage
}
return exec_config_t{0, {-1, -1, -1}};
}
#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \
else if (KERNEL_FUNCTION( \
q_type, thread_n_blocks, thread_k_blocks, has_act_order, \
group_blocks, num_threads, blocks, max_shared_mem, stream, \
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
replicate_input, apply_weights, m_block, max_par, \
exec_cfg.max_m_blocks)) { \
}
void marlin_mm_moe(const void* A, const void* B, void* C,
const void* sorted_ids, const void* topk_weights,
const void* topk_ids, const void* s, void* zp,
const void* g_idx, const void* perm, void* a_tmp,
void* expert_offsets, int prob_m, int prob_n, int prob_k,
void* workspace, vllm::ScalarType const& q_type,
bool has_act_order, bool is_k_full, bool has_zp,
int num_groups, int group_size, int num_experts, int topk,
int moe_block_size, int dev, cudaStream_t stream,
int thread_k, int thread_n, int sms, int max_par,
bool replicate_input, bool apply_weights) {
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]");
if (sms == -1) {
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
}
int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
int num_bits = q_type.size_bits();
// Set thread config
exec_config_t exec_cfg;
if (thread_k != -1 && thread_n != -1) {
// User-defined config
exec_cfg =
exec_config_t{4, thread_config_t{thread_k, thread_n, USER_THREADS}};
} else {
// Auto config
exec_cfg =
determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, max_shared_mem);
}
TORCH_CHECK(exec_cfg.max_m_blocks > 0 &&
is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks,
prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, max_shared_mem),
"Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks,
", thread_k = ", exec_cfg.tb_cfg.thread_k,
", thread_n = ", exec_cfg.tb_cfg.thread_n,
", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [",
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
", group_size = ", group_size,
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
", max_shared_mem = ", max_shared_mem);
int num_threads = exec_cfg.tb_cfg.num_threads;
thread_k = exec_cfg.tb_cfg.thread_k;
thread_n = exec_cfg.tb_cfg.thread_n;
int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16;
int blocks = sms;
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
" is not divisible by thread_n = ", thread_n);
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
" is not divisible by thread_k = ", thread_k);
int group_blocks = 0;
if (has_act_order) {
if (is_k_full) {
TORCH_CHECK(group_size != -1);
group_blocks = group_size / 16;
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
" is not divisible by group_blocks = ", group_blocks);
} else {
TORCH_CHECK(group_size == 0);
group_blocks = 0;
}
} else {
if (group_size == -1) {
group_blocks = -1;
} else {
group_blocks = group_size / 16;
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
" is not divisible by group_blocks = ", group_blocks);
}
}
int tot_m = prob_m;
const int* topk_ids_ptr = (const int*)topk_ids;
int* expert_offsets_ptr = (int*)expert_offsets;
compute_expert_offsets<<<1, num_experts, 0, stream>>>(
topk_ids_ptr, expert_offsets_ptr, tot_m * topk, moe_block_size);
bool do_permute_a = has_act_order;
// If we have a full K, then we can run the non-act-order version of Marlin
// (since the weight rows are reordered by increasing group ids, and by
// having a full K, we have full original groups)
if (is_k_full) {
has_act_order = false;
}
int pack_factor = 32 / q_type.size_bits();
for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) {
const int4* A_ptr = (const int4*)A;
int4* a_tmp_ptr = (int4*)a_tmp;
const int4* B_ptr =
(const int4*)B + (prob_n * prob_k / (pack_factor * 4)) * expert_idx;
int4* C_ptr = (int4*)C;
const float* topk_weights_ptr = (const float*)topk_weights;
const int* sorted_ids_ptr = (const int*)sorted_ids;
const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx;
const int4* zp_ptr =
(const int4*)zp + num_groups * prob_n / (pack_factor * 4) * expert_idx;
const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx;
const int* perm_ptr = (const int*)perm + prob_k * expert_idx;
int* locks = (int*)workspace;
if (do_permute_a) {
// Permute A columns
int topk_rows = replicate_input ? tot_m : tot_m * topk;
int block_rows = ceildiv(topk_rows, blocks);
permute_cols_kernel<<<blocks, num_threads, 0, stream>>>(
A_ptr, perm_ptr, a_tmp_ptr, topk_rows, prob_k, block_rows);
A_ptr = a_tmp_ptr;
}
int tot_m_blocks = ceildiv(tot_m, 16);
for (int m_block = 0; m_block < tot_m_blocks;
m_block += 4 * exec_cfg.max_m_blocks) {
if (false) {
}
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8)
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128)
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4)
else {
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
str(prob_n) + ", " + str(prob_k) + "]" +
", has_act_order = " + str(has_act_order) +
", num_groups = " + str(num_groups) +
", group_size = " + str(group_size) +
", thread_n_blocks = " + str(thread_n_blocks) +
", thread_k_blocks = " + str(thread_k_blocks));
}
}
}
}
} // namespace marlin_moe
torch::Tensor marlin_gemm_moe(
const torch::Tensor& a, const torch::Tensor& b_q_weights,
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
torch::Tensor& b_zeros, const torch::Tensor& g_idx,
const torch::Tensor& perm, torch::Tensor& workspace,
vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk,
int64_t moe_block_size, bool replicate_input, bool apply_weights) {
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
bool has_zp = b_zeros.size(1) != 0;
if (has_zp) {
TORCH_CHECK(
b_q_type == vllm::kU4,
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
} else {
TORCH_CHECK(
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
"b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type.str());
}
int pack_factor = 32 / b_q_type.size_bits();
int max_par = 4;
int dev = a.get_device();
auto options_dtype =
torch::TensorOptions().dtype(a.dtype()).device(a.device());
auto options_int =
torch::TensorOptions().dtype(torch::kInt).device(a.device());
torch::Tensor c = torch::zeros({size_m, topk, size_n}, options_dtype);
torch::Tensor a_tmp =
replicate_input ? torch::zeros({size_m, size_k}, options_dtype)
: torch::zeros({size_m, topk, size_k}, options_dtype);
torch::Tensor expert_offsets = torch::empty({num_experts + 1}, options_int);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_k = -1;
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_n = -1;
// sms: number of SMs to use for the kernel (can usually be left as auto -1)
int sms = -1;
// Detect groupsize and act_order
int num_groups = -1;
int group_size = -1;
bool has_act_order = g_idx.size(1) != 0;
int b_rank = b_scales.sizes().size();
TORCH_CHECK(b_rank == 3, "b_scales rank = ", b_rank, " is not 3");
TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2),
" is not size_n = ", size_n);
num_groups = b_scales.size(1);
TORCH_CHECK(VLLM_IMPLIES(!is_k_full, has_act_order),
"if is_k_full is false, has_act_order must be true");
if (has_act_order) {
if (is_k_full) {
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
", is not divisible by num_groups = ", num_groups);
group_size = size_k / num_groups;
} else {
group_size = 0;
}
} else {
if (num_groups > 1) {
TORCH_CHECK(
size_k % num_groups == 0, "size_k = ", size_k,
", is not divisible by b_scales.size(0) = ", b_scales.size(0));
group_size = size_k / num_groups;
} else {
group_size = -1;
}
}
// Verify b_zeros
if (has_zp) {
int rank = b_zeros.sizes().size();
TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3");
TORCH_CHECK(b_zeros.size(1) == num_groups,
"b_zeros dim 1 = ", b_zeros.size(1),
" is not num_groups = ", num_groups);
TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor,
"b_zeros dim 2 = ", b_zeros.size(2),
" is not size_n / pack_factor = ", size_n / pack_factor);
}
marlin_moe::marlin_mm_moe(
a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(),
topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),
expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(),
b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size,
num_experts, topk, moe_block_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par,
replicate_input, apply_weights);
return c;
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("marlin_gemm_moe", &marlin_gemm_moe);
}
kernel_*.cu
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import glob
import itertools
import os
......@@ -25,15 +26,16 @@ TEMPLATE = ("template __global__ void Marlin<"
"{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, "
"{{stages}}, "
"{{'true' if has_act_order else 'false'}}, "
"{{'true' if has_zp else 'false'}}, "
"{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );")
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128"]
SCALAR_TYPES = [
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
"vllm::kFE2M1f"
]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
......@@ -41,7 +43,7 @@ THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# = 0 : act order case
# = -1 : channelwise quantization
# > 0 : group_size=16*group_blocks
GROUP_BLOCKS = [0, -1, 2, 4, 8]
GROUP_BLOCKS = [0, -1, 1, 2, 4, 8]
DTYPES = ["fp16", "bf16"]
......@@ -52,21 +54,35 @@ def remove_old_kernels():
def generate_new_kernels():
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
has_zp = "B" not in scalar_type
all_template_str_list = []
for group_blocks, m_blocks, thread_configs in itertools.product(
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
has_act_order = group_blocks == 0
if has_zp and has_act_order:
# act order case only support gptq-int4 and gptq-int8
if group_blocks == 0 and scalar_type not in [
"vllm::kU4B8", "vllm::kU8B128"
]:
continue
if thread_configs[2] == 256:
# for small batch (m_blocks == 1), we only need (128, 128, 256)
# for large batch (m_blocks > 1), we only need (64, 256, 256)
if m_blocks <= 1 and thread_configs[0] != 128:
continue
if m_blocks > 1 and thread_configs[0] != 64:
continue
# we only support channelwise quantization and group_size == 128
# for fp8
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
continue
# nvfp4 only supports group_size == 16
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
continue
# other quantization methods don't support group_size = 16
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
continue
k_blocks = thread_configs[0] // 16
n_blocks = thread_configs[1] // 16
threads = thread_configs[2]
......@@ -82,8 +98,6 @@ def generate_new_kernels():
thread_k_blocks=k_blocks,
m_block_size_8=m_blocks == 0.5,
stages="pipe_stages",
has_act_order=has_act_order,
has_zp=has_zp,
group_blocks=group_blocks,
is_zp_float=False,
)
......
......@@ -7,18 +7,19 @@
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
bool use_fp32_reduce
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ scale2_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
bool use_fp32_reduce, int max_shared_mem
namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16
......@@ -33,11 +34,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
......
......@@ -25,6 +25,7 @@
#include "quantization/gptq_marlin/marlin.cuh"
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "quantization/gptq_marlin/dequant.h"
#include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
......@@ -48,11 +49,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
......@@ -77,8 +76,8 @@ __global__ void Marlin(
int prob_k, // reduction dimension k
int* locks, // extra global storage for barrier synchronization
bool use_atomic_add, // whether to use atomic add to reduce
bool use_fp32_reduce // whether to use fp32 global reduce
) {}
bool use_fp32_reduce, // whether to use fp32 global reduce
int max_shared_mem) {}
} // namespace MARLIN_NAMESPACE_NAME
......@@ -166,144 +165,6 @@ __device__ inline void ldsm(typename ScalarType<scalar_t>::FragA& frag_a,
}
}
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res)
: "r"(a), "r"(b), "r"(c), "n"(lut));
return res;
}
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template <int start_byte, int mask>
__device__ inline uint32_t prmt(uint32_t a) {
uint32_t res;
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
: "=r"(res)
: "r"(a), "n"(start_byte), "n"(mask));
return res;
}
template <typename scalar_t, int bit>
__device__ inline typename ScalarType<scalar_t>::FragB dequant(
int q, typename ScalarType<scalar_t>::FragB& frag_b);
//
// Efficiently dequantize 4bit values packed in an int32 value into a full
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
// with some small changes:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
//
template <>
__device__ inline typename ScalarType<half>::FragB dequant<half, 4>(
int q, typename ScalarType<half>::FragB& frag_b) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64086408;
const int MUL = 0x2c002c00;
const int ADD = 0xd480d480;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant<nv_bfloat16, 4>(int q,
typename ScalarType<nv_bfloat16>::FragB& frag_b) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
q >>= 4;
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC308C308;
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
return frag_b;
}
//
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
//
template <>
__device__ inline typename ScalarType<half>::FragB dequant<half, 8>(
int q, typename ScalarType<half>::FragB& frag_b) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant<nv_bfloat16, 8>(int q,
typename ScalarType<nv_bfloat16>::FragB& frag_b) {
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted =
reinterpret_cast<uint32_t*>(fp32_intermediates);
static constexpr uint32_t fp32_base = 0x4B000000;
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388736.f;
fp32_intermediates[1] -= 8388736.f;
fp32_intermediates[2] -= 8388736.f;
fp32_intermediates[3] -= 8388736.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
fp32_intermediates_casted[3], 0x7632);
return frag_b;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
template <typename scalar_t>
......@@ -429,11 +290,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
......@@ -442,9 +301,11 @@ __global__ void Marlin(
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
// only)
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k
const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids
const int32_t* __restrict__ expert_ids_ptr, // moe expert ids
const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens
......@@ -458,8 +319,8 @@ __global__ void Marlin(
int prob_k, // reduction dimension k
int* locks, // extra global storage for barrier synchronization
bool use_atomic_add, // whether to use atomic add to reduce
bool use_fp32_reduce // whether to use fp32 global reduce
) {
bool use_fp32_reduce, // whether to use fp32 global reduce
int max_shared_mem) {
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
......@@ -481,13 +342,26 @@ __global__ void Marlin(
extern __shared__ int4 sh[];
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
// see comments of dequant.h for more details
constexpr bool dequant_skip_flop =
!is_int_type ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(w_type == vllm::kU8);
scalar_t2 global_scale;
constexpr bool has_act_order = group_blocks == 0;
constexpr int pack_factor = 32 / w_type.size_bits();
static_assert(thread_m_blocks == 1 || !m_block_size_8);
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
const int group_size =
(!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups;
const int scales_expert_stride = prob_n * prob_k / group_size / 8;
const int scales_expert_stride =
prob_n * prob_k / group_size / (w_type == vllm::kFE2M1f ? 16 : 8);
const int zp_expert_stride =
is_zp_float ? prob_n * prob_k / group_size / 8
: prob_n * prob_k / group_size / (pack_factor * 4);
......@@ -534,13 +408,20 @@ __global__ void Marlin(
int64_t B_expert_off = 0;
int4* sh_block_sorted_ids_int4 = sh;
int4* sh_rd_block_sorted_ids_int4 =
sh_block_sorted_ids_int4 + moe_block_size / 4;
int4* sh_block_topk_weights_int4 =
sh_rd_block_sorted_ids_int4 + moe_block_size / 4;
// sh_block_topk_weights_int4 only need (moe_block_size / 4);
// but we pad to align to 256 bytes
int4* sh_new =
sh_block_topk_weights_int4 + moe_block_size / 2 + moe_block_size;
int32_t* sh_block_sorted_ids =
reinterpret_cast<int*>(sh_block_sorted_ids_int4);
int4* sh_block_topk_weights_int4 =
sh_block_sorted_ids_int4 + moe_block_size / 4;
int32_t* sh_rd_block_sorted_ids =
reinterpret_cast<int*>(sh_rd_block_sorted_ids_int4);
scalar_t2* sh_block_topk_weights =
reinterpret_cast<scalar_t2*>(sh_block_topk_weights_int4);
int4* sh_new = sh_block_topk_weights_int4 + moe_block_size / 4;
int32_t block_num_valid_tokens = 0;
int32_t locks_off = 0;
......@@ -584,12 +465,24 @@ __global__ void Marlin(
sh_block_sorted_ids_int4[tid4] = reinterpret_cast<const int4*>(
sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4];
#pragma unroll
for (int i = 0; i < 4; i++)
sh_rd_block_sorted_ids[tid4 * 4 + i] =
sh_block_sorted_ids[tid4 * 4 + i] / top_k;
if (mul_topk_weights) {
#pragma unroll
for (int i = 0; i < 4; i++) {
sh_block_topk_weights[tid4 * 4 + i] =
Dtype::num2num2(Dtype::float2num(
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]));
int idx = tid4 * 4 + i;
idx = idx < block_num_valid_tokens ? idx : 0;
if constexpr (w_type == vllm::kFE2M1f) {
sh_block_topk_weights[idx] = __hmul2(
global_scale, Dtype::num2num2(Dtype::float2num(
topk_weights_ptr[sh_block_sorted_ids[idx]])));
} else {
sh_block_topk_weights[idx] = Dtype::num2num2(
Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]]));
}
}
}
}
......@@ -620,6 +513,11 @@ __global__ void Marlin(
expert_id = expert_ids_ptr[block_id];
}
if constexpr (w_type == vllm::kFE2M1f) {
uint16_t val = scale2_ptr[expert_id];
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
}
B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4);
scales_ptr += (expert_id - old_expert_id) * scales_expert_stride;
if constexpr (has_zp) {
......@@ -733,7 +631,7 @@ __global__ void Marlin(
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks
? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1)
: 1;
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
int s_gl_rd_delta = s_gl_stride;
......@@ -743,6 +641,7 @@ __global__ void Marlin(
constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;
// constexpr int act_s_row_stride = 1;
// int act_s_col_stride = act_s_row_stride * num_groups;
constexpr int act_s_max_num_groups = 32;
int act_s_col_stride = 1;
int act_s_col_warp_stride = act_s_col_stride * 8;
int tb_n_warps = thread_n_blocks / 4;
......@@ -758,9 +657,9 @@ __global__ void Marlin(
int zp_gl_rd_delta = zp_gl_stride;
// Global A read index of current thread.
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
a_gl_rd += a_gl_rd_delta_o * slice_row;
int a_gl_rd_row = threadIdx.x / a_gl_rd_delta_o;
int a_gl_rd_col = a_gl_rd_delta_o * slice_row + threadIdx.x % a_gl_rd_delta_o;
// Shared write index of current thread.
int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
......@@ -774,8 +673,8 @@ __global__ void Marlin(
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
b_gl_rd += b_sh_stride * slice_col;
b_gl_rd += b_gl_rd_delta_o * slice_row;
int b_sh_wr = threadIdx.x * b_thread_vecs;
int b_sh_rd = threadIdx.x * b_thread_vecs;
auto b_sh_wr = threadIdx.x * b_thread_vecs;
auto b_sh_rd = threadIdx.x * b_thread_vecs;
// For act_order
constexpr int k_iter_size = tb_k / b_sh_wr_iters;
......@@ -790,11 +689,12 @@ __global__ void Marlin(
if constexpr (group_blocks == -1) {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
} else {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) /
(w_type == vllm::kFE2M1f ? 2 : 1) +
s_sh_stride * slice_col + threadIdx.x;
}
}
int s_sh_wr = threadIdx.x;
auto s_sh_wr = threadIdx.x;
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
// Zero-points
......@@ -807,17 +707,27 @@ __global__ void Marlin(
zp_sh_stride * slice_col + threadIdx.x;
}
}
int zp_sh_wr = threadIdx.x;
auto zp_sh_wr = threadIdx.x;
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
int s_sh_rd;
if constexpr (group_blocks != -1)
if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) {
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp))
s_sh_rd = s_sh_rd * 2 + warp_row % 2;
} else if constexpr (group_blocks != -1)
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
else if constexpr (group_blocks == -1 &&
(m_block_size_8 || (has_zp && !dequant_skip_flop)))
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 8;
else
......@@ -851,7 +761,7 @@ __global__ void Marlin(
// each warp must also write a consecutive memory segment?
auto transform_a = [&](int i) {
int row = i / a_gl_rd_delta_o;
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ (row % 8);
};
// Since the computation of this remapping is non-trivial and, due to our main
// loop unrolls, all shared memory accesses are static, we simply precompute
......@@ -879,12 +789,28 @@ __global__ void Marlin(
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
// Shared memory storage for global fetch pipelines.
int4* sh_a = sh_new;
int4* sh_b = sh_a + (stages * a_sh_stage);
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks;
constexpr int sh_b_size = stages * b_sh_stage;
int4* sh_b = sh_new;
int4* sh_red = sh_new;
int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
: (stages * s_sh_stage);
int4* sh_s = sh_zp + (stages * zp_sh_stage);
int4* sh_red = sh_b;
// shared memory reused by reduction should be smaller than
// shared memory used by weight.
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
stages * b_sh_stage);
int4* sh_a = sh_s + sh_s_size;
constexpr int shm_size_used =
moe_block_size + stages * (g_idx_stage + zp_sh_stage) + sh_s_size +
(sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
// all remaining shared memory is used to cache A (input)
// sh_a_max_row is at least ` stages * 16 * thread_m_blocks `
int sh_a_max_row =
((max_shared_mem - 1024) / 16 - shm_size_used) / (thread_k_blocks * 2);
// Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks];
......@@ -905,15 +831,14 @@ __global__ void Marlin(
int sh_first_group_id = -1;
int sh_num_groups = -1;
constexpr int sh_max_num_groups = 32;
auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id,
int last_group_id) {
sh_first_group_id = first_group_id;
sh_num_groups = last_group_id - first_group_id + 1;
if (sh_num_groups < sh_max_num_groups) {
sh_num_groups = sh_max_num_groups;
if (sh_num_groups > act_s_max_num_groups) {
sh_num_groups = act_s_max_num_groups;
}
if (sh_first_group_id + sh_num_groups > num_groups) {
......@@ -940,27 +865,31 @@ __global__ void Marlin(
}
}
};
// Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location.
int a_remaining_load_count_in_slice = stages;
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
bool should_load_a = true;
int max_num_stage_groups =
((sh_a_max_row - moe_block_size) / moe_block_size + 1) / stages;
max_num_stage_groups = max(max_num_stage_groups, 1);
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true,
int pipe_a = 0) {
if (pred) {
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
if (prob_k > thread_k_blocks * 16 * stages || slice_col == 0 ||
a_remaining_load_count_in_slice > 0) {
a_remaining_load_count_in_slice--;
if (should_load_a) {
int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a;
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++) {
int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off;
int row = a_idx / a_gl_stride;
int row = a_gl_rd_delta_i / a_gl_stride * i + a_gl_rd_row;
int64_t sorted_row = 0;
if (!m_block_size_8 || row < 8)
sorted_row = sh_block_sorted_ids[row] / top_k;
int64_t true_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride;
sorted_row = sh_rd_block_sorted_ids[row];
int64_t true_idx =
sorted_row * a_gl_stride + a_gl_rd_col + a_gl_rd_delta_o * a_off;
cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx],
row < block_num_valid_tokens);
}
}
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
......@@ -1063,8 +992,8 @@ __global__ void Marlin(
// Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer.
auto fetch_to_registers = [&](int k, int pipe) {
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
auto fetch_to_registers = [&](int k, int pipe, int pipe_a = 0) {
int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a;
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++)
ldsm<m_block_size_8 ? 2 : 4, scalar_t>(
......@@ -1109,12 +1038,17 @@ __global__ void Marlin(
}
} else if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
if (k % b_sh_wr_iters == 0) {
int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
} else {
reinterpret_cast<int4*>(&frag_s[1])[0] =
reinterpret_cast<int4*>(&frag_s[0])[0];
}
} else {
int warp_id = threadIdx.x / 32;
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
......@@ -1123,12 +1057,19 @@ __global__ void Marlin(
cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16;
int cur_group_id = k_blocks / group_blocks;
int cur_group_id =
k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1));
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
if constexpr (w_type_id != vllm::kFE2M1f.id()) {
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
reinterpret_cast<int2*>(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
}
}
}
......@@ -1152,7 +1093,7 @@ __global__ void Marlin(
// Determine "position" inside the thread-block (based on warp and
// thread-id)
int warp_id = threadIdx.x / 32;
auto warp_id = threadIdx.x / 32;
int n_warps =
thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
......@@ -1161,7 +1102,7 @@ __global__ void Marlin(
cur_k += warp_row * 16;
int th_id = threadIdx.x % 32;
auto th_id = threadIdx.x % 32;
cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
int s_col_shift =
......@@ -1222,15 +1163,18 @@ __global__ void Marlin(
}
} else if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_zp_stage =
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
if (k % b_sh_wr_iters == 0) {
int4* sh_zp_stage =
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
#pragma unroll
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
}
} else {
int warp_id = threadIdx.x / 32;
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
......@@ -1251,6 +1195,7 @@ __global__ void Marlin(
sh_zp_stage += cur_group_id * zp_sh_stride;
#pragma unroll
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
......@@ -1263,12 +1208,16 @@ __global__ void Marlin(
if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_zp_stage =
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd];
if (k % b_sh_wr_iters == 0) {
int4* sh_zp_stage =
sh_zp +
zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] =
sh_zp_stage[zp_sh_rd];
}
} else {
int warp_id = threadIdx.x / 32;
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
......@@ -1292,6 +1241,10 @@ __global__ void Marlin(
}
};
auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) {
dequant<scalar_t2, w_type_id, dequant_skip_flop>(q, frag_b_ptr);
};
// Execute the actual tensor core matmul of a sub-tile.
bool is_first_matmul_in_slice = true;
auto matmul = [&](int k) {
......@@ -1315,15 +1268,27 @@ __global__ void Marlin(
zp_quant_1 = frag_qzp[k2][1];
}
dequant<scalar_t, w_type.size_bits()>(zp_quant_0, frag_zp_0);
dequant<scalar_t, w_type.size_bits()>(zp_quant_1, frag_zp_1);
frag_zp[0] = frag_zp_0[0];
frag_zp[1] = frag_zp_0[1];
frag_zp[2] = frag_zp_1[0];
frag_zp[3] = frag_zp_1[1];
dequant_data(zp_quant_0, reinterpret_cast<scalar_t2*>(&frag_zp));
dequant_data(zp_quant_1, reinterpret_cast<scalar_t2*>(&frag_zp) + 2);
}
}
if constexpr (!dequant_skip_flop && has_zp && is_zp_float) {
if (is_new_zp) {
reinterpret_cast<int4*>(&frag_zp)[0] =
reinterpret_cast<int4*>(&frag_zpf[k2])[0];
}
}
if constexpr (w_type == vllm::kFE2M1f) {
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
dequant_fp8_scales<scalar_t2>(s_quant_0,
reinterpret_cast<scalar_t2*>(&frag_s[k2]));
dequant_fp8_scales<scalar_t2>(
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
}
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
......@@ -1332,7 +1297,10 @@ __global__ void Marlin(
FragB frag_b1;
int b_quant_0, b_quant_1;
if constexpr (w_type.size_bits() == 4) {
if constexpr (w_type_id == vllm::kFE2M1f.id()) {
b_quant_1 = frag_b_quant[k2][0][j];
b_quant_0 = b_quant_1 << 8;
} else if constexpr (w_type.size_bits() == 4) {
b_quant_0 = frag_b_quant[k2][0][j];
b_quant_1 = b_quant_0 >> 8;
} else {
......@@ -1342,8 +1310,13 @@ __global__ void Marlin(
b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
}
dequant<scalar_t, w_type.size_bits()>(b_quant_0, frag_b0);
dequant<scalar_t, w_type.size_bits()>(b_quant_1, frag_b1);
dequant_data(b_quant_0, reinterpret_cast<scalar_t2*>(&frag_b0));
dequant_data(b_quant_1, reinterpret_cast<scalar_t2*>(&frag_b1));
if constexpr (dequant_skip_flop && has_zp && !is_zp_float) {
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
}
// Apply scale to frag_b0
if constexpr (has_act_order) {
......@@ -1351,9 +1324,9 @@ __global__ void Marlin(
scale4<scalar_t>(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k][2][j], act_frag_s[k2][3][j], 1);
} else if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);
} else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float &&
group_blocks == -1) {
int idx = (threadIdx.x / 4) % 2;
scalar_t2 s2 = Dtype::nums2num2(
reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 0])[idx],
......@@ -1361,18 +1334,12 @@ __global__ void Marlin(
if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);
scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x);
scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y);
} else if constexpr (has_zp && !is_zp_float && group_blocks != -1) {
} else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) {
if (is_new_zp)
frag_zp[j] = __hmul2(frag_zp[j],
*reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
scale_and_sub<scalar_t>(frag_b0, frag_s[k % 2][j][0].x, frag_zp[j].x);
scale_and_sub<scalar_t>(frag_b1, frag_s[k % 2][j][0].y, frag_zp[j].y);
} else if constexpr (has_zp && is_zp_float && group_blocks != -1) {
if (is_new_zp)
frag_zpf[k2][j] = __hmul2(
frag_zpf[k2][j], *reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
scale_and_sub<scalar_t>(frag_b0, frag_s[k2][j].x, frag_zpf[k2][j].x);
scale_and_sub<scalar_t>(frag_b1, frag_s[k2][j].y, frag_zpf[k2][j].y);
scale_and_sub<scalar_t>(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x);
scale_and_sub<scalar_t>(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y);
} else if constexpr (group_blocks != -1) {
scale<scalar_t>(frag_b0, frag_s[k2][j], 0);
scale<scalar_t>(frag_b1, frag_s[k2][j], 1);
......@@ -1397,7 +1364,7 @@ __global__ void Marlin(
auto thread_block_reduce = [&]() {
constexpr int red_off = threads / b_sh_stride_threads / 2;
if (red_off >= 1) {
int red_idx = threadIdx.x / b_sh_stride_threads;
auto red_idx = threadIdx.x / b_sh_stride_threads;
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
constexpr int red_sh_delta = b_sh_stride_threads;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
......@@ -1634,10 +1601,17 @@ __global__ void Marlin(
// For per-column quantization we finally apply the scale here (only for
// 4-bit)
if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 4 && !has_zp) {
w_type.size_bits() == 4 &&
(has_zp && dequant_skip_flop || !has_zp)) {
res = __hmul2(res, s[0]);
}
if constexpr (w_type == vllm::kFE2M1f) {
if (!mul_topk_weights) {
res = __hmul2(res, global_scale);
}
}
if constexpr (m_block_size_8) {
((scalar_t*)sh_red)[idx] = res.x;
((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y;
......@@ -1728,10 +1702,12 @@ __global__ void Marlin(
if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
if (i == 0) {
fetch_col_zp_to_shared();
fetch_col_scale_to_shared();
if constexpr (!dequant_skip_flop) {
fetch_col_scale_to_shared();
}
}
}
fetch_to_shared(i, i, i < slice_iters);
fetch_to_shared(i, i, i < slice_iters, i);
}
zero_accums();
......@@ -1740,8 +1716,10 @@ __global__ void Marlin(
fetch_to_registers(0, 0);
fetch_scales_to_registers(0, 0);
fetch_zp_to_registers(0, 0);
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
slice_k_start_shared_fetch += tb_k * (stages - 1);
a_gl_rd_col += a_gl_rd_delta_o * (stages - 1);
if constexpr (has_act_order) {
slice_k_start_shared_fetch += tb_k * (stages - 1);
}
};
if (slice_iters) {
start_pipes();
......@@ -1754,43 +1732,59 @@ __global__ void Marlin(
// have even length meaning that the next iteration will always start at
// index 0.
for (int stage_group_id = 0; stage_group_id < max_num_stage_groups;
stage_group_id++) {
#pragma unroll
for (int pipe = 0; pipe < stages;) {
for (int pipe = 0; pipe < stages;) {
#pragma unroll
for (int k = 0; k < b_sh_wr_iters; k++) {
fetch_to_registers(k + 1, pipe % stages);
fetch_scales_to_registers(k + 1, pipe);
fetch_zp_to_registers(k + 1, pipe);
if (k == b_sh_wr_iters - 2) {
fetch_to_shared((pipe + stages - 1) % stages, pipe,
slice_iters >= stages);
pipe++;
wait_for_stage();
init_same_group(pipe % stages);
for (int k = 0; k < b_sh_wr_iters; k++) {
int idx =
(pipe >= stages && stage_group_id == max_num_stage_groups - 1)
? (pipe - stages)
: (pipe + stage_group_id * stages);
fetch_to_registers(k + 1, pipe % stages, idx);
fetch_scales_to_registers(k + 1, pipe);
fetch_zp_to_registers(k + 1, pipe);
if (k == b_sh_wr_iters - 2) {
int idx = (pipe >= 1 && stage_group_id == max_num_stage_groups - 1)
? (pipe - 1)
: (pipe + (stage_group_id + 1) * stages - 1);
fetch_to_shared((pipe + stages - 1) % stages, pipe,
slice_iters >= stages, idx);
pipe++;
wait_for_stage();
init_same_group(pipe % stages);
}
matmul(k);
}
slice_iters--;
if (slice_iters == 0) {
break;
}
matmul(k);
}
slice_iters--;
if (slice_iters == 0) {
break;
}
}
a_remaining_load_count_in_slice = 0;
a_gl_rd += a_gl_rd_delta_o * stages;
slice_k_start += tb_k * stages;
slice_k_start_shared_fetch += tb_k * stages;
a_gl_rd_col += a_gl_rd_delta_o * stages;
if constexpr (has_act_order) {
int first_group_id = g_idx[slice_k_start];
int last_g_idx = slice_k_start + stages * tb_k * 2;
if (last_g_idx >= prob_k) {
last_g_idx = prob_k - 1;
if constexpr (has_act_order) {
slice_k_start += tb_k * stages;
if (slice_k_start < prob_k) {
slice_k_start_shared_fetch += tb_k * stages;
int first_group_id = g_idx[slice_k_start];
int last_g_idx = slice_k_start + stages * tb_k * 2;
if (last_g_idx >= prob_k) {
last_g_idx = prob_k - 1;
}
int last_group_id = g_idx[last_g_idx];
if (last_group_id >= sh_first_group_id + sh_num_groups) {
fetch_act_order_scales_to_shared(false, first_group_id,
last_group_id);
__syncthreads();
}
}
}
int last_group_id = g_idx[last_g_idx];
if (last_group_id >= sh_first_group_id + sh_num_groups) {
fetch_act_order_scales_to_shared(false, first_group_id, last_group_id);
__syncthreads();
if (slice_iters == 0) {
break;
}
}
......@@ -1802,7 +1796,8 @@ __global__ void Marlin(
bool last = slice_idx == slice_count - 1;
// For per-column scales, we only fetch them here in the final step before
// write-out
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
if constexpr (!has_act_order && group_blocks == -1 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
......@@ -1812,7 +1807,8 @@ __global__ void Marlin(
}
thread_block_reduce();
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
if constexpr (!has_act_order && group_blocks == -1 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
cp_async_wait<0>();
__syncthreads();
......@@ -1836,7 +1832,8 @@ __global__ void Marlin(
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 8 && !has_zp) {
w_type.size_bits() == 8 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
......@@ -1877,15 +1874,30 @@ __global__ void Marlin(
if (last || use_atomic_add)
// only the last block in a slice actually writes the result
write_result();
if (slice_row) a_remaining_load_count_in_slice = stages;
int old_slice_row = slice_row;
slice_row = 0;
slice_col_par++;
slice_col++;
is_first_matmul_in_slice = true;
init_slice();
// Should we load A matrix in next slice?
// `slice_col == 0`: when move to a new moe block
// `old_slice_row > 0`:
// when the last slice is not starting from k_index == 0
// (only happen when it is the first slice of a threadblock)
// `prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups`:
// when the required shared memory size is larger than
// the remaining shared memory
if (slice_col == 0 || old_slice_row ||
prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups) {
should_load_a = true;
} else {
should_load_a = false;
}
if (slice_iters) {
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
a_gl_rd_col = (threadIdx.x % a_gl_rd_delta_o);
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
......@@ -1900,12 +1912,10 @@ __global__ void Marlin(
slice_k_finish = slice_k_start + tb_k * slice_iters;
slice_k_start_shared_fetch = slice_k_start;
slice_n_offset = act_s_col_tb_stride * slice_col;
} else {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
}
start_pipes();
}
}
......
......@@ -116,7 +116,7 @@ __global__ void permute_cols_kernel(
int base_k = 0;
for (int i = 0; i < iters; i++) {
int cur_k = base_k + threadIdx.x;
auto cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
......@@ -126,7 +126,7 @@ __global__ void permute_cols_kernel(
if (rest) {
if (threadIdx.x < rest) {
int cur_k = base_k + threadIdx.x;
auto cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
......@@ -195,7 +195,6 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 2;
} else {
int tb_scales = tb_groups * tb_n * 2;
......@@ -203,22 +202,24 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
}
}
int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int group_size, bool has_act_order, bool is_k_full,
int has_zp, int is_zp_float) {
int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
int thread_m_blocks, int prob_m, int prob_n,
int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full, int has_zp,
int is_zp_float) {
int pack_factor = 32 / num_bits;
// Get B size
int tb_k = th_config.thread_k;
int tb_n = th_config.thread_n;
int tb_m = thread_m_blocks * 16;
int tb_m = thread_m_blocks * (m_block_size_8 ? 8 : 16);
// shm size for block_sorted_ids/block_topk_weights
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
int sh_block_meta_size = tb_m * 4 * 2;
int sh_block_meta_size = tb_m * 4;
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
int sh_red_size = tb_m * (tb_n + 8) * 2;
int sh_s_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full);
......@@ -233,16 +234,17 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
sh_zp_size = sh_s_size / 2;
}
int total_size = sh_a_size + sh_b_size + sh_s_size + sh_zp_size +
sh_g_idx_size + sh_block_meta_size;
int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size +
sh_zp_size + sh_g_idx_size + sh_block_meta_size;
return total_size;
}
bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int group_size, bool has_act_order, bool is_k_full,
int has_zp, int is_zp_float, int max_shared_mem) {
bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
int thread_m_blocks, int prob_m, int prob_n, int prob_k,
int num_bits, int group_size, bool has_act_order,
bool is_k_full, int has_zp, int is_zp_float,
int max_shared_mem) {
// Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) {
......@@ -266,143 +268,129 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
// Check that pipeline fits into cache
int cache_size = get_kernel_cache_size(
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float);
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float);
return cache_size <= max_shared_mem;
}
#define __GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
NUM_THREADS, IS_ZP_FLOAT) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
m_block_size_8 == M_BLOCK_SIZE_8 && \
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
pipe_stages, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
IS_ZP_FLOAT>; \
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
m_block_size_8 == M_BLOCK_SIZE_8 && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
}
#define GPTQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, true, false, 0, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 8, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false)
#define GPTQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false)
#define AWQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 2, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 8, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false)
#define AWQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false)
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
// this is the most common cases
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
// FZP: cases for float-zero-point (is_zp_float = true)
// ACT: cases for act order case (group_blocks == 0)
// FP4: cases for nvfp4(e2m1) (group_blocks == 1)
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define COMMON_GET_IF(W_TYPE) \
COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \
COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \
COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \
COMMON_GET_IF_M234(W_TYPE, 8, 4, 128)
#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define FP4_GET_IF(W_TYPE) \
FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
FP4_GET_IF_M234(W_TYPE, 8, 4, 128)
#define BIGGROUP_GET_IF(W_TYPE) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128)
// We currently have 4-bit models only with group_blocks == 4
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF(W_TYPE) \
FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \
FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \
FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \
FZP_GET_IF_M234(W_TYPE, 8, 4, 128)
// We currently have 4-bit models only with group_blocks == 4
#define HQQ_GET_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
true) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true)
#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF(W_TYPE) \
ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \
ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \
ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \
ACT_GET_IF_M234(W_TYPE, 8, 4, 128)
template <typename scalar_t>
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
......@@ -415,23 +403,17 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
auto kernel = MarlinDefault;
if (false) {
}
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 8, 256)
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 4, 128)
GPTQ_GET_IF_M234(vllm::kU4B8, 16, 4, 256)
GPTQ_GET_IF_M234(vllm::kU4B8, 8, 4, 128)
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 8, 256)
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 4, 128)
COMMON_GET_IF(vllm::kU4)
COMMON_GET_IF(vllm::kU4B8)
COMMON_GET_IF(vllm::kU8B128)
GPTQ_GET_IF_M234(vllm::kU8B128, 16, 4, 256)
GPTQ_GET_IF_M234(vllm::kU8B128, 8, 4, 128)
BIGGROUP_GET_IF(vllm::kFE4M3fn)
AWQ_GET_IF_M1(vllm::kU4, 8, 8, 256)
AWQ_GET_IF_M1(vllm::kU4, 8, 4, 128)
FP4_GET_IF(vllm::kFE2M1f)
AWQ_GET_IF_M234(vllm::kU4, 16, 4, 256)
AWQ_GET_IF_M234(vllm::kU4, 8, 4, 128)
ACT_GET_IF(vllm::kU4B8)
ACT_GET_IF(vllm::kU8B128)
return kernel;
}
......@@ -457,19 +439,19 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
for (int i = 0; i < thread_configs_size; i++) {
thread_config_t th_config = thread_configs[i];
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp,
is_zp_float, max_shared_mem)) {
if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, max_shared_mem)) {
continue;
}
int cache_size = get_kernel_cache_size(
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full, has_zp, is_zp_float);
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float);
int group_blocks = 0;
if (!has_act_order) {
group_blocks = group_size == -1 ? -1 : group_size / 16;
group_blocks = group_size == -1 ? -1 : (group_size / 16);
}
auto kernel = get_marlin_kernel<scalar_t>(
......@@ -501,7 +483,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
template <typename scalar_t>
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
void* zp, void* g_idx, void* perm, void* a_tmp,
void* s2, void* zp, void* g_idx, void* perm, void* a_tmp,
void* sorted_token_ids, void* expert_ids,
void* num_tokens_past_padded, void* topk_weights,
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep,
......@@ -520,8 +502,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
} else {
TORCH_CHECK(
q_type == vllm::kU4B8 || q_type == vllm::kU8B128,
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f,
"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when "
"has_zp = False. Got = ",
q_type.str());
}
......@@ -555,6 +539,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp;
const int4* s_ptr = (const int4*)s;
const uint16_t* s2_ptr = (const uint16_t*)s2;
const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx;
const int* perm_ptr = (const int*)perm;
......@@ -631,18 +616,18 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16;
TORCH_CHECK(is_valid_config(thread_tfg, thread_m_blocks, prob_m, prob_n,
prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, max_shared_mem),
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
", thread_k = ", thread_tfg.thread_k,
", thread_n = ", thread_tfg.thread_n,
", num_threads = ", thread_tfg.num_threads, " for MKN = [",
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
", group_size = ", group_size,
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
", max_shared_mem = ", max_shared_mem);
TORCH_CHECK(
is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, max_shared_mem),
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
", thread_k = ", thread_tfg.thread_k,
", thread_n = ", thread_tfg.thread_n,
", num_threads = ", thread_tfg.num_threads, " for MKN = [", prob_m, ", ",
prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
", group_size = ", group_size, ", has_act_order = ", has_act_order,
", is_k_full = ", is_k_full, ", has_zp = ", has_zp,
", is_zp_float = ", is_zp_float, ", max_shared_mem = ", max_shared_mem);
auto kernel = get_marlin_kernel<scalar_t>(
q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8,
......@@ -663,10 +648,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
// avoid ">>>" being formatted to "> > >"
// clang-format off
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr,
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr,
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce);
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem);
// clang-format on
}
......@@ -675,6 +660,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& global_scale_or_none,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
......@@ -826,6 +812,17 @@ torch::Tensor moe_wna16_marlin_gemm(
}
}
torch::Tensor global_scale;
if (global_scale_or_none.has_value()) {
global_scale = global_scale_or_none.value();
TORCH_CHECK(b_q_type == vllm::kFE2M1f,
"global_scale can only be used for float4_e2m1f.");
} else {
global_scale = torch::empty({0}, options);
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f),
"the global_scale parameter must be passed for float4_e2m1f.");
}
torch::Tensor b_zeros;
if (b_zeros_or_none.has_value()) {
b_zeros = b_zeros_or_none.value();
......@@ -838,13 +835,15 @@ torch::Tensor moe_wna16_marlin_gemm(
if (has_zp) {
TORCH_CHECK(
b_q_type == vllm::kU4,
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
"b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str());
} else {
TORCH_CHECK(
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
b_q_type.str());
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 ||
b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f,
"b_q_type must be uint4b8, uint8b128, float8_e4m3fn or "
"float4_e2m1f when "
"has_zp = False. Got = ",
b_q_type.str());
}
if (has_zp && is_zp_float) {
......@@ -889,9 +888,16 @@ torch::Tensor moe_wna16_marlin_gemm(
int dev = a.get_device();
if (a.scalar_type() == at::ScalarType::Half) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
} else {
scales_ptr = b_scales.data_ptr<at::Half>();
}
MARLIN_NAMESPACE_NAME::marlin_mm<half>(
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
c_tmp.data_ptr<float>(), scales_ptr, global_scale.data_ptr<at::Half>(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
a_tmp.data_ptr<at::Half>(), sorted_token_ids.data_ptr(),
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
......@@ -901,11 +907,18 @@ torch::Tensor moe_wna16_marlin_gemm(
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
} else {
scales_ptr = b_scales.data_ptr<at::BFloat16>();
}
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(), scales_ptr,
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
......
......@@ -399,7 +399,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
}
if (use_global_memory) {
VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
......@@ -424,7 +424,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
cumsum_buffer.data_ptr<int32_t>());
});
} else if (use_i16) {
VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// set dynamic shared mem
auto kernel =
......@@ -439,7 +439,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
topk_ids.numel());
});
} else {
VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
auto kernel =
vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
......@@ -464,7 +464,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
TORCH_CHECK(num_experts == 256,
"sgl_moe_align_block_size kernel only supports deepseek v3.");
VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `cumsum` tensors
auto options_int =
......
......@@ -30,6 +30,12 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
int64_t BLOCK_SIZE_K, int64_t bit);
#endif
bool moe_permute_unpermute_supported();
void shuffle_rows(const torch::Tensor& input_tensor,
const torch::Tensor& dst2src_map,
torch::Tensor& output_tensor);
std::vector<torch::Tensor> moe_fused_gate(
torch::Tensor& input,
torch::Tensor& bias,
......@@ -37,4 +43,4 @@ std::vector<torch::Tensor> moe_fused_gate(
int64_t topk_group,
int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor);
\ No newline at end of file
double routed_scaling_factor);
#include <c10/core/ScalarType.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include "permute_unpermute_kernels/moe_permute_unpermute_kernel.h"
#include "permute_unpermute_kernels/dispatch.h"
#include "core/registration.h"
// moe_permute kernels require at least CUDA 12.0
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
void moe_permute(
const torch::Tensor& input, // [n_token, hidden]
const torch::Tensor& topk_weights, //[n_token, topk]
torch::Tensor& topk_ids, // [n_token, topk]
const torch::Tensor& token_expert_indicies, // [n_token, topk]
const std::optional<torch::Tensor>& expert_map, // [n_expert]
int64_t n_expert, int64_t n_local_expert, int64_t topk,
const std::optional<int64_t>& align_block_size,
torch::Tensor&
permuted_input, // [topk * n_token/align_block_size_m, hidden]
torch::Tensor& expert_first_token_offset, // [n_local_expert + 1]
torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk]
torch::Tensor& m_indices) { // [align_expand_m]
TORCH_CHECK(topk_weights.scalar_type() == at::ScalarType::Float,
"topk_weights must be float32");
TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long,
"expert_first_token_offset must be int64");
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
"topk_ids must be int32");
TORCH_CHECK(token_expert_indicies.scalar_type() == at::ScalarType::Int,
"token_expert_indicies must be int32");
TORCH_CHECK(src_row_id2dst_row_id_map.scalar_type() == at::ScalarType::Int,
"src_row_id2dst_row_id_map must be int32");
TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1,
"expert_first_token_offset shape != n_local_expert+1")
TORCH_CHECK(
src_row_id2dst_row_id_map.sizes() == token_expert_indicies.sizes(),
"token_expert_indicies shape must be same as src_row_id2dst_row_id_map");
auto n_token = input.sizes()[0];
auto n_hidden = input.sizes()[1];
auto align_block_size_value =
align_block_size.has_value() ? align_block_size.value() : -1;
auto stream = at::cuda::getCurrentCUDAStream().stream();
const long sorter_size =
CubKeyValueSorter::getWorkspaceSize(n_token * topk, n_expert);
auto sort_workspace = torch::empty(
{sorter_size},
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
auto permuted_experts_id = torch::empty_like(topk_ids);
auto dst_row_id2src_row_id_map = torch::empty_like(src_row_id2dst_row_id_map);
auto align_expert_first_token_offset =
torch::zeros_like(expert_first_token_offset);
CubKeyValueSorter sorter{};
int64_t* valid_num_ptr = nullptr;
// pre-process kernel for expert-parallelism:
// no local expert id plus "n_expert" offset for priority to local expert
// map local expert id [n, .., n+n_local_expert-1] to [0, n_local_expert -1]
// For example, 4 expert with ep_size=2. ep_rank=1 owns global expert id
// [2,3] with expert_map[-1, -1, 0, 1], preprocess_topk_id process topk_ids
// and map global expert id [2, 3] to local_expert id [0, 1] and map global
// expert id [0, 1] ( not in ep rank=1) to [4, 5] by plus n_expert. This map
// operation is to make local expert high priority in following sort topk_ids
// and scan local expert_first_token_offset for each ep rank for next group
// gemm.
if (expert_map.has_value()) {
const int* expert_map_ptr = get_ptr<int>(expert_map.value());
valid_num_ptr =
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
preprocessTopkIdLauncher(get_ptr<int>(topk_ids), n_token * topk,
expert_map_ptr, n_expert, stream);
}
// expert sort topk expert id and scan expert id get expert_first_token_offset
sortAndScanExpert(get_ptr<int>(topk_ids), get_ptr<int>(token_expert_indicies),
get_ptr<int>(permuted_experts_id),
get_ptr<int>(dst_row_id2src_row_id_map),
get_ptr<int64_t>(expert_first_token_offset), n_token,
n_expert, n_local_expert, topk, sorter,
get_ptr<int>(sort_workspace), stream);
// dispatch expandInputRowsKernelLauncher
MOE_DISPATCH(input.scalar_type(), [&] {
expandInputRowsKernelLauncher<scalar_t>(
get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input),
get_ptr<float>(topk_weights), get_ptr<int>(permuted_experts_id),
get_ptr<int>(dst_row_id2src_row_id_map),
get_ptr<int>(src_row_id2dst_row_id_map),
get_ptr<int64_t>(expert_first_token_offset), n_token, valid_num_ptr,
n_hidden, topk, n_local_expert, align_block_size_value, stream);
});
// get m_indices and update expert_first_token_offset with align block
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
get_ptr<int64_t>(align_expert_first_token_offset),
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
stream);
if (align_block_size.has_value()) {
// update align_expert_first_token_offset
expert_first_token_offset.copy_(align_expert_first_token_offset);
}
}
void moe_unpermute(
const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden]
const torch::Tensor& topk_weights, //[n_token, topk]
const torch::Tensor& topk_ids, // [n_token, topk]
const torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk]
const torch::Tensor& expert_first_token_offset, // [n_local_expert+1]
int64_t n_expert, int64_t n_local_expert, int64_t topk,
torch::Tensor& hidden_states // [n_token, hidden]
) {
TORCH_CHECK(src_row_id2dst_row_id_map.sizes() == topk_ids.sizes(),
"topk_ids shape must be same as src_row_id2dst_row_id_map");
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
"topk_ids must be int32");
TORCH_CHECK(
permuted_hidden_states.scalar_type() == hidden_states.scalar_type(),
"topk_ids dtype must be same as src_row_id2dst_row_id_map");
auto n_token = hidden_states.size(0);
auto n_hidden = hidden_states.size(1);
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int64_t* valid_ptr =
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
MOE_DISPATCH(hidden_states.scalar_type(), [&] {
finalizeMoeRoutingKernelLauncher<scalar_t, scalar_t>(
get_ptr<scalar_t>(permuted_hidden_states),
get_ptr<scalar_t>(hidden_states), get_ptr<float>(topk_weights),
get_ptr<int>(src_row_id2dst_row_id_map), get_ptr<int>(topk_ids),
n_token, n_hidden, topk, valid_ptr, stream);
});
}
template <typename T>
__global__ void shuffleInputRowsKernel(const T* input,
const int32_t* dst2src_map, T* output,
int64_t num_src_rows,
int64_t num_dst_rows, int64_t num_cols) {
int64_t dest_row_idx = blockIdx.x;
int64_t const source_row_idx = dst2src_map[dest_row_idx];
if (blockIdx.x < num_dst_rows) {
// Load 128-bits per thread
constexpr int64_t ELEM_PER_THREAD = 128 / sizeof(T) / 8;
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
// Duplicate and permute rows
auto const* source_row_ptr =
reinterpret_cast<DataElem const*>(input + source_row_idx * num_cols);
auto* dest_row_ptr =
reinterpret_cast<DataElem*>(output + dest_row_idx * num_cols);
int64_t const start_offset = threadIdx.x;
int64_t const stride = blockDim.x;
int64_t const num_elems_in_col = num_cols / ELEM_PER_THREAD;
for (int elem_index = start_offset; elem_index < num_elems_in_col;
elem_index += stride) {
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
}
}
}
void shuffle_rows(const torch::Tensor& input_tensor,
const torch::Tensor& dst2src_map,
torch::Tensor& output_tensor) {
TORCH_CHECK(input_tensor.scalar_type() == output_tensor.scalar_type(),
"Input and output tensors must have the same data type");
auto stream = at::cuda::getCurrentCUDAStream().stream();
int64_t const blocks = output_tensor.size(0);
int64_t const threads = 256;
int64_t const num_dest_rows = output_tensor.size(0);
int64_t const num_src_rows = input_tensor.size(0);
int64_t const num_cols = input_tensor.size(1);
TORCH_CHECK(!(num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)),
"num_cols must be divisible by 128 / "
"sizeof(input_tensor.scalar_type()) / 8");
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
dst2src_map.data_ptr<int32_t>(),
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
num_dest_rows, num_cols);
});
}
#else
void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
torch::Tensor& topk_ids,
const torch::Tensor& token_expert_indicies,
const std::optional<torch::Tensor>& expert_map,
int64_t n_expert, int64_t n_local_expert, int64_t topk,
const std::optional<int64_t>& align_block_size,
torch::Tensor& permuted_input,
torch::Tensor& expert_first_token_offset,
torch::Tensor& src_row_id2dst_row_id_map,
torch::Tensor& m_indices) {
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
}
void moe_unpermute(const torch::Tensor& input,
const torch::Tensor& topk_weights, torch::Tensor& topk_ids,
const torch::Tensor& token_expert_indicies,
const std::optional<torch::Tensor>& expert_map,
int64_t n_expert, int64_t n_local_expert, int64_t topk,
const std::optional<int64_t>& align_block_size,
torch::Tensor& permuted_input,
torch::Tensor& expert_first_token_offset,
torch::Tensor& src_row_id2dst_row_id_map,
torch::Tensor& m_indices) {
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
}
#endif
bool moe_permute_unpermute_supported() {
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
return true;
#else
return false;
#endif
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("moe_permute", &moe_permute);
m.impl("moe_unpermute", &moe_unpermute);
}
......@@ -108,11 +108,11 @@ __device__ inline void dequant<half2, 4>(int q, half2* res) {
const int MUL = 0x2c002c00;
const int ADD = 0xd400d400;
int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
q >>= 8;
int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
res[0] = __hsub2(*reinterpret_cast<half2*>(&lo0),
*reinterpret_cast<const half2*>(&SUB));
......@@ -149,13 +149,13 @@ __device__ inline void dequant<nv_bfloat162, 4>(int q, nv_bfloat162* res) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
int hi0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
int hi1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC300C300;
......
#pragma once
#include <cuda_fp8.h>
#define MOE_SWITCH(TYPE, ...) \
at::ScalarType _st = ::detail::scalar_type(TYPE); \
switch (_st) { \
__VA_ARGS__ \
default: \
TORCH_CHECK(false, "[moe permute]data type dispatch fail!") \
}
#define MOE_DISPATCH_CASE(enum_type, ...) \
case enum_type: { \
using scalar_t = ScalarType2CudaType<enum_type>::type; \
__VA_ARGS__(); \
break; \
}
#define MOE_DISPATCH_FLOAT_CASE(...) \
MOE_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
#define MOE_DISPATCH(TYPE, ...) \
MOE_SWITCH(TYPE, MOE_DISPATCH_FLOAT_CASE(__VA_ARGS__))
template <at::ScalarType type>
struct ScalarType2CudaType;
template <>
struct ScalarType2CudaType<at::ScalarType::Float> {
using type = float;
};
template <>
struct ScalarType2CudaType<at::ScalarType::Half> {
using type = half;
};
template <>
struct ScalarType2CudaType<at::ScalarType::BFloat16> {
using type = __nv_bfloat16;
};
// uint8 for packed fp4
template <>
struct ScalarType2CudaType<at::ScalarType::Byte> {
using type = uint8_t;
};
// #if __CUDA_ARCH__ >= 890
// fp8
template <>
struct ScalarType2CudaType<at::ScalarType::Float8_e5m2> {
using type = __nv_fp8_e5m2;
};
template <>
struct ScalarType2CudaType<at::ScalarType::Float8_e4m3fn> {
using type = __nv_fp8_e4m3;
};
// #endif
\ No newline at end of file
#include "moe_permute_unpermute_kernel.h"
// moe_permute kernels require at least CUDA 12.0
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
// CubKeyValueSorter definition begin
CubKeyValueSorter::CubKeyValueSorter()
: num_experts_(0), num_bits_(sizeof(int) * 8) {}
int CubKeyValueSorter::expertsToBits(int num_experts) {
// Max value we represent is V = num_experts + (num_experts - 1) = 2 *
// num_experts - 1 The maximum number of bits is therefore floor(log2(V)) + 1
return static_cast<int>(log2(2 * num_experts - 1)) + 1;
}
CubKeyValueSorter::CubKeyValueSorter(int const num_experts)
: num_experts_(num_experts), num_bits_(expertsToBits(num_experts)) {}
void CubKeyValueSorter::updateNumExperts(int const num_experts) {
num_experts_ = num_experts;
num_bits_ = expertsToBits(num_experts);
}
size_t CubKeyValueSorter::getWorkspaceSize(size_t const num_key_value_pairs,
int const num_experts) {
int num_bits = expertsToBits(num_experts);
size_t required_storage = 0;
int* null_int = nullptr;
cub::DeviceRadixSort::SortPairs(nullptr, required_storage, null_int, null_int,
null_int, null_int, num_key_value_pairs, 0,
num_bits);
// when num_key_value_pairs, num_experts, num_bits, required_storage = 64,
// 4, 3, 0 The required_storage seems to vary between 0 and 1 for the same
// inputs
if (required_storage == 0) {
required_storage = 1;
}
return required_storage;
}
void CubKeyValueSorter::run(void* workspace, size_t const workspace_size,
int const* keys_in, int* keys_out,
int const* values_in, int* values_out,
size_t const num_key_value_pairs,
cudaStream_t stream) {
size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs, num_experts_);
size_t actual_ws_size = workspace_size;
TORCH_CHECK(expected_ws_size <= workspace_size,
"[CubKeyValueSorter::run] The allocated workspace is too small "
"to run this problem.");
cub::DeviceRadixSort::SortPairs(workspace, actual_ws_size, keys_in, keys_out,
values_in, values_out, num_key_value_pairs, 0,
num_bits_, stream);
}
// CubKeyValueSorter definition end
static inline size_t pad_to_multiple_of_16(size_t const& input) {
static constexpr int ALIGNMENT = 16;
return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT);
}
template <class T>
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices,
int64_t const arr_length,
T const target) {
int64_t low = 0, high = arr_length - 1, target_location = -1;
while (low <= high) {
int64_t mid = (low + high) / 2;
if (sorted_indices[mid] >= target) {
high = mid - 1;
} else {
low = mid + 1;
target_location = mid;
}
}
return target_location + 1;
}
// Calculates the start offset of the tokens for a given expert. The last
// element is the total number of valid tokens
__global__ void computeExpertFirstTokenOffsetKernel(
int const* sorted_experts, int64_t const sorted_experts_len,
int const num_experts, int64_t* expert_first_token_offset) {
// First, compute the global tid. We only need 1 thread per expert.
int const expert = blockIdx.x * blockDim.x + threadIdx.x;
// Note that expert goes [0, num_experts] (inclusive) because we want a count
// for the total number of active tokens at the end of the scan.
if (expert >= num_experts + 1) {
return;
}
expert_first_token_offset[expert] =
findTotalEltsLessThanTarget(sorted_experts, sorted_experts_len, expert);
}
void computeExpertFirstTokenOffset(int const* sorted_indices,
int const total_indices,
int const num_experts,
int64_t* expert_first_token_offset,
cudaStream_t stream) {
int const num_entries = num_experts + 1;
int const threads = std::min(1024, num_entries);
int const blocks = (num_entries + threads - 1) / threads;
computeExpertFirstTokenOffsetKernel<<<blocks, threads, 0, stream>>>(
sorted_indices, total_indices, num_experts, expert_first_token_offset);
}
void sortAndScanExpert(int* expert_for_source_row, const int* source_rows,
int* permuted_experts, int* permuted_rows,
int64_t* expert_first_token_offset, int num_rows,
int num_experts, int num_experts_per_node, int k,
CubKeyValueSorter& sorter, void* sorter_ws,
cudaStream_t stream) {
int64_t const expanded_num_rows = static_cast<int64_t>(k) * num_rows;
// We need to use the full num_experts because that is the sentinel value used
// by topk for disabled experts
sorter.updateNumExperts(num_experts);
size_t const sorter_ws_size_bytes = pad_to_multiple_of_16(
sorter.getWorkspaceSize(expanded_num_rows, num_experts));
sorter.run((void*)sorter_ws, sorter_ws_size_bytes, expert_for_source_row,
permuted_experts, source_rows, permuted_rows, expanded_num_rows,
stream);
computeExpertFirstTokenOffset(permuted_experts, expanded_num_rows,
num_experts_per_node, expert_first_token_offset,
stream);
}
__global__ void preprocessTopkIdKernel(int* topk_id_ptr, int size,
const int* expert_map_ptr,
int num_experts) {
auto tidx = threadIdx.x;
auto bidx = blockIdx.x;
auto offset = bidx * blockDim.x;
auto bound = min(offset + blockDim.x, size);
extern __shared__ int smem_expert_map[];
// store expert_map in smem
for (int i = tidx; i < num_experts; i += blockDim.x) {
smem_expert_map[i] = expert_map_ptr[i];
}
__syncthreads();
// query global expert id in expert map.
// if global expert id = -1 in exert map, plus n_expert
// else set global expert id = exert map[global expert id]
if (offset + tidx < bound) {
auto topk_id = topk_id_ptr[offset + tidx];
auto local_expert_idx = smem_expert_map[topk_id];
if (local_expert_idx == -1) {
topk_id += num_experts;
} else {
topk_id = local_expert_idx;
}
__syncwarp();
topk_id_ptr[offset + tidx] = topk_id;
}
}
void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
const int* expert_map_ptr, int num_experts,
cudaStream_t stream) {
int block = std::min(size, 1024);
int grid = (size + block - 1) / block;
int smem_size = (num_experts) * sizeof(int);
preprocessTopkIdKernel<<<grid, block, smem_size, stream>>>(
topk_id_ptr, size, expert_map_ptr, num_experts);
}
template <bool ALIGN_BLOCK_SIZE>
__global__ void getMIndicesKernel(int64_t* expert_first_token_offset,
int64_t* align_expert_first_token_offset,
int* m_indices, const int num_local_expert,
const int align_block_size) {
int eidx = blockIdx.x;
int tidx = threadIdx.x;
extern __shared__ int64_t smem_expert_first_token_offset[];
for (int i = tidx; i <= num_local_expert; i += blockDim.x) {
smem_expert_first_token_offset[tidx] = __ldg(expert_first_token_offset + i);
}
__syncthreads();
auto last_token_offset = smem_expert_first_token_offset[eidx + 1];
auto first_token_offset = smem_expert_first_token_offset[eidx];
int n_token_in_expert = last_token_offset - first_token_offset;
if constexpr (ALIGN_BLOCK_SIZE) {
n_token_in_expert = (n_token_in_expert + align_block_size - 1) /
align_block_size * align_block_size;
// round up to ALIGN_BLOCK_SIZE
int64_t accumulate_align_offset = 0;
for (int i = 1; i <= eidx + 1; i++) {
int n_token = smem_expert_first_token_offset[i] -
smem_expert_first_token_offset[i - 1];
accumulate_align_offset =
accumulate_align_offset + (n_token + align_block_size - 1) /
align_block_size * align_block_size;
if (i == eidx) {
first_token_offset = accumulate_align_offset;
}
// last block store align_expert_first_token_offset
if (eidx == num_local_expert - 1 && threadIdx.x == 0) {
align_expert_first_token_offset[i] = accumulate_align_offset;
}
}
}
for (int idx = tidx; idx < n_token_in_expert; idx += blockDim.x) {
// update m_indice with expert id
m_indices[first_token_offset + idx] = eidx;
}
}
void getMIndices(int64_t* expert_first_token_offset,
int64_t* align_expert_first_token_offset, int* m_indices,
int num_local_expert, const int align_block_size,
cudaStream_t stream) {
int block = 256;
int grid = num_local_expert;
int smem_size = sizeof(int64_t) * (num_local_expert + 1);
if (align_block_size == -1) {
getMIndicesKernel<false><<<grid, block, smem_size, stream>>>(
expert_first_token_offset, align_expert_first_token_offset, m_indices,
num_local_expert, align_block_size);
} else {
getMIndicesKernel<true><<<grid, block, smem_size, stream>>>(
expert_first_token_offset, align_expert_first_token_offset, m_indices,
num_local_expert, align_block_size);
}
}
#endif
#pragma once
// reference from tensorrt_llm moe kernel implementation archive in
// https://github.com/BBuf/tensorrt-llm-moe/tree/master
#include <c10/core/ScalarType.h>
#include <torch/all.h>
#include "dispatch.h"
#include <cub/cub.cuh>
#include <cub/device/device_radix_sort.cuh>
#include <cub/util_type.cuh>
#include "cutlass/numeric_size.h"
#include "cutlass/array.h"
template <typename T>
inline T* get_ptr(torch::Tensor& t) {
return reinterpret_cast<T*>(t.data_ptr());
}
template <typename T>
inline const T* get_ptr(const torch::Tensor& t) {
return reinterpret_cast<const T*>(t.data_ptr());
}
class CubKeyValueSorter {
public:
CubKeyValueSorter();
CubKeyValueSorter(int const num_experts);
void updateNumExperts(int const num_experts);
static size_t getWorkspaceSize(size_t const num_key_value_pairs,
int const num_experts);
void run(void* workspace, size_t const workspace_size, int const* keys_in,
int* keys_out, int const* values_in, int* values_out,
size_t const num_key_value_pairs, cudaStream_t stream);
private:
static int expertsToBits(int experts);
int num_experts_;
int num_bits_;
};
void computeExpertFirstTokenOffset(int const* sorted_indices,
int const total_indices,
int const num_experts,
int64_t* expert_first_token_offset,
cudaStream_t stream);
void sortAndScanExpert(int* expert_for_source_row, const int* source_rows,
int* permuted_experts, int* permuted_rows,
int64_t* expert_first_token_offset, int num_rows,
int num_experts, int num_experts_per_node, int k,
CubKeyValueSorter& sorter, void* sorter_ws,
cudaStream_t stream);
template <typename T>
void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output,
const float* unpermuted_scales, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row,
int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int num_local_experts, const int& align_block_size, cudaStream_t stream);
// Final kernel to unpermute and scale
// This kernel unpermutes the original data, does the k-way reduction and
// performs the final skip connection.
template <typename T, typename OutputType, bool CHECK_SKIPPED>
__global__ void finalizeMoeRoutingKernel(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const orig_cols, int64_t const k,
int64_t const* num_valid_ptr);
template <class T, class OutputType>
void finalizeMoeRoutingKernelLauncher(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const num_rows,
int64_t const cols, int64_t const k, int64_t const* num_valid_ptr,
cudaStream_t stream);
void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
const int* expert_map_ptr, int num_experts,
cudaStream_t stream);
void getMIndices(int64_t* expert_first_token_offset,
int64_t* align_expert_first_token_offset, int* m_indices,
int num_local_expert, const int align_block_size,
cudaStream_t stream);
#include "moe_permute_unpermute_kernel.inl"
#pragma once
template <typename T, bool CHECK_SKIPPED, bool ALIGN_BLOCK_SIZE>
__global__ void expandInputRowsKernel(
T const* unpermuted_input, T* permuted_output,
const float* unpermuted_scales, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row,
int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_dest_rows, int64_t const cols, int64_t k,
int num_local_experts, int align_block_size) {
// Reverse permutation map.
// I do this so that later, we can use the source -> dest map to do the k-way
// reduction and unpermuting. I need the reverse map for that reduction to
// allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
// thread block will be responsible for all k summations.
int64_t expanded_dest_row = blockIdx.x;
int64_t const expanded_source_row =
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
int expert_id = sorted_experts[expanded_dest_row];
extern __shared__ int64_t smem_expert_first_token_offset[];
int64_t align_expanded_row_accumulate = 0;
if constexpr (ALIGN_BLOCK_SIZE) {
// load g2s
for (int idx = threadIdx.x; idx < num_local_experts + 1;
idx += blockDim.x) {
smem_expert_first_token_offset[idx] =
__ldg(expert_first_token_offset + idx);
}
__syncthreads();
int lane_idx = threadIdx.x & 31;
if (lane_idx == 0) {
// set token_offset_in_expert = 0 if this expert is not local expert
int token_offset_in_expert =
expert_id >= num_local_experts
? 0
: expanded_dest_row - smem_expert_first_token_offset[expert_id];
int64_t accumulate_align_offset = 0;
#pragma unroll 1
for (int eidx = 1; eidx <= min(expert_id, num_local_experts); eidx++) {
auto n_token_in_expert = smem_expert_first_token_offset[eidx] -
smem_expert_first_token_offset[eidx - 1];
accumulate_align_offset += (n_token_in_expert + align_block_size - 1) /
align_block_size * align_block_size;
}
expanded_dest_row = accumulate_align_offset + token_offset_in_expert;
}
// lane0 shuffle broadcast align_expanded_dest_row
expanded_dest_row = __shfl_sync(0xffffffff, expanded_dest_row, 0);
}
if (threadIdx.x == 0) {
assert(expanded_dest_row <= INT32_MAX);
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
static_cast<int>(expanded_dest_row);
}
if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) {
// Load 128-bits per thread
constexpr int64_t ELEM_PER_THREAD = 128 / cutlass::sizeof_bits<T>::value;
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
// Duplicate and permute rows
int64_t const source_k_rank = expanded_source_row / num_rows;
int64_t const source_row = expanded_source_row % num_rows;
auto const* source_row_ptr =
reinterpret_cast<DataElem const*>(unpermuted_input + source_row * cols);
auto* dest_row_ptr =
reinterpret_cast<DataElem*>(permuted_output + expanded_dest_row * cols);
int64_t const start_offset = threadIdx.x;
int64_t const stride = blockDim.x;
int64_t const num_elems_in_col = cols / ELEM_PER_THREAD;
for (int elem_index = start_offset; elem_index < num_elems_in_col;
elem_index += stride) {
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
}
}
}
template <typename T>
void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output,
const float* unpermuted_scales, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row,
int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int num_local_experts, const int& align_block_size, cudaStream_t stream) {
int64_t const blocks = num_rows * k;
int64_t const threads = 256;
using FuncPtr = decltype(&expandInputRowsKernel<T, true, true>);
FuncPtr func_map[2][2] = {
{&expandInputRowsKernel<T, false, false>,
&expandInputRowsKernel<T, false, true>},
{&expandInputRowsKernel<T, true, false>,
&expandInputRowsKernel<T, true, true>},
};
bool is_check_skip = num_valid_tokens_ptr != nullptr;
bool is_align_block_size = align_block_size != -1;
auto func = func_map[is_check_skip][is_align_block_size];
int64_t smem_size = sizeof(int64_t) * (num_local_experts + 1);
func<<<blocks, threads, smem_size, stream>>>(
unpermuted_input, permuted_output, unpermuted_scales, sorted_experts,
expanded_dest_row_to_expanded_source_row,
expanded_source_row_to_expanded_dest_row, expert_first_token_offset,
num_rows, num_valid_tokens_ptr, cols, k, num_local_experts,
align_block_size);
}
template <class T, class U>
__host__ __device__ constexpr static U arrayConvert(T const& input) {
using Type = typename U::Element;
static_assert(T::kElements == U::kElements);
U u;
#pragma unroll
for (int i = 0; i < U::kElements; i++) {
u[i] = static_cast<Type>(input[i]);
}
return u;
}
template <typename T, typename OutputType, bool CHECK_SKIPPED>
__global__ void finalizeMoeRoutingKernel(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const orig_cols, int64_t const k,
int64_t const* num_valid_ptr) {
assert(orig_cols % 4 == 0);
int64_t const original_row = blockIdx.x;
int64_t const num_rows = gridDim.x;
auto const offset = original_row * orig_cols;
OutputType* reduced_row_ptr = reduced_unpermuted_output + offset;
int64_t const num_valid = *num_valid_ptr;
// Load 128-bits per thread, according to the smallest data type we read/write
constexpr int64_t FINALIZE_ELEM_PER_THREAD =
128 / std::min(cutlass::sizeof_bits<OutputType>::value,
cutlass::sizeof_bits<T>::value);
int64_t const start_offset = threadIdx.x;
int64_t const stride = blockDim.x;
int64_t const num_elems_in_col = orig_cols / FINALIZE_ELEM_PER_THREAD;
using InputElem = cutlass::Array<T, FINALIZE_ELEM_PER_THREAD>;
using OutputElem = cutlass::Array<OutputType, FINALIZE_ELEM_PER_THREAD>;
using ComputeElem = cutlass::Array<float, FINALIZE_ELEM_PER_THREAD>;
auto const* expanded_permuted_rows_v =
reinterpret_cast<InputElem const*>(expanded_permuted_rows);
auto* reduced_row_ptr_v = reinterpret_cast<OutputElem*>(reduced_row_ptr);
#pragma unroll
for (int elem_index = start_offset; elem_index < num_elems_in_col;
elem_index += stride) {
ComputeElem thread_output;
thread_output.fill(0);
float row_rescale{0.f};
for (int k_idx = 0; k_idx < k; ++k_idx) {
int64_t const expanded_original_row = original_row + k_idx * num_rows;
int64_t const expanded_permuted_row =
expanded_source_row_to_expanded_dest_row[expanded_original_row];
int64_t const k_offset = original_row * k + k_idx;
float const row_scale = scales[k_offset];
// Check after row_rescale has accumulated
if (CHECK_SKIPPED && expanded_permuted_row >= num_valid) {
continue;
}
auto const* expanded_permuted_rows_row_ptr =
expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col;
int64_t const expert_idx = expert_for_source_row[k_offset];
ComputeElem expert_result = arrayConvert<InputElem, ComputeElem>(
expanded_permuted_rows_row_ptr[elem_index]);
thread_output = thread_output + row_scale * (expert_result);
}
OutputElem output_elem =
arrayConvert<ComputeElem, OutputElem>(thread_output);
reduced_row_ptr_v[elem_index] = output_elem;
}
}
template <class T, class OutputType>
void finalizeMoeRoutingKernelLauncher(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const num_rows,
int64_t const cols, int64_t const k, int64_t const* num_valid_ptr,
cudaStream_t stream) {
int64_t const blocks = num_rows;
int64_t const threads = 256;
bool const check_finished = num_valid_ptr != nullptr;
using FuncPtr = decltype(&finalizeMoeRoutingKernel<T, OutputType, false>);
FuncPtr func_map[2] = {&finalizeMoeRoutingKernel<T, OutputType, false>,
&finalizeMoeRoutingKernel<T, OutputType, true>};
auto* const kernel = func_map[check_finished];
kernel<<<blocks, threads, 0, stream>>>(
expanded_permuted_rows, reduced_unpermuted_output, scales,
expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k,
num_valid_ptr);
}
......@@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__
}
}
template <int TPB>
__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
template <int TPB, typename IndType>
__launch_bounds__(TPB) __global__ void moeTopK(
const float* inputs_after_softmax,
const bool* finished,
float* output,
IndType* indices,
int* source_rows,
const int num_experts,
const int k,
const int start_expert,
const int end_expert)
{
using cub_kvp = cub::KeyValuePair<int, float>;
......@@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax
2) This implementation assumes k is small, but will work for any k.
*/
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, typename IndType>
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices,
int* source_rows, const int k, const int start_expert, const int end_expert)
{
// We begin by enforcing compile time assertions and setting up compile time constants.
......@@ -397,8 +405,8 @@ struct TopkConstants
};
} // namespace detail
template <int EXPERTS, int WARPS_PER_TB>
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
template <int EXPERTS, int WARPS_PER_TB, typename IndType>
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
{
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
......@@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
token_expert_indices, num_tokens, topk, 0, num_experts, \
stream);
template <typename IndType>
void topkGatingSoftmaxKernelLauncher(
const float* gating_output,
float* topk_weights,
int* topk_indicies,
IndType* topk_indicies,
int* token_expert_indices,
float* softmax_workspace,
const int num_tokens,
......@@ -493,14 +502,44 @@ void topk_softmax(
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
if(topk_indices.scalar_type() == at::ScalarType::Int)
{
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
}
else if (topk_indices.scalar_type() == at::ScalarType::UInt32)
{
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<uint32_t>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
}
else {
assert(topk_indices.scalar_type() == at::ScalarType::Int64);
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int64_t>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
}
}
......@@ -10,7 +10,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Calculate the result of moe by summing up the partial results
// from all selected experts.
m.def("moe_sum(Tensor! input, Tensor output) -> ()");
m.def("moe_sum(Tensor input, Tensor! output) -> ()");
m.impl("moe_sum", torch::kCUDA, &moe_sum);
// Aligning the number of tokens to be processed by each expert such
......@@ -50,7 +50,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale, Tensor? "
"b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
......@@ -59,8 +60,38 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"int size_m, int size_n, int size_k,"
"bool is_full_k, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
"int b_q_type, SymInt size_m, "
"SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int "
"topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor");
m.def(
"moe_permute(Tensor input, Tensor topk_weight, Tensor! topk_ids,"
"Tensor token_expert_indicies, Tensor? expert_map, int n_expert,"
"int n_local_expert,"
"int topk, int? align_block_size,Tensor! permuted_input, Tensor! "
"expert_first_token_offset, Tensor! src_row_id2dst_row_id_map, Tensor! "
"m_indices)->()");
// conditionally compiled so impl registration is in source file
m.def(
"moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights,"
"Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor "
"expert_first_token_offset, int n_expert, int n_local_expert,int "
"topk, Tensor! hidden_states)->()");
m.def("moe_permute_unpermute_supported() -> bool");
m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported);
// Row shuffle for MoE
m.def(
"shuffle_rows(Tensor input_tensor, Tensor dst2src_map, Tensor! "
"output_tensor) -> ()");
m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows);
#endif
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment