"vllm/model_executor/models/lfm2_moe.py" did not exist on "e97f802b2d74861af77997691a7d1c36498f6dca"
Commit 7a985548 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.9.0' into v0.9.0-ori

parents 45d3785c dc1440cf
#pragma once
#include <cuda_fp8.h>
#define MOE_SWITCH(TYPE, ...) \
at::ScalarType _st = ::detail::scalar_type(TYPE); \
switch (_st) { \
__VA_ARGS__ \
default: \
TORCH_CHECK(false, "[moe permute]data type dispatch fail!") \
}
#define MOE_DISPATCH_CASE(enum_type, ...) \
case enum_type: { \
using scalar_t = ScalarType2CudaType<enum_type>::type; \
__VA_ARGS__(); \
break; \
}
#define MOE_DISPATCH_FLOAT_CASE(...) \
MOE_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
#define MOE_DISPATCH(TYPE, ...) \
MOE_SWITCH(TYPE, MOE_DISPATCH_FLOAT_CASE(__VA_ARGS__))
template <at::ScalarType type>
struct ScalarType2CudaType;
template <>
struct ScalarType2CudaType<at::ScalarType::Float> {
using type = float;
};
template <>
struct ScalarType2CudaType<at::ScalarType::Half> {
using type = half;
};
template <>
struct ScalarType2CudaType<at::ScalarType::BFloat16> {
using type = __nv_bfloat16;
};
// #if __CUDA_ARCH__ >= 890
// fp8
template <>
struct ScalarType2CudaType<at::ScalarType::Float8_e5m2> {
using type = __nv_fp8_e5m2;
};
template <>
struct ScalarType2CudaType<at::ScalarType::Float8_e4m3fn> {
using type = __nv_fp8_e4m3;
};
// #endif
\ No newline at end of file
#include "moe_permute_unpermute_kernel.h"
// CubKeyValueSorter definition begin
CubKeyValueSorter::CubKeyValueSorter()
: num_experts_(0), num_bits_(sizeof(int) * 8) {}
int CubKeyValueSorter::expertsToBits(int num_experts) {
// Max value we represent is V = num_experts + (num_experts - 1) = 2 *
// num_experts - 1 The maximum number of bits is therefore floor(log2(V)) + 1
return static_cast<int>(log2(2 * num_experts - 1)) + 1;
}
CubKeyValueSorter::CubKeyValueSorter(int const num_experts)
: num_experts_(num_experts), num_bits_(expertsToBits(num_experts)) {}
void CubKeyValueSorter::updateNumExperts(int const num_experts) {
num_experts_ = num_experts;
num_bits_ = expertsToBits(num_experts);
}
size_t CubKeyValueSorter::getWorkspaceSize(size_t const num_key_value_pairs,
int const num_experts) {
int num_bits = expertsToBits(num_experts);
size_t required_storage = 0;
int* null_int = nullptr;
cub::DeviceRadixSort::SortPairs(nullptr, required_storage, null_int, null_int,
null_int, null_int, num_key_value_pairs, 0,
num_bits);
// when num_key_value_pairs, num_experts, num_bits, required_storage = 64,
// 4, 3, 0 The required_storage seems to vary between 0 and 1 for the same
// inputs
if (required_storage == 0) {
required_storage = 1;
}
return required_storage;
}
void CubKeyValueSorter::run(void* workspace, size_t const workspace_size,
int const* keys_in, int* keys_out,
int const* values_in, int* values_out,
size_t const num_key_value_pairs,
cudaStream_t stream) {
size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs, num_experts_);
size_t actual_ws_size = workspace_size;
TORCH_CHECK(expected_ws_size <= workspace_size,
"[CubKeyValueSorter::run] The allocated workspace is too small "
"to run this problem.");
cub::DeviceRadixSort::SortPairs(workspace, actual_ws_size, keys_in, keys_out,
values_in, values_out, num_key_value_pairs, 0,
num_bits_, stream);
}
// CubKeyValueSorter definition end
static inline size_t pad_to_multiple_of_16(size_t const& input) {
static constexpr int ALIGNMENT = 16;
return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT);
}
template <class T>
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices,
int64_t const arr_length,
T const target) {
int64_t low = 0, high = arr_length - 1, target_location = -1;
while (low <= high) {
int64_t mid = (low + high) / 2;
if (sorted_indices[mid] >= target) {
high = mid - 1;
} else {
low = mid + 1;
target_location = mid;
}
}
return target_location + 1;
}
// Calculates the start offset of the tokens for a given expert. The last
// element is the total number of valid tokens
__global__ void computeExpertFirstTokenOffsetKernel(
int const* sorted_experts, int64_t const sorted_experts_len,
int const num_experts, int64_t* expert_first_token_offset) {
// First, compute the global tid. We only need 1 thread per expert.
int const expert = blockIdx.x * blockDim.x + threadIdx.x;
// Note that expert goes [0, num_experts] (inclusive) because we want a count
// for the total number of active tokens at the end of the scan.
if (expert >= num_experts + 1) {
return;
}
expert_first_token_offset[expert] =
findTotalEltsLessThanTarget(sorted_experts, sorted_experts_len, expert);
}
void computeExpertFirstTokenOffset(int const* sorted_indices,
int const total_indices,
int const num_experts,
int64_t* expert_first_token_offset,
cudaStream_t stream) {
int const num_entries = num_experts + 1;
int const threads = std::min(1024, num_entries);
int const blocks = (num_entries + threads - 1) / threads;
computeExpertFirstTokenOffsetKernel<<<blocks, threads, 0, stream>>>(
sorted_indices, total_indices, num_experts, expert_first_token_offset);
}
void sortAndScanExpert(int* expert_for_source_row, const int* source_rows,
int* permuted_experts, int* permuted_rows,
int64_t* expert_first_token_offset, int num_rows,
int num_experts, int num_experts_per_node, int k,
CubKeyValueSorter& sorter, void* sorter_ws,
cudaStream_t stream) {
int64_t const expanded_num_rows = static_cast<int64_t>(k) * num_rows;
// We need to use the full num_experts because that is the sentinel value used
// by topk for disabled experts
sorter.updateNumExperts(num_experts);
size_t const sorter_ws_size_bytes = pad_to_multiple_of_16(
sorter.getWorkspaceSize(expanded_num_rows, num_experts));
sorter.run((void*)sorter_ws, sorter_ws_size_bytes, expert_for_source_row,
permuted_experts, source_rows, permuted_rows, expanded_num_rows,
stream);
computeExpertFirstTokenOffset(permuted_experts, expanded_num_rows,
num_experts_per_node, expert_first_token_offset,
stream);
}
__global__ void preprocessTopkIdKernel(int* topk_id_ptr, int size,
const int* expert_map_ptr,
int num_experts) {
auto tidx = threadIdx.x;
auto bidx = blockIdx.x;
auto lidx = tidx & 31;
auto widx = tidx >> 5;
auto warp_count = (blockDim.x + 31) >> 5;
auto offset = bidx * blockDim.x;
auto bound = min(offset + blockDim.x, size);
extern __shared__ int smem_expert_map[];
// store expert_map in smem
for (int i = tidx; i < num_experts; i += blockDim.x) {
smem_expert_map[i] = expert_map_ptr[i];
}
__syncthreads();
// query global expert id in expert map.
// if global expert id = -1 in exert map, plus n_expert
// else set global expert id = exert map[global expert id]
if (offset + tidx < bound) {
auto topk_id = topk_id_ptr[offset + tidx];
auto local_expert_idx = smem_expert_map[topk_id];
if (local_expert_idx == -1) {
topk_id += num_experts;
} else {
topk_id = local_expert_idx;
}
__syncwarp();
topk_id_ptr[offset + tidx] = topk_id;
}
}
void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
const int* expert_map_ptr, int num_experts,
cudaStream_t stream) {
int block = std::min(size, 1024);
int grid = (size + block - 1) / block;
int smem_size = (num_experts) * sizeof(int);
preprocessTopkIdKernel<<<grid, block, smem_size, stream>>>(
topk_id_ptr, size, expert_map_ptr, num_experts);
}
template <bool ALIGN_BLOCK_SIZE>
__global__ void getMIndicesKernel(int64_t* expert_first_token_offset,
int64_t* align_expert_first_token_offset,
int* m_indices, const int num_local_expert,
const int align_block_size) {
int eidx = blockIdx.x;
int tidx = threadIdx.x;
extern __shared__ int64_t smem_expert_first_token_offset[];
for (int i = tidx; i <= num_local_expert; i += blockDim.x) {
smem_expert_first_token_offset[tidx] = __ldg(expert_first_token_offset + i);
}
__syncthreads();
auto last_token_offset = smem_expert_first_token_offset[eidx + 1];
auto first_token_offset = smem_expert_first_token_offset[eidx];
int n_token_in_expert = last_token_offset - first_token_offset;
if constexpr (ALIGN_BLOCK_SIZE) {
n_token_in_expert = (n_token_in_expert + align_block_size - 1) /
align_block_size * align_block_size;
// round up to ALIGN_BLOCK_SIZE
int64_t accumulate_align_offset = 0;
for (int i = 1; i <= eidx + 1; i++) {
int n_token = smem_expert_first_token_offset[i] -
smem_expert_first_token_offset[i - 1];
accumulate_align_offset =
accumulate_align_offset + (n_token + align_block_size - 1) /
align_block_size * align_block_size;
if (i == eidx) {
first_token_offset = accumulate_align_offset;
}
// last block store align_expert_first_token_offset
if (eidx == num_local_expert - 1 && threadIdx.x == 0) {
align_expert_first_token_offset[i] = accumulate_align_offset;
}
}
}
for (int idx = tidx; idx < n_token_in_expert; idx += blockDim.x) {
// update m_indice with expert id
m_indices[first_token_offset + idx] = eidx;
}
}
void getMIndices(int64_t* expert_first_token_offset,
int64_t* align_expert_first_token_offset, int* m_indices,
int num_local_expert, const int align_block_size,
cudaStream_t stream) {
int block = 256;
int grid = num_local_expert;
int smem_size = sizeof(int64_t) * (num_local_expert + 1);
if (align_block_size == -1) {
getMIndicesKernel<false><<<grid, block, smem_size, stream>>>(
expert_first_token_offset, align_expert_first_token_offset, m_indices,
num_local_expert, align_block_size);
} else {
getMIndicesKernel<true><<<grid, block, smem_size, stream>>>(
expert_first_token_offset, align_expert_first_token_offset, m_indices,
num_local_expert, align_block_size);
}
}
\ No newline at end of file
#pragma once
// reference from tensorrt_llm moe kernel implementation archive in
// https://github.com/BBuf/tensorrt-llm-moe/tree/master
#include <c10/core/ScalarType.h>
#include <torch/all.h>
#include "dispatch.h"
#include <cub/cub.cuh>
#include <cub/device/device_radix_sort.cuh>
#include <cub/util_type.cuh>
#include "cutlass/numeric_size.h"
#include "cutlass/array.h"
template <typename T>
inline T* get_ptr(torch::Tensor& t) {
return reinterpret_cast<T*>(t.data_ptr());
}
template <typename T>
inline const T* get_ptr(const torch::Tensor& t) {
return reinterpret_cast<const T*>(t.data_ptr());
}
class CubKeyValueSorter {
public:
CubKeyValueSorter();
CubKeyValueSorter(int const num_experts);
void updateNumExperts(int const num_experts);
static size_t getWorkspaceSize(size_t const num_key_value_pairs,
int const num_experts);
void run(void* workspace, size_t const workspace_size, int const* keys_in,
int* keys_out, int const* values_in, int* values_out,
size_t const num_key_value_pairs, cudaStream_t stream);
private:
static int expertsToBits(int experts);
int num_experts_;
int num_bits_;
};
void computeExpertFirstTokenOffset(int const* sorted_indices,
int const total_indices,
int const num_experts,
int64_t* expert_first_token_offset,
cudaStream_t stream);
void sortAndScanExpert(int* expert_for_source_row, const int* source_rows,
int* permuted_experts, int* permuted_rows,
int64_t* expert_first_token_offset, int num_rows,
int num_experts, int num_experts_per_node, int k,
CubKeyValueSorter& sorter, void* sorter_ws,
cudaStream_t stream);
template <typename T>
void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output,
const float* unpermuted_scales, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row,
int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int num_local_experts, const int& align_block_size, cudaStream_t stream);
// Final kernel to unpermute and scale
// This kernel unpermutes the original data, does the k-way reduction and
// performs the final skip connection.
template <typename T, typename OutputType, bool CHECK_SKIPPED>
__global__ void finalizeMoeRoutingKernel(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const orig_cols, int64_t const k,
int64_t const* num_valid_ptr);
template <class T, class OutputType>
void finalizeMoeRoutingKernelLauncher(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const num_rows,
int64_t const cols, int64_t const k, int64_t const* num_valid_ptr,
cudaStream_t stream);
void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
const int* expert_map_ptr, int num_experts,
cudaStream_t stream);
void getMIndices(int64_t* expert_first_token_offset,
int64_t* align_expert_first_token_offset, int* m_indices,
int num_local_expert, const int align_block_size,
cudaStream_t stream);
#include "moe_permute_unpermute_kernel.inl"
#pragma once
template <typename T, bool CHECK_SKIPPED, bool ALIGN_BLOCK_SIZE>
__global__ void expandInputRowsKernel(
T const* unpermuted_input, T* permuted_output,
const float* unpermuted_scales, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row,
int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_dest_rows, int64_t const cols, int64_t k,
int num_local_experts, int align_block_size) {
// Reverse permutation map.
// I do this so that later, we can use the source -> dest map to do the k-way
// reduction and unpermuting. I need the reverse map for that reduction to
// allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
// thread block will be responsible for all k summations.
int64_t expanded_dest_row = blockIdx.x;
int64_t const expanded_source_row =
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
int expert_id = sorted_experts[expanded_dest_row];
extern __shared__ int64_t smem_expert_first_token_offset[];
int64_t align_expanded_row_accumulate = 0;
if constexpr (ALIGN_BLOCK_SIZE) {
// load g2s
for (int idx = threadIdx.x; idx < num_local_experts + 1;
idx += blockDim.x) {
smem_expert_first_token_offset[idx] =
__ldg(expert_first_token_offset + idx);
}
__syncthreads();
int lane_idx = threadIdx.x & 31;
if (lane_idx == 0) {
// set token_offset_in_expert = 0 if this expert is not local expert
int token_offset_in_expert =
expert_id >= num_local_experts
? 0
: expanded_dest_row - smem_expert_first_token_offset[expert_id];
int64_t accumulate_align_offset = 0;
#pragma unroll 1
for (int eidx = 1; eidx <= min(expert_id, num_local_experts); eidx++) {
auto n_token_in_expert = smem_expert_first_token_offset[eidx] -
smem_expert_first_token_offset[eidx - 1];
accumulate_align_offset += (n_token_in_expert + align_block_size - 1) /
align_block_size * align_block_size;
}
expanded_dest_row = accumulate_align_offset + token_offset_in_expert;
}
// lane0 shuffle broadcast align_expanded_dest_row
expanded_dest_row = __shfl_sync(0xffffffff, expanded_dest_row, 0);
}
if (threadIdx.x == 0) {
assert(expanded_dest_row <= INT32_MAX);
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
static_cast<int>(expanded_dest_row);
}
if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) {
// Load 128-bits per thread
constexpr int64_t ELEM_PER_THREAD = 128 / cutlass::sizeof_bits<T>::value;
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
// Duplicate and permute rows
int64_t const source_k_rank = expanded_source_row / num_rows;
int64_t const source_row = expanded_source_row % num_rows;
auto const* source_row_ptr =
reinterpret_cast<DataElem const*>(unpermuted_input + source_row * cols);
auto* dest_row_ptr =
reinterpret_cast<DataElem*>(permuted_output + expanded_dest_row * cols);
int64_t const start_offset = threadIdx.x;
int64_t const stride = blockDim.x;
int64_t const num_elems_in_col = cols / ELEM_PER_THREAD;
for (int elem_index = start_offset; elem_index < num_elems_in_col;
elem_index += stride) {
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
}
}
}
template <typename T>
void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output,
const float* unpermuted_scales, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row,
int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int num_local_experts, const int& align_block_size, cudaStream_t stream) {
int64_t const blocks = num_rows * k;
int64_t const threads = 256;
using FuncPtr = decltype(&expandInputRowsKernel<T, true, true>);
FuncPtr func_map[2][2] = {
{&expandInputRowsKernel<T, false, false>,
&expandInputRowsKernel<T, false, true>},
{&expandInputRowsKernel<T, true, false>,
&expandInputRowsKernel<T, true, true>},
};
bool is_check_skip = num_valid_tokens_ptr != nullptr;
bool is_align_block_size = align_block_size != -1;
auto func = func_map[is_check_skip][is_align_block_size];
int64_t smem_size = sizeof(int64_t) * (num_local_experts + 1);
func<<<blocks, threads, smem_size, stream>>>(
unpermuted_input, permuted_output, unpermuted_scales, sorted_experts,
expanded_dest_row_to_expanded_source_row,
expanded_source_row_to_expanded_dest_row, expert_first_token_offset,
num_rows, num_valid_tokens_ptr, cols, k, num_local_experts,
align_block_size);
}
template <class T, class U>
__host__ __device__ constexpr static U arrayConvert(T const& input) {
using Type = typename U::Element;
static_assert(T::kElements == U::kElements);
U u;
#pragma unroll
for (int i = 0; i < U::kElements; i++) {
u[i] = static_cast<Type>(input[i]);
}
return u;
}
template <typename T, typename OutputType, bool CHECK_SKIPPED>
__global__ void finalizeMoeRoutingKernel(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const orig_cols, int64_t const k,
int64_t const* num_valid_ptr) {
assert(orig_cols % 4 == 0);
int64_t const original_row = blockIdx.x;
int64_t const num_rows = gridDim.x;
auto const offset = original_row * orig_cols;
OutputType* reduced_row_ptr = reduced_unpermuted_output + offset;
int64_t const num_valid = *num_valid_ptr;
// Load 128-bits per thread, according to the smallest data type we read/write
constexpr int64_t FINALIZE_ELEM_PER_THREAD =
128 / std::min(cutlass::sizeof_bits<OutputType>::value,
cutlass::sizeof_bits<T>::value);
int64_t const start_offset = threadIdx.x;
int64_t const stride = blockDim.x;
int64_t const num_elems_in_col = orig_cols / FINALIZE_ELEM_PER_THREAD;
using InputElem = cutlass::Array<T, FINALIZE_ELEM_PER_THREAD>;
using OutputElem = cutlass::Array<OutputType, FINALIZE_ELEM_PER_THREAD>;
using ComputeElem = cutlass::Array<float, FINALIZE_ELEM_PER_THREAD>;
auto const* expanded_permuted_rows_v =
reinterpret_cast<InputElem const*>(expanded_permuted_rows);
auto* reduced_row_ptr_v = reinterpret_cast<OutputElem*>(reduced_row_ptr);
#pragma unroll
for (int elem_index = start_offset; elem_index < num_elems_in_col;
elem_index += stride) {
ComputeElem thread_output;
thread_output.fill(0);
float row_rescale{0.f};
for (int k_idx = 0; k_idx < k; ++k_idx) {
int64_t const expanded_original_row = original_row + k_idx * num_rows;
int64_t const expanded_permuted_row =
expanded_source_row_to_expanded_dest_row[expanded_original_row];
int64_t const k_offset = original_row * k + k_idx;
float const row_scale = scales[k_offset];
// Check after row_rescale has accumulated
if (CHECK_SKIPPED && expanded_permuted_row >= num_valid) {
continue;
}
auto const* expanded_permuted_rows_row_ptr =
expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col;
int64_t const expert_idx = expert_for_source_row[k_offset];
ComputeElem expert_result = arrayConvert<InputElem, ComputeElem>(
expanded_permuted_rows_row_ptr[elem_index]);
thread_output = thread_output + row_scale * (expert_result);
}
OutputElem output_elem =
arrayConvert<ComputeElem, OutputElem>(thread_output);
reduced_row_ptr_v[elem_index] = output_elem;
}
}
template <class T, class OutputType>
void finalizeMoeRoutingKernelLauncher(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const num_rows,
int64_t const cols, int64_t const k, int64_t const* num_valid_ptr,
cudaStream_t stream) {
int64_t const blocks = num_rows;
int64_t const threads = 256;
bool const check_finished = num_valid_ptr != nullptr;
using FuncPtr = decltype(&finalizeMoeRoutingKernel<T, OutputType, false>);
FuncPtr func_map[2] = {&finalizeMoeRoutingKernel<T, OutputType, false>,
&finalizeMoeRoutingKernel<T, OutputType, true>};
auto* const kernel = func_map[check_finished];
kernel<<<blocks, threads, 0, stream>>>(
expanded_permuted_rows, reduced_unpermuted_output, scales,
expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k,
num_valid_ptr);
}
...@@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__ ...@@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__
} }
} }
template <int TPB> template <int TPB, typename IndType>
__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output, __launch_bounds__(TPB) __global__ void moeTopK(
int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert) const float* inputs_after_softmax,
const bool* finished,
float* output,
IndType* indices,
int* source_rows,
const int num_experts,
const int k,
const int start_expert,
const int end_expert)
{ {
using cub_kvp = cub::KeyValuePair<int, float>; using cub_kvp = cub::KeyValuePair<int, float>;
...@@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax ...@@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax
2) This implementation assumes k is small, but will work for any k. 2) This implementation assumes k is small, but will work for any k.
*/ */
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG> template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, typename IndType>
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices, void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices,
int* source_rows, const int k, const int start_expert, const int end_expert) int* source_rows, const int k, const int start_expert, const int end_expert)
{ {
// We begin by enforcing compile time assertions and setting up compile time constants. // We begin by enforcing compile time assertions and setting up compile time constants.
...@@ -397,8 +405,8 @@ struct TopkConstants ...@@ -397,8 +405,8 @@ struct TopkConstants
}; };
} // namespace detail } // namespace detail
template <int EXPERTS, int WARPS_PER_TB> template <int EXPERTS, int WARPS_PER_TB, typename IndType>
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices, void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream) int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
{ {
static constexpr std::size_t MAX_BYTES_PER_LDG = 16; static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
...@@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f ...@@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
token_expert_indices, num_tokens, topk, 0, num_experts, \ token_expert_indices, num_tokens, topk, 0, num_experts, \
stream); stream);
template <typename IndType>
void topkGatingSoftmaxKernelLauncher( void topkGatingSoftmaxKernelLauncher(
const float* gating_output, const float* gating_output,
float* topk_weights, float* topk_weights,
int* topk_indicies, IndType* topk_indicies,
int* token_expert_indices, int* token_expert_indices,
float* softmax_workspace, float* softmax_workspace,
const int num_tokens, const int num_tokens,
...@@ -493,14 +502,32 @@ void topk_softmax( ...@@ -493,14 +502,32 @@ void topk_softmax(
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output)); const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options()); torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(), if(topk_indices.scalar_type() == at::ScalarType::Int)
topk_weights.data_ptr<float>(), {
topk_indices.data_ptr<int>(), vllm::moe::topkGatingSoftmaxKernelLauncher(
token_expert_indices.data_ptr<int>(), gating_output.data_ptr<float>(),
softmax_workspace.data_ptr<float>(), topk_weights.data_ptr<float>(),
num_tokens, topk_indices.data_ptr<int>(),
num_experts, token_expert_indices.data_ptr<int>(),
topk, softmax_workspace.data_ptr<float>(),
stream); num_tokens,
num_experts,
topk,
stream);
}
else
{
assert(topk_indices.scalar_type() == at::ScalarType::UInt32);
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<uint32_t>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
}
} }
...@@ -44,7 +44,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -44,7 +44,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def( m.def(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none," "moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none," "Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale, Tensor? "
"b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace," "Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids," "Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded," "Tensor! expert_ids, Tensor! num_tokens_past_padded,"
...@@ -53,7 +54,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -53,7 +54,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"int size_m, int size_n, int size_k," "int size_m, int size_n, int size_k,"
"bool is_full_k, bool use_atomic_add," "bool is_full_k, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor"); "bool use_fp32_reduce, bool is_zp_float) -> Tensor");
m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
"int b_q_type, SymInt size_m, "
"SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int "
"topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor");
m.def(
"moe_permute(Tensor input, Tensor topk_weight, Tensor! topk_ids,"
"Tensor token_expert_indicies, Tensor? expert_map, int n_expert,"
"int n_local_expert,"
"int topk, int? align_block_size,Tensor! permuted_input, Tensor! "
"expert_first_token_offset, Tensor! src_row_id2dst_row_id_map, Tensor! "
"m_indices)->()");
m.def(
"moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights,"
"Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor "
"expert_first_token_offset, int n_expert, int n_local_expert,int "
"topk, Tensor! hidden_states)->()");
// conditionally compiled so impl registration is in source file // conditionally compiled so impl registration is in source file
#endif #endif
......
...@@ -59,6 +59,31 @@ void merge_attn_states(torch::Tensor& output, ...@@ -59,6 +59,31 @@ void merge_attn_states(torch::Tensor& output,
const torch::Tensor& prefix_lse, const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output, const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse); const torch::Tensor& suffix_lse);
void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
bool causal);
void convert_vertical_slash_indexes_mergehead(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
torch::Tensor vertical_indices_count, // [N_HEADS, ]
torch::Tensor slash_indices_count, int64_t context_size,
int64_t block_size_M, int64_t block_size_N, bool causal);
#endif #endif
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
...@@ -86,17 +111,20 @@ void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, ...@@ -86,17 +111,20 @@ void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
// std::optional<torch::Tensor> residual); // std::optional<torch::Tensor> residual);
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int64_t head_size, std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox); torch::Tensor& cos_sin_cache, bool is_neox);
void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int64_t head_size, std::optional<torch::Tensor> key,
torch::Tensor& cos_sin_cache, bool is_neox, int64_t head_size, torch::Tensor& cos_sin_cache,
int64_t rot_dim, bool is_neox, int64_t rot_dim,
torch::Tensor& cos_sin_cache_offsets); torch::Tensor& cos_sin_cache_offsets);
void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale);
void mul_and_silu(torch::Tensor& out, torch::Tensor& input); void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
...@@ -177,6 +205,10 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W, ...@@ -177,6 +205,10 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W,
torch::Tensor num_tokens_post_padded, int64_t type, torch::Tensor num_tokens_post_padded, int64_t type,
int64_t row, int64_t top_k, int64_t tokens); int64_t row, int64_t top_k, int64_t tokens);
torch::Tensor ggml_moe_a8_vec(torch::Tensor X, torch::Tensor W,
torch::Tensor topk_ids, int64_t top_k,
int64_t type, int64_t row, int64_t tokens);
int64_t ggml_moe_get_block_size(int64_t type); int64_t ggml_moe_get_block_size(int64_t type);
#ifndef USE_ROCM #ifndef USE_ROCM
...@@ -203,6 +235,12 @@ void cutlass_moe_mm( ...@@ -203,6 +235,12 @@ void cutlass_moe_mm(
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides); torch::Tensor const& b_strides, torch::Tensor const& c_strides);
void cutlass_fp4_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets);
void get_cutlass_moe_mm_data( void get_cutlass_moe_mm_data(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
...@@ -230,6 +268,12 @@ std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a); ...@@ -230,6 +268,12 @@ std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
torch::Tensor& output_scale, torch::Tensor& output_scale,
torch::Tensor const& input_scale); torch::Tensor const& input_scale);
void scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
#endif #endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
......
...@@ -38,12 +38,14 @@ inline __device__ void apply_rotary_embedding( ...@@ -38,12 +38,14 @@ inline __device__ void apply_rotary_embedding(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads, // head_size] or [num_tokens, num_heads,
// head_size] // head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, scalar_t* __restrict__ key, // nullptr or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads, // head_size] or [num_tokens, num_kv_heads,
// head_size] // head_size]
const scalar_t* cache_ptr, const int head_size, const int num_heads, const scalar_t* cache_ptr, const int head_size, const int num_heads,
const int num_kv_heads, const int rot_dim, const int token_idx, const int num_kv_heads, const int rot_dim, const int token_idx,
const int64_t query_stride, const int64_t key_stride) { const int64_t query_stride, const int64_t key_stride,
const int64_t head_stride) {
const int embed_dim = rot_dim / 2; const int embed_dim = rot_dim / 2;
const scalar_t* cos_ptr = cache_ptr; const scalar_t* cos_ptr = cache_ptr;
const scalar_t* sin_ptr = cache_ptr + embed_dim; const scalar_t* sin_ptr = cache_ptr + embed_dim;
...@@ -51,19 +53,23 @@ inline __device__ void apply_rotary_embedding( ...@@ -51,19 +53,23 @@ inline __device__ void apply_rotary_embedding(
const int nq = num_heads * embed_dim; const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) { for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim; const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * query_stride + head_idx * head_size; const int64_t token_head =
token_idx * query_stride + head_idx * head_stride;
const int rot_offset = i % embed_dim; const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>( apply_token_rotary_embedding<scalar_t, IS_NEOX>(
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
} }
const int nk = num_kv_heads * embed_dim; if (key != nullptr) {
for (int i = threadIdx.x; i < nk; i += blockDim.x) { const int nk = num_kv_heads * embed_dim;
const int head_idx = i / embed_dim; for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int head_idx = i / embed_dim;
const int rot_offset = i % embed_dim; const int64_t token_head =
apply_token_rotary_embedding<scalar_t, IS_NEOX>( token_idx * key_stride + head_idx * head_stride;
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
}
} }
} }
...@@ -74,13 +80,15 @@ __global__ void rotary_embedding_kernel( ...@@ -74,13 +80,15 @@ __global__ void rotary_embedding_kernel(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads, // head_size] or [num_tokens, num_heads,
// head_size] // head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, scalar_t* __restrict__ key, // nullptr or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads, // head_size] or [num_tokens, num_kv_heads,
// head_size] // head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
// 2] // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size) { const int64_t head_stride, const int num_heads, const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token. // Each thread block is responsible for one token.
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
...@@ -88,7 +96,7 @@ __global__ void rotary_embedding_kernel( ...@@ -88,7 +96,7 @@ __global__ void rotary_embedding_kernel(
apply_rotary_embedding<scalar_t, IS_NEOX>( apply_rotary_embedding<scalar_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride); token_idx, query_stride, key_stride, head_stride);
} }
template <typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
...@@ -98,15 +106,16 @@ __global__ void batched_rotary_embedding_kernel( ...@@ -98,15 +106,16 @@ __global__ void batched_rotary_embedding_kernel(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads, // head_size] or [num_tokens, num_heads,
// head_size] // head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, scalar_t* __restrict__ key, // nullptr or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads, // head_size] or [num_tokens, num_kv_heads,
// head_size] // head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
// 2] // 2]
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
// or [num_tokens]
const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size) { const int64_t head_stride, const int num_heads, const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token. // Each thread block is responsible for one token.
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
...@@ -116,7 +125,7 @@ __global__ void batched_rotary_embedding_kernel( ...@@ -116,7 +125,7 @@ __global__ void batched_rotary_embedding_kernel(
apply_rotary_embedding<scalar_t, IS_NEOX>( apply_rotary_embedding<scalar_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride); token_idx, query_stride, key_stride, head_stride);
} }
} // namespace vllm } // namespace vllm
...@@ -127,10 +136,12 @@ void rotary_embedding( ...@@ -127,10 +136,12 @@ void rotary_embedding(
// [num_tokens, num_heads * head_size] or // [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or // [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size] // [num_tokens, num_heads, head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or std::optional<torch::Tensor> key,
// [num_tokens, num_kv_heads * head_size] or // null or
// [batch_size, seq_len, num_heads, head_size] or // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_heads, head_size] // [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t head_size, int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) { bool is_neox) {
...@@ -138,40 +149,46 @@ void rotary_embedding( ...@@ -138,40 +149,46 @@ void rotary_embedding(
int64_t num_tokens = positions.numel(); int64_t num_tokens = positions.numel();
int positions_ndim = positions.dim(); int positions_ndim = positions.dim();
// Make sure num_tokens dim is consistent across positions, query, and key. // Make sure num_tokens dim is consistent across positions, query, and key
TORCH_CHECK( TORCH_CHECK(
positions_ndim == 1 || positions_ndim == 2, positions_ndim == 1 || positions_ndim == 2,
"positions must have shape [num_tokens] or [batch_size, seq_len]"); "positions must have shape [num_tokens] or [batch_size, seq_len]");
if (positions_ndim == 1) { if (positions_ndim == 1) {
TORCH_CHECK( TORCH_CHECK(query.size(0) == positions.size(0) &&
query.size(0) == positions.size(0) && key.size(0) == positions.size(0), (!key.has_value() || key->size(0) == positions.size(0)),
"query, key and positions must have the same number of tokens"); "query, key and positions must have the same number of tokens");
} }
if (positions_ndim == 2) { if (positions_ndim == 2) {
TORCH_CHECK( TORCH_CHECK(
query.size(0) == positions.size(0) && query.size(0) == positions.size(0) &&
key.size(0) == positions.size(0) && (!key.has_value() || key->size(0) == positions.size(0)) &&
query.size(1) == positions.size(1) && query.size(1) == positions.size(1) &&
key.size(1) == positions.size(1), (!key.has_value() || key->size(1) == positions.size(1)),
"query, key and positions must have the same batch_size and seq_len"); "query, key and positions must have the same batch_size and seq_len");
} }
// Make sure head_size is valid for query and key // Make sure head_size is valid for query and key
// hidden_size = num_heads * head_size // hidden_size = num_heads * head_size
int query_hidden_size = query.numel() / num_tokens; int query_hidden_size = query.numel() / num_tokens;
int key_hidden_size = key.numel() / num_tokens; int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0;
TORCH_CHECK(query_hidden_size % head_size == 0); TORCH_CHECK(query_hidden_size % head_size == 0);
TORCH_CHECK(key_hidden_size % head_size == 0); TORCH_CHECK(key_hidden_size % head_size == 0);
// Make sure query and key have consistent number of heads // Make sure query and key have consistent number of heads
int num_heads = query_hidden_size / head_size; int num_heads = query_hidden_size / head_size;
int num_kv_heads = key_hidden_size / head_size; int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads;
TORCH_CHECK(num_heads % num_kv_heads == 0); TORCH_CHECK(num_heads % num_kv_heads == 0);
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int seq_dim_idx = positions_ndim - 1; int seq_dim_idx = positions_ndim - 1;
int64_t query_stride = query.stride(seq_dim_idx); int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.stride(seq_dim_idx); int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
// Determine head stride: for [*, heads, head_size] use stride of last dim;
// for flat [*, heads*head_size], heads blocks are contiguous of size
// head_size
int query_ndim = query.dim();
int64_t head_stride =
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512)); dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
...@@ -181,15 +198,16 @@ void rotary_embedding( ...@@ -181,15 +198,16 @@ void rotary_embedding(
if (is_neox) { if (is_neox) {
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>( vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim, key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
query_stride, key_stride, num_heads, num_kv_heads, head_size); cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride, key_stride,
head_stride, num_heads, num_kv_heads, head_size);
} else { } else {
vllm::rotary_embedding_kernel<scalar_t, false> vllm::rotary_embedding_kernel<scalar_t, false>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
rot_dim, query_stride, key_stride, num_heads, num_kv_heads, cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
head_size); key_stride, head_stride, num_heads, num_kv_heads, head_size);
} }
}); });
} }
...@@ -204,10 +222,12 @@ void batched_rotary_embedding( ...@@ -204,10 +222,12 @@ void batched_rotary_embedding(
// [num_tokens, num_heads * head_size] or // [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or // [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size] // [num_tokens, num_heads, head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or std::optional<torch::Tensor>
// [num_tokens, num_kv_heads * head_size] or key, // null or
// [batch_size, seq_len, num_heads, head_size] or // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_heads, head_size] // [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t head_size, int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox, int64_t rot_dim, bool is_neox, int64_t rot_dim,
...@@ -221,38 +241,44 @@ void batched_rotary_embedding( ...@@ -221,38 +241,44 @@ void batched_rotary_embedding(
"cos_sin_cache_offsets"); "cos_sin_cache_offsets");
int positions_ndim = positions.dim(); int positions_ndim = positions.dim();
// Make sure num_tokens dim is consistent across positions, query, and key. // Make sure num_tokens dim is consistent across positions, query, and key
TORCH_CHECK( TORCH_CHECK(
positions_ndim == 1 || positions_ndim == 2, positions_ndim == 1 || positions_ndim == 2,
"positions must have shape [num_tokens] or [batch_size, seq_len]"); "positions must have shape [num_tokens] or [batch_size, seq_len]");
if (positions_ndim == 1) { if (positions_ndim == 1) {
TORCH_CHECK( TORCH_CHECK(query.size(0) == positions.size(0) &&
query.size(0) == positions.size(0) && key.size(0) == positions.size(0), (!key.has_value() || key->size(0) == positions.size(0)),
"query, key and positions must have the same number of tokens"); "query, key and positions must have the same number of tokens");
} }
if (positions_ndim == 2) { if (positions_ndim == 2) {
TORCH_CHECK( TORCH_CHECK(
query.size(0) == positions.size(0) && query.size(0) == positions.size(0) &&
key.size(0) == positions.size(0) && (!key.has_value() || key->size(0) == positions.size(0)) &&
query.size(1) == positions.size(1) && query.size(1) == positions.size(1) &&
key.size(1) == positions.size(1), (!key.has_value() || key->size(1) == positions.size(1)),
"query, key and positions must have the same batch_size and seq_len"); "query, key and positions must have the same batch_size and seq_len");
} }
// Make sure head_size is valid for query and key // Make sure head_size is valid for query and key
int query_hidden_size = query.numel() / num_tokens; int query_hidden_size = query.numel() / num_tokens;
int key_hidden_size = key.numel() / num_tokens; int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0;
TORCH_CHECK(query_hidden_size % head_size == 0); TORCH_CHECK(query_hidden_size % head_size == 0);
TORCH_CHECK(key_hidden_size % head_size == 0); TORCH_CHECK(key_hidden_size % head_size == 0);
// Make sure query and key have concistent number of heads // Make sure query and key have concistent number of heads
int num_heads = query_hidden_size / head_size; int num_heads = query_hidden_size / head_size;
int num_kv_heads = key_hidden_size / head_size; int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads;
TORCH_CHECK(num_heads % num_kv_heads == 0); TORCH_CHECK(num_heads % num_kv_heads == 0);
int seq_dim_idx = positions_ndim - 1; int seq_dim_idx = positions_ndim - 1;
int64_t query_stride = query.stride(seq_dim_idx); int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.stride(seq_dim_idx); int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
// Determine head stride: for [*, heads, head_size] use stride of last dim;
// for flat [*, heads*head_size], heads blocks are contiguous of size
// head_size
int query_ndim = query.dim();
int64_t head_stride =
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512)); dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
...@@ -263,16 +289,18 @@ void batched_rotary_embedding( ...@@ -263,16 +289,18 @@ void batched_rotary_embedding(
vllm::batched_rotary_embedding_kernel<scalar_t, true> vllm::batched_rotary_embedding_kernel<scalar_t, true>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride, cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
key_stride, num_heads, num_kv_heads, head_size); key_stride, head_stride, num_heads, num_kv_heads, head_size);
} else { } else {
vllm::batched_rotary_embedding_kernel<scalar_t, false> vllm::batched_rotary_embedding_kernel<scalar_t, false>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride, cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
key_stride, num_heads, num_kv_heads, head_size); key_stride, head_stride, num_heads, num_kv_heads, head_size);
} }
}); });
} }
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <cmath>
#include "core/math.hpp"
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "quantization/fp8/common.cuh"
namespace vllm {
template <typename T>
__device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x)
return (T)(((float)x) / (1.0f + expf((float)-x)));
}
// Activation and gating kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
typename fp8_type>
__global__ void act_and_mul_quant_kernel(
fp8_type* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const float* scale, const int d) {
const int32_t blocks_per_token = gridDim.y;
const int32_t elems_per_128bit_load = (128 / 8) / sizeof(scalar_t);
// We don't expect the hidden dimension to exceed 32 bits so int32 should
// be safe here.
const int32_t tgt_elems_per_block = div_ceil(d, blocks_per_token);
const int32_t elems_per_block =
round_to_next_multiple_of(tgt_elems_per_block, elems_per_128bit_load);
const int32_t block_start = blockIdx.y * elems_per_block;
int32_t block_end = block_start + elems_per_block;
block_end = block_end > d ? d : block_end;
// token_idx is 64 bit to prevent 32 bit overflow when the number of tokens
// is very large
const int64_t token_idx = blockIdx.x;
const scalar_t* __restrict__ x_ptr = input + token_idx * 2 * d;
const scalar_t* __restrict__ y_ptr = input + token_idx * 2 * d + d;
fp8_type* __restrict__ out_ptr = out + token_idx * d;
// 128-bit vectorized code
const int32_t vec_loop_end =
round_to_previous_multiple_of(elems_per_128bit_load, block_end);
const int32_t vec_end_idx = vec_loop_end / elems_per_128bit_load;
const int32_t vec_start_idx = block_start / elems_per_128bit_load;
const int4* __restrict__ x_128bit_ptr = reinterpret_cast<const int4*>(x_ptr);
const int4* __restrict__ y_128bit_ptr = reinterpret_cast<const int4*>(y_ptr);
int2* __restrict__ out_128bit_ptr = reinterpret_cast<int2*>(out_ptr);
float inverted_scale = 1 / *scale;
#pragma unroll
for (int32_t vec_idx = vec_start_idx + threadIdx.x; vec_idx < vec_end_idx;
vec_idx += blockDim.x) {
const int4 x_128bit = VLLM_LDG(&x_128bit_ptr[vec_idx]);
const int4 y_128bit = VLLM_LDG(&y_128bit_ptr[vec_idx]);
using scalar_128bit_vec_t = std::array<scalar_t, elems_per_128bit_load>;
using scalar_64bit_vec_t = std::array<fp8_type, elems_per_128bit_load>;
scalar_64bit_vec_t out_vec;
const auto x_vec = reinterpret_cast<scalar_128bit_vec_t const&>(x_128bit);
const auto y_vec = reinterpret_cast<scalar_128bit_vec_t const&>(y_128bit);
#pragma unroll
for (int i = 0; i < elems_per_128bit_load; i++) {
out_vec[i] = scaled_fp8_conversion<true, fp8_type>(
ACT_FN(x_vec[i]) * y_vec[i], inverted_scale);
}
out_128bit_ptr[vec_idx] = reinterpret_cast<const int2&>(out_vec);
}
// Scalar cleanup code
if (block_end > vec_loop_end) {
for (int64_t idx = vec_loop_end + threadIdx.x; idx < block_end;
idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
out_ptr[idx] =
scaled_fp8_conversion<true, fp8_type>(ACT_FN(x) * y, inverted_scale);
}
}
}
} // namespace vllm
// Launch activation, gating, and quantize kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens, num_tokens > 16 ? num_tokens > 32 ? 1 : 2 : 4); \
dim3 block(std::min(d, 512)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel", [&] { \
VLLM_DISPATCH_FP8_TYPES( \
out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \
vllm::act_and_mul_quant_kernel<scalar_t, KERNEL<scalar_t>, \
fp8_t> \
<<<grid, block, 0, stream>>>(out.data_ptr<fp8_t>(), \
input.data_ptr<scalar_t>(), \
scale.data_ptr<float>(), d); \
}); \
});
void silu_and_mul_quant(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., 2 * d]
torch::Tensor& scale) {
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn ||
out.dtype() == torch::kFloat8_e4m3fnuz);
TORCH_CHECK(input.dtype() == torch::kFloat16 ||
input.dtype() == torch::kBFloat16);
TORCH_CHECK(input.size(-1) % 2 == 0);
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}
...@@ -26,7 +26,13 @@ static inline __device__ int8_t float_to_int8_rn(float x) { ...@@ -26,7 +26,13 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
float dst = std::nearbyint(x); float dst = std::nearbyint(x);
// saturate // saturate
dst = std::clamp(dst, i8_min, i8_max);
// See https://github.com/pytorch/pytorch/issues/127666
// See https://github.com/llvm/llvm-project/issues/95183
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
// Arch/gcc14. The following replaces std::clamp usage with similar logic
// dst = std::clamp(dst, i8_min, i8_max);
dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst;
return static_cast<int8_t>(dst); return static_cast<int8_t>(dst);
#else #else
// CUDA path // CUDA path
...@@ -79,7 +85,13 @@ static inline __device__ int8_t int32_to_int8(int32_t x) { ...@@ -79,7 +85,13 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
static_cast<int32_t>(std::numeric_limits<int8_t>::max()); static_cast<int32_t>(std::numeric_limits<int8_t>::max());
// saturate // saturate
int32_t dst = std::clamp(x, i8_min, i8_max);
// See https://github.com/pytorch/pytorch/issues/127666
// See https://github.com/llvm/llvm-project/issues/95183
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
// Arch/gcc14. The following replaces std::clamp usage with similar logic
// int32_t dst = std::clamp(x, i8_min, i8_max);
int32_t dst = (x < i8_min) ? i8_min : (x > i8_max) ? i8_max : x;
return static_cast<int8_t>(dst); return static_cast<int8_t>(dst);
#else #else
// CUDA path // CUDA path
......
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(
a.size(0) % 4 == 0,
"Input tensor must have a number of rows that is a multiple of 4. ",
"but got: ", a.size(0), " rows.");
if (out.dtype() == torch::kBFloat16) {
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_gemm_caller.cuh"
namespace vllm {
using namespace cute;
template <typename OutType, typename MmaTileShape, typename ScalesPerTile,
class ClusterShape, typename EpilogueScheduler,
typename MainloopScheduler>
struct cutlass_3x_gemm_fp8_blockwise {
using ElementAB = cutlass::float_e4m3_t;
using ElementA = ElementAB;
using LayoutA = cutlass::layout::RowMajor;
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
using ElementB = ElementAB;
using LayoutB = cutlass::layout::ColumnMajor;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
using ElementC = void;
using ElementD = OutType;
using LayoutD = cutlass::layout::RowMajor;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
using LayoutC = LayoutD;
static constexpr int AlignmentC = AlignmentD;
using ElementAccumulator = float;
using ElementCompute = float;
using ElementBlockScale = float;
// MMA and Cluster Tile Shapes
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster
// Shape %2 == 0 using MmaTileShape_MNK = Shape<_128,_128,_128>;
static constexpr int ScaleMsPerTile = size<0>(ScalesPerTile{});
static constexpr int ScaleGranularityM =
size<0>(MmaTileShape{}) / ScaleMsPerTile;
static constexpr int ScaleGranularityN =
size<1>(MmaTileShape{}) / size<1>(ScalesPerTile{});
static constexpr int ScaleGranularityK =
size<2>(MmaTileShape{}) / size<2>(ScalesPerTile{});
// Shape of the threadblocks in a cluster
using ClusterShape_MNK = ClusterShape;
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
cute::UMMA::Major::MN, cute::UMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
using ArchTag = cutlass::arch::Sm100;
using OperatorClass = cutlass::arch::OpClassTensorOp;
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
using ElementScalar = float;
// clang-format off
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
MmaTileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementCompute,
ElementC,
LayoutC,
AlignmentC,
ElementD,
LayoutD,
AlignmentD,
EpilogueScheduler,
DefaultOperation
>::CollectiveOp;
using StageCountType = cutlass::gemm::collective::StageCountAuto;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementA,
cute::tuple<LayoutA, LayoutSFA>,
AlignmentA,
ElementB,
cute::tuple<LayoutB, LayoutSFB>,
AlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduler
>::CollectiveOp;
// clang-format on
using KernelType = enable_sm100_only<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
struct GemmKernel : public KernelType {};
};
template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideD = typename Gemm::GemmKernel::StrideD;
using StrideC = typename Gemm::GemmKernel::StrideC;
using LayoutSFA = typename Gemm::LayoutSFA;
using LayoutSFB = typename Gemm::LayoutSFB;
using ScaleConfig = typename Gemm::ScaleConfig;
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
auto prob_shape = cute::make_shape(m, n, k, 1);
StrideA a_stride;
StrideB b_stride;
StrideC c_stride;
a_stride =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
b_stride =
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
c_stride =
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
LayoutSFA layout_SFA =
ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
LayoutSFB layout_SFB =
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
typename GemmKernel::MainloopArguments mainloop_args{
a_ptr, a_stride, b_ptr, b_stride,
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB};
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{
{}, c_ptr, c_stride, c_ptr, c_stride};
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
epilogue_args);
}
template <typename OutType>
void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
auto m = a.size(0);
auto k = a.size(1);
auto n = b.size(1);
int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
auto should_use_2sm = [&sms](int m, int n, int tile1SM = 128) {
return std::ceil(static_cast<float>(m) / tile1SM) *
std::ceil(static_cast<float>(n) / tile1SM) >=
sms;
};
bool use_2sm = should_use_2sm(m, n);
if (use_2sm) {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, Shape<_256, _128, _128>, Shape<_256, _1, _1>,
Shape<_2, _2, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
out, a, b, a_scales, b_scales);
} else {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, Shape<_128, _128, _128>, Shape<_128, _1, _1>,
Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm
#include <torch/all.h>
#include "cuda_utils.h"
#include "cutlass_extensions/common.hpp"
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias,
Fp8Func fp8_func, Int8Func int8_func,
BlockwiseFunc blockwise_func) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
int M = a.size(0), N = b.size(1), K = a.size(1);
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (a.dtype() == torch::kFloat8_e4m3fn) {
fp8_func(c, a, b, a_scales, b_scales, bias);
} else {
TORCH_CHECK(a.dtype() == torch::kInt8);
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
int8_func(c, a, b, a_scales, b_scales, bias);
} else {
TORCH_CHECK(false, "Int8 not supported for this architecture");
}
}
} else {
TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
int32_t version_num = get_sm_version_num();
if (version_num >= 100) {
TORCH_CHECK(
a.size(0) == a_scales.size(0) &&
cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1),
"a_scale_group_shape must be [1, 128].");
TORCH_CHECK(
cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) &&
cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1),
"b_scale_group_shape must be [128, 128].");
} else {
// TODO: Remove this after using cutlass sm90 blockwise scaling gemm
// kernel, or introducing ceil_div to the load_init() of mainloop.
using GroupShape = std::array<int64_t, 2>;
auto make_group_shape = [](torch::Tensor const& x,
torch::Tensor const& s) -> GroupShape {
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
cuda_utils::ceil_div(x.size(1), s.size(1))};
};
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
// 1x128 per-token group scales for activations
// 128x128 blockwise scales for weights
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
b_scale_group_shape == GroupShape{128, 128} &&
a.dtype() == torch::kFloat8_e4m3fn &&
b.dtype() == torch::kFloat8_e4m3fn),
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
"a_scale_group_shape must be [1, 128]. Got: [",
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
"]\n"
"b_scale_group_shape must be [128, 128]. Got: [",
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
}
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
blockwise_func(c, a, b, a_scales, b_scales);
}
}
...@@ -36,4 +36,9 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, ...@@ -36,4 +36,9 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias); std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
} // namespace vllm } // namespace vllm
#include <cudaTypedefs.h> #include "c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp" #include "c3x/scaled_mm_kernels.hpp"
#include "cuda_utils.h"
/* /*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm100 (Blackwell). NVIDIA GPUs with sm100 (Blackwell).
...@@ -15,20 +13,10 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a, ...@@ -15,20 +13,10 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) { std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); vllm::cutlass_scaled_mm_sm100_fp8,
nullptr, // int8 not supported on SM100
int M = a.size(0), N = b.size(1), K = a.size(1); vllm::cutlass_scaled_mm_blockwise_sm100_fp8);
TORCH_CHECK(
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
"Currently, only fp8 gemm is implemented for Blackwell");
vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
} }
#endif #endif
#include <cudaTypedefs.h> #include "c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp" #include "c3x/scaled_mm_kernels.hpp"
#include "cuda_utils.h"
/* /*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper). NVIDIA GPUs with sm90a (Hopper).
...@@ -15,49 +13,10 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, ...@@ -15,49 +13,10 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) { std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); vllm::cutlass_scaled_mm_sm90_fp8,
vllm::cutlass_scaled_mm_sm90_int8,
int M = a.size(0), N = b.size(1), K = a.size(1); vllm::cutlass_scaled_mm_blockwise_sm90_fp8);
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (a.dtype() == torch::kFloat8_e4m3fn) {
vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias);
} else {
TORCH_CHECK(a.dtype() == torch::kInt8);
vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias);
}
} else {
using GroupShape = std::array<int64_t, 2>;
auto make_group_shape = [](torch::Tensor const& x,
torch::Tensor const& s) -> GroupShape {
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
cuda_utils::ceil_div(x.size(1), s.size(1))};
};
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
// 1x128 per-token group scales for activations
// 128x128 blockwise scales for weights
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
b_scale_group_shape == GroupShape{128, 128} &&
a.dtype() == torch::kFloat8_e4m3fn &&
b.dtype() == torch::kFloat8_e4m3fn),
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
"a_scale_group_shape must be [1, 128]. Got: [",
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
"]\n"
"b_scale_group_shape must be [128, 128]. Got: [",
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales);
}
} }
void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
......
...@@ -29,7 +29,8 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, ...@@ -29,7 +29,8 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias); std::optional<torch::Tensor> const& bias);
#endif
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
void cutlass_moe_mm_sm90( void cutlass_moe_mm_sm90(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
...@@ -37,12 +38,6 @@ void cutlass_moe_mm_sm90( ...@@ -37,12 +38,6 @@ void cutlass_moe_mm_sm90(
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides); torch::Tensor const& b_strides, torch::Tensor const& c_strides);
void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k);
#endif #endif
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100 #if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
...@@ -53,6 +48,15 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a, ...@@ -53,6 +48,15 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
std::optional<torch::Tensor> const& bias); std::optional<torch::Tensor> const& bias);
#endif #endif
#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \
defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100
void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k);
#endif
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
...@@ -110,6 +114,8 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) { ...@@ -110,6 +114,8 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
#if defined CUDA_VERSION #if defined CUDA_VERSION
if (cuda_device_capability >= 90 && cuda_device_capability < 100) { if (cuda_device_capability >= 90 && cuda_device_capability < 100) {
return CUDA_VERSION >= 12000; return CUDA_VERSION >= 12000;
} else if (cuda_device_capability >= 100) {
return CUDA_VERSION >= 12080;
} }
#endif #endif
...@@ -222,7 +228,8 @@ void get_cutlass_moe_mm_data( ...@@ -222,7 +228,8 @@ void get_cutlass_moe_mm_data(
// This function currently gets compiled only if we have a valid cutlass moe // This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for. // mm to run it for.
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 #if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90)
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, input_permutation, problem_sizes2, input_permutation,
output_permutation, num_experts, n, k); output_permutation, num_experts, n, k);
......
#include <torch/all.h>
#include <cutlass/arch/arch.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include <cassert>
using namespace cute;
template <typename ElementAB, typename ElementC, typename ElementSF,
typename ElementAccumulator, typename LayoutSFA, typename LayoutSFB,
typename ScaleConfig>
__global__ void __get_group_gemm_starts(
ElementAB** a_offsets, ElementAB** b_offsets, ElementC** out_offsets,
ElementSF** a_scales_offsets, ElementSF** b_scales_offsets,
ElementAccumulator** alpha_offsets, LayoutSFA* layout_sfa_base_as_int,
LayoutSFB* layout_sfb_base_as_int, ElementAB* a_base_as_int,
ElementAB* b_base_as_int, ElementC* out_base_as_int,
ElementSF* a_scales_base_as_int, ElementSF* b_scales_base_as_int,
ElementAccumulator* alphas_base_as_int, const int32_t* expert_offsets,
const int32_t* sf_offsets, const int32_t* problem_sizes_as_shapes,
const int K, const int N) {
int64_t expert_id = threadIdx.x;
if (expert_id >= gridDim.x * blockDim.x) {
return;
}
// Originally int32_t but upcasting to int64_t to avoid overflow
// during offset calculations
int64_t expert_offset = static_cast<int64_t>(expert_offsets[expert_id]);
int64_t sf_offset = static_cast<int64_t>(sf_offsets[expert_id]);
// size for block in block scale.
int64_t group_size = 16;
int64_t m = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3]);
int64_t n = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3 + 1]);
int64_t k = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3 + 2]);
assert((m >= 0 && n == N && k == K && k % 2 == 0) &&
"unexpected problem sizes");
int64_t half_k = static_cast<int64_t>(k / 2);
int64_t group_k = static_cast<int64_t>(k / group_size);
// Shape of A as uint8/byte = [M, K // 2]
// Shape of B as uint8/byte = [E, N, K // 2]
a_offsets[expert_id] = a_base_as_int + expert_offset * half_k;
b_offsets[expert_id] = b_base_as_int + expert_id * n * half_k;
// Shape of C = [M, N]
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
// Shape of a_scale = [sum(sf_sizes), K // group_size]
a_scales_offsets[expert_id] = a_scales_base_as_int + sf_offset * group_k;
assert((reinterpret_cast<uintptr_t>(a_scales_offsets[expert_id]) % 128) ==
0 &&
"TMA requires 128-byte alignment");
// Shape of B scale = [E, N, K // group_size]
b_scales_offsets[expert_id] = b_scales_base_as_int + expert_id * n * group_k;
assert((reinterpret_cast<uintptr_t>(b_scales_offsets[expert_id]) % 128) ==
0 &&
"TMA requires 128-byte alignment");
// Shape of alpha = [E]
alpha_offsets[expert_id] = alphas_base_as_int + expert_id;
LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id;
LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id;
*layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), 1));
*layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), 1));
}
#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(ELEMENT_AB_TYPE, SF_TYPE, \
TENSOR_C_TYPE, C_TYPE, LayoutSFA, \
LayoutSFB, ScaleConfig) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
__get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, float, \
LayoutSFA, LayoutSFB, ScaleConfig> \
<<<1, num_experts, 0, stream>>>( \
static_cast<ELEMENT_AB_TYPE**>(a_starts.data_ptr()), \
static_cast<ELEMENT_AB_TYPE**>(b_starts.data_ptr()), \
static_cast<C_TYPE**>(out_starts.data_ptr()), \
static_cast<SF_TYPE**>(a_scales_starts.data_ptr()), \
static_cast<SF_TYPE**>(b_scales_starts.data_ptr()), \
static_cast<float**>(alpha_starts.data_ptr()), \
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), \
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()), \
static_cast<ELEMENT_AB_TYPE*>(a_tensors.data_ptr()), \
static_cast<ELEMENT_AB_TYPE*>(b_tensors.data_ptr()), \
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
static_cast<SF_TYPE*>(a_scales.data_ptr()), \
static_cast<SF_TYPE*>(b_scales.data_ptr()), \
static_cast<float*>(alphas.data_ptr()), \
static_cast<int32_t*>(expert_offsets.data_ptr()), \
static_cast<int32_t*>(sf_offsets.data_ptr()), \
static_cast<int32_t*>(problem_sizes.data_ptr()), K, N); \
}
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
void run_get_group_gemm_starts(
const torch::Tensor& a_starts, const torch::Tensor& b_starts,
const torch::Tensor& out_starts, const torch::Tensor& a_scales_starts,
const torch::Tensor& b_scales_starts, const torch::Tensor& alpha_starts,
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
/*these are used for their base addresses*/
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
torch::Tensor const& out_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& alphas,
torch::Tensor const& expert_offsets, torch::Tensor const& sf_offsets,
torch::Tensor const& problem_sizes, int M, int N, int K) {
int num_experts = (int)expert_offsets.size(0);
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
TORCH_CHECK(out_tensors.size(1) == N,
"Output tensor shape doesn't match expected shape");
TORCH_CHECK(K / 2 == b_tensors.size(2),
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
" dimension must match");
if (false) {
}
//(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB,
// ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(
cutlass::float_e2m1_t, cutlass::float_ue4m3_t, torch::kBFloat16,
cutlass::bfloat16_t, LayoutSFA, LayoutSFB, ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(cutlass::float_e2m1_t,
cutlass::float_ue4m3_t, torch::kFloat16,
half, LayoutSFA, LayoutSFB, ScaleConfig)
else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
}
template <typename OutType>
void run_fp4_blockwise_scaled_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
int N, int K) {
using ProblemShape =
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
using ElementType = cutlass::float_e2m1_t;
using ElementSFType = cutlass::float_ue4m3_t;
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
using ElementC = OutType;
using ElementD = ElementC;
using ElementAccumulator = float;
// Layout definitions
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = LayoutC;
// Alignment constraints
static constexpr int AlignmentA = 32;
static constexpr int AlignmentB = 32;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
// Architecture definitions
using ArchTag = cutlass::arch::Sm100;
using EpilogueOperatorClass =
cutlass::arch::OpClassTensorOp; // Epilogue Operator class tag
using MainloopOperatorClass =
cutlass::arch::OpClassBlockScaledTensorOp; // Mainloop Operator class tag
using StageCountType =
cutlass::gemm::collective::StageCountAuto; // Stage count maximized based
// on the tile size
using ClusterShape = Shape<_1, _1, _1>;
struct MMA1SMConfig {
using MmaTileShape = Shape<_128, _128, _128>;
using KernelSchedule = cutlass::gemm::
KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100; // Kernel to launch
using EpilogueSchedule =
cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch
};
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, EpilogueOperatorClass, typename MMA1SMConfig::MmaTileShape,
ClusterShape, Shape<_128, _64>, ElementAccumulator,
ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD,
LayoutC*, AlignmentD,
typename MMA1SMConfig::EpilogueSchedule>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, MainloopOperatorClass, ElementA, LayoutA*, AlignmentA,
ElementB, LayoutB*, AlignmentB, ElementAccumulator,
typename MMA1SMConfig::MmaTileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
typename MMA1SMConfig::KernelSchedule>::CollectiveOp;
using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
CollectiveEpilogue>;
using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using Gemm = Gemm1SM;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
using LayoutSFA =
typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
using LayoutSFB =
typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
using ScaleConfig =
typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
int num_experts = static_cast<int>(expert_offsets.size(0));
auto options_int =
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
torch::Tensor c_strides1 =
torch::full({num_experts}, output.stride(0), options_int);
torch::Tensor a_strides1 =
torch::full({num_experts}, a.stride(0) * 2, options_int);
torch::Tensor b_strides1 =
torch::full({num_experts}, b.stride(1) * 2, options_int);
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
layout_sfa, layout_sfb, a, b, output, a_blockscale, b_blockscales, alphas,
expert_offsets, sf_offsets, problem_sizes, M, N, K);
// Create an instance of the GEMM
Gemm gemm_op;
// Initialize problem_sizes_as_shapes correctly
UnderlyingProblemShape* problem_sizes_as_shapes =
static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
// Set the Scheduler info
cutlass::KernelHardwareInfo hw_info;
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::
PersistentTileSchedulerSm100GroupParams<
typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
scheduler.raster_order = RasterOrderOptions::AlongM;
hw_info.device_id = a.get_device();
static std::unordered_map<int, int> cached_sm_counts;
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
cached_sm_counts[hw_info.device_id] =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
hw_info.device_id);
}
hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX);
// Mainloop Arguments
typename GemmKernel::MainloopArguments mainloop_args{
static_cast<const ElementType**>(a_ptrs.data_ptr()),
static_cast<StrideA*>(a_strides1.data_ptr()),
static_cast<const ElementType**>(b_ptrs.data_ptr()),
static_cast<StrideB*>(b_strides1.data_ptr()),
static_cast<const ElementSFType**>(a_scales_ptrs.data_ptr()),
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()),
static_cast<const ElementSFType**>(b_scales_ptrs.data_ptr()),
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr())};
// Epilogue Arguments
typename GemmKernel::EpilogueArguments epilogue_args{
{}, // epilogue.thread
nullptr,
static_cast<StrideC*>(c_strides1.data_ptr()),
static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<StrideC*>(c_strides1.data_ptr())};
auto& fusion_args = epilogue_args.thread;
fusion_args.alpha_ptr_array =
reinterpret_cast<float**>(alpha_ptrs.data_ptr());
fusion_args.dAlpha = {_0{}, _0{}, 1};
// Gemm Arguments
typename GemmKernel::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped,
{num_experts, problem_sizes_as_shapes, nullptr},
mainloop_args,
epilogue_args,
hw_info,
scheduler};
size_t workspace_size = Gemm::get_workspace_size(args);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM");
// Run the GEMM
auto status = gemm_op.initialize(args, workspace.data_ptr());
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
status = gemm_op.run(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
#define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
void cutlass_fp4_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) {
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
// Input validation
CHECK_INPUT(a, FLOAT4_E2M1X2, "a");
CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
CHECK_INPUT(a_blockscale, SF_DTYPE, "a_blockscale");
CHECK_INPUT(b_blockscales, SF_DTYPE, "b_blockscales");
CHECK_INPUT(alphas, at::ScalarType::Float, "alphas");
TORCH_CHECK(a_blockscale.dim() == 2,
"expected a_blockscale to be of shape [num_experts, rounded_m,"
" k // group_size], observed rank: ",
a_blockscale.dim())
TORCH_CHECK(b_blockscales.dim() == 3,
"expected b_blockscale to be of shape: "
" [num_experts, n, k // group_size], observed rank: ",
b_blockscales.dim())
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be a 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3,
"problem_sizes must have the shape (num_experts, 3)");
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
"Number of experts in problem_sizes must match expert_offsets");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
"problem_sizes must be int32.");
int M = static_cast<int>(a.size(0));
int N = static_cast<int>(b.size(1));
int E = static_cast<int>(b.size(0));
int K = static_cast<int>(2 * b.size(2));
if (output.scalar_type() == torch::kBFloat16) {
run_fp4_blockwise_scaled_group_mm<cutlass::bfloat16_t>(
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
expert_offsets, sf_offsets, M, N, K);
} else {
run_fp4_blockwise_scaled_group_mm<cutlass::half_t>(
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
expert_offsets, sf_offsets, M, N, K);
}
#else
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_fp4_group_mm kernel, vLLM must "
"be compiled with ENABLE_NVFP4 for SM100+ and CUDA "
"12.8 or above.");
#endif
}
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <cuda_fp8.h>
template <typename T>
struct TypeConverter {
using Type = half2;
}; // keep for generality
template <>
struct TypeConverter<half2> {
using Type = half;
};
template <>
struct TypeConverter<half> {
using Type = half2;
};
template <>
struct TypeConverter<__nv_bfloat162> {
using Type = __nv_bfloat16;
};
template <>
struct TypeConverter<__nv_bfloat16> {
using Type = __nv_bfloat162;
};
#define ELTS_PER_THREAD 8
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
constexpr int CVT_FP4_SF_VEC_SIZE = 16;
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t val;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
".reg .b8 byte1;\n"
".reg .b8 byte2;\n"
".reg .b8 byte3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
: "=r"(val)
: "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]),
"f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7]));
return val;
#else
return 0;
#endif
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t val;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
".reg .b8 byte1;\n"
".reg .b8 byte2;\n"
".reg .b8 byte3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
: "=r"(val)
: "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y),
"f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
return val;
#else
return 0;
#endif
}
// Fast reciprocal.
inline __device__ float reciprocal_approximate_ftz(float a) {
float b;
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
return b;
}
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx,
int numCols,
SFType* SFout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 ||
CVT_FP4_NUM_THREADS_PER_SF == 2);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
// SF vector index (16 elements share one SF in the K dimension).
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
int32_t mIdx = rowIdx;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t mTileIdx = mIdx / (32 * 4);
// SF vector size 16.
int factor = CVT_FP4_SF_VEC_SIZE * 4;
int32_t numKTiles = (numCols + factor - 1) / factor;
int64_t mTileStride = numKTiles * 32 * 4 * 4;
int32_t kTileIdx = (kIdx / 4);
int64_t kTileStride = 32 * 4 * 4;
// M tile layout [32, 4] is column-major.
int32_t outerMIdx = (mIdx % 32);
int64_t outerMStride = 4 * 4;
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
int64_t innerMStride = 4;
int32_t innerKIdx = (kIdx % 4);
int64_t innerKStride = 1;
// Compute the global offset.
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride +
outerMIdx * outerMStride + innerMIdx * innerMStride +
innerKIdx * innerKStride;
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
}
#endif
return nullptr;
}
// Define a 16 bytes packed data type.
template <class Type>
struct PackedVec {
typename TypeConverter<Type>::Type elts[4];
};
template <>
struct PackedVec<__nv_fp8_e4m3> {
__nv_fp8x2_e4m3 elts[8];
};
// Quantizes the provided PackedVec into the uint32_t output
template <class Type, bool UE8M0_SF = false>
__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
uint8_t* SFout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Get absolute maximum values among the local 8 values.
auto localMax = __habs2(vec.elts[0]);
// Local maximum value.
#pragma unroll
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
localMax = __hmax2(localMax, __habs2(vec.elts[i]));
}
// Get the absolute maximum among all 16 values (two threads).
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
// Get the final absolute maximum values.
float vecMax = float(__hmax(localMax.x, localMax.y));
// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
// 8 bits representation of the SF.
uint8_t fp8SFVal;
// Write the SF to global memory (STG.8).
if constexpr (UE8M0_SF) {
// Extract the 8 exponent bits from float32.
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
uint32_t tmp = reinterpret_cast<uint32_t&>(SFValue) >> 23;
fp8SFVal = tmp & 0xff;
// Convert back to fp32.
reinterpret_cast<uint32_t&>(SFValue) = tmp << 23;
} else {
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp;
// Convert back to fp32.
SFValue = float(tmp);
}
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
// reciprocal(SFScaleVal))
float outputScale =
SFValue != 0 ? reciprocal_approximate_ftz(
SFValue * reciprocal_approximate_ftz(SFScaleVal))
: 0.0f;
if (SFout) {
// Write the SF to global memory (STG.8).
*SFout = fp8SFVal;
}
// Convert the input to float.
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
#pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
if constexpr (std::is_same_v<Type, half>) {
fp2Vals[i] = __half22float2(vec.elts[i]);
} else {
fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
}
fp2Vals[i].x *= outputScale;
fp2Vals[i].y *= outputScale;
}
// Convert to e2m1 values.
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
// Write the e2m1 values to global memory.
return e2m1Vec;
#else
return 0;
#endif
}
// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false>
__global__ void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__(512, 4) cvt_fp16_to_fp4(
#else
cvt_fp16_to_fp4(
#endif
int32_t numRows, int32_t numCols, Type const* in, float const* SFScale,
uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts,
uint32_t* output_scale_offset_by_experts, int n_experts) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using PackedVec = PackedVec<Type>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
"Vec size is not matched.");
// Input tensor row/col loops.
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD;
colIdx += blockDim.x) {
int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = inOffset;
auto& out_pos = out[outOffset];
// Find index within the experts.
int rowIdx_in_expert = 0;
int expert_idx = 0;
for (int i = 0; i < n_experts; i++) {
if (rowIdx >= input_offset_by_experts[i] &&
rowIdx < input_offset_by_experts[i + 1]) {
rowIdx_in_expert = rowIdx - input_offset_by_experts[i];
expert_idx = i;
break;
}
}
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
int factor = CVT_FP4_SF_VEC_SIZE * 4;
// The actual output_scales dim is computed from the padded numCols.
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
uint32_t* SFout_in_expert =
SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
auto sf_out =
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
out_pos =
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
}
}
#endif
}
template <typename T>
void quant_impl(void* output, void* output_scale, void* input,
void* input_global_scale, void* input_offset_by_experts,
void* output_scale_offset_by_experts, int m_topk, int k,
int n_experts, cudaStream_t stream) {
// TODO: this multiProcessorCount should be cached.
int device;
cudaGetDevice(&device);
int multiProcessorCount;
cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount,
device);
// Grid, Block size.
// Each thread converts 8 values.
dim3 block(std::min(int(k / ELTS_PER_THREAD), 512));
// Get number of blocks per SM (assume we can fully utilize the SM).
int const numBlocksPerSM = 2048 / block.x;
dim3 grid(std::min(int(m_topk), multiProcessorCount * numBlocksPerSM));
cvt_fp16_to_fp4<T, false><<<grid, block, 0, stream>>>(
m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<float*>(input_global_scale),
reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts), n_experts);
}
/*Quantization entry for fp4 experts quantization*/
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
#define CHECK_INPUT(x, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m);
constexpr auto HALF = at::ScalarType::Half;
constexpr auto BF16 = at::ScalarType::BFloat16;
constexpr auto FLOAT = at::ScalarType::Float;
constexpr auto INT = at::ScalarType::Int;
constexpr auto UINT8 = at::ScalarType::Byte;
void scaled_fp4_experts_quant_sm100a(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
CHECK_INPUT(output, "output must be a CUDA tensor");
CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor");
CHECK_INPUT(input, "input must be a CUDA tensor");
CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor");
CHECK_INPUT(input_offset_by_experts,
"input_offset_by_experts must be a CUDA tensor");
CHECK_INPUT(output_scale_offset_by_experts,
"output_scale_offset_by_experts must be a CUDA tensor");
TORCH_CHECK(output.dim() == 2);
TORCH_CHECK(output_scale.dim() == 2);
TORCH_CHECK(input.dim() == 2);
TORCH_CHECK(input_global_scale.dim() == 1);
TORCH_CHECK(input_offset_by_experts.dim() == 1);
TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
// output is uint8 (two nvfp4 values are packed into one uint8)
// output_scale is int32 (four fp8 values are packed into one int32)
TORCH_CHECK(output.scalar_type() == UINT8);
TORCH_CHECK(output_scale.scalar_type() == INT);
const int BLOCK_SIZE = 16;
auto m_topk = input.size(0);
auto k = input.size(1);
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
auto n_experts = input_global_scale.size(0);
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
TORCH_CHECK(output.size(0) == m_topk);
TORCH_CHECK(output.size(1) == k / 2);
int scales_k = k / BLOCK_SIZE;
// 4 means the swizzle requirement by nvidia nvfp4.
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
// 4 means 4 fp8 values are packed into one int32
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
auto in_dtype = input.dtype();
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
const cudaStream_t stream =
at::cuda::getCurrentCUDAStream(input.get_device());
if (in_dtype == at::ScalarType::Half) {
quant_impl<half>(output.data_ptr(), output_scale.data_ptr(),
input.data_ptr(), input_global_scale.data_ptr(),
input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(), m_topk, k,
n_experts, stream);
} else if (in_dtype == at::ScalarType::BFloat16) {
quant_impl<__nv_bfloat16>(output.data_ptr(), output_scale.data_ptr(),
input.data_ptr(), input_global_scale.data_ptr(),
input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(), m_topk,
k, n_experts, stream);
} else {
TORCH_CHECK(false, "Expected input data type to be half or bfloat16");
}
}
\ No newline at end of file
...@@ -23,10 +23,32 @@ void scaled_fp4_quant_sm100a(torch::Tensor const& output, ...@@ -23,10 +23,32 @@ void scaled_fp4_quant_sm100a(torch::Tensor const& output,
torch::Tensor const& input_sf); torch::Tensor const& input_sf);
#endif #endif
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
void scaled_fp4_experts_quant_sm100a(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
#endif
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
torch::Tensor& output_sf, torch::Tensor const& input_sf) { torch::Tensor& output_sf, torch::Tensor const& input_sf) {
#if defined ENABLE_NVFP4 && ENABLE_NVFP4 #if defined ENABLE_NVFP4 && ENABLE_NVFP4
return scaled_fp4_quant_sm100a(output, input, output_sf, input_sf); return scaled_fp4_quant_sm100a(output, input, output_sf, input_sf);
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization"); TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel");
}
void scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return scaled_fp4_experts_quant_sm100a(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false,
"No compiled nvfp4 experts quantization kernel");
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment