"vscode:/vscode.git/clone" did not exist on "a0f44bb6169dcd6225d2efc0a59dd343a8d4a38e"
Unverified Commit 98229db2 authored by Elvir Crnčević's avatar Elvir Crnčević Committed by GitHub
Browse files

[Kernels][DP/EP] Optimize Silu Kernel for R1 (#24054)


Signed-off-by: default avatarelvircrn <elvircrn@gmail.com>
parent dbeee384
......@@ -133,6 +133,12 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& input_global_scale);
#endif
void silu_mul_fp8_quant_deep_gemm_cuda(
const at::Tensor& input, // (E, T, 2*H)
const at::Tensor& counts, // (E)
at::Tensor& y_q, // (E, T, H) [OUT]
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens);
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
......
......@@ -9,6 +9,26 @@
#include "quantization/fp8/common.cuh"
#include <c10/util/Float8_e4m3fn.h>
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_fp8.h>
typedef __hip_bfloat162 __nv_bfloat162;
typedef __hip_bfloat16 __nv_bfloat16;
typedef __hip_bfloat16_raw __nv_bfloat16_raw;
typedef __hip_fp8_e4m3 __nv_fp8_e4m3;
typedef __hip_fp8x4_e4m3 __nv_fp8x4_e4m3;
#endif
#include "core/registration.h"
namespace vllm {
template <typename T>
......@@ -87,6 +107,337 @@ __global__ void act_and_mul_quant_kernel(
}
}
}
__device__ __forceinline__ float silu(float x) {
return (__fdividef(x, (1.f + expf(-x))));
}
__device__ __forceinline__ float2 silu2(float2 x) {
return make_float2(silu(x.x), silu(x.y));
}
#ifndef USE_ROCM
__device__ __forceinline__ float warp_max(float v) {
static constexpr unsigned FULL_MASK = 0xffffffffu;
for (int offset = 1; offset < WARP_SIZE; offset *= 2) {
v = fmaxf(v, __shfl_xor_sync(FULL_MASK, v, offset));
}
return v;
}
__device__ __forceinline__ __nv_bfloat16 warp_max(__nv_bfloat16 v) {
static constexpr unsigned FULL_MASK = 0xffffffffu;
for (int offset = 1; offset < WARP_SIZE; offset *= 2) {
v = __hmax(v, __shfl_xor_sync(FULL_MASK, v, offset));
}
return v;
}
#endif
template <typename T, typename U>
__device__ __forceinline__ void cp_async4(T* _smem_ptr, const U* _glob_ptr) {
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
auto smem_ptr = reinterpret_cast<void*>(_smem_ptr);
auto glob_ptr = reinterpret_cast<const void*>(_glob_ptr);
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
#else
_smem_ptr[0] = _glob_ptr[0];
#endif
}
__device__ __forceinline__ void cp_async_fence() {
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
asm volatile("cp.async.commit_group;\n" ::);
#else
#endif
}
template <int N>
__device__ __forceinline__ void cp_async_wait() {
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
#else
#endif
}
template <>
__device__ __forceinline__ void cp_async_wait<0>() {
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
asm volatile("cp.async.wait_all;\n" ::);
#else
#endif
}
__device__ __forceinline__ float clip(float v, float mmin, float mmax) {
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
return fminf(mmax, fmaxf(v, mmin));
#else
#endif
}
__device__ __forceinline__ __nv_bfloat16 clip(__nv_bfloat16 v,
__nv_bfloat16 mmin,
__nv_bfloat16 mmax) {
return __hmin(mmax, __hmax(v, mmin));
}
__device__ __forceinline__ __nv_bfloat162 clip(__nv_bfloat162 v,
__nv_bfloat162 mmin,
__nv_bfloat162 mmax) {
return __hmin2(mmax, __hmax2(v, mmin));
}
// We use the following values for fp8 min/max:
// __nv_fp8_e4m3 = (-448, +448)
// __nv_fp8_e4m3uz = (-240.0, +240.0)
// It is currently assumed that only
template <class T>
constexpr __nv_bfloat16 get_fp8_max() {
static_assert(std::is_same_v<T, c10::Float8_e4m3fn> ||
std::is_same_v<T, c10::Float8_e4m3fnuz>);
if constexpr (std::is_same_v<T, c10::Float8_e4m3fn>) {
return __nv_bfloat16(__nv_bfloat16_raw{.x = 17376});
} else {
return __nv_bfloat16(__nv_bfloat16_raw{.x = 17264});
}
}
template <class T>
constexpr __nv_bfloat16 get_fp8_min() {
static_assert(std::is_same_v<T, c10::Float8_e4m3fn> ||
std::is_same_v<T, c10::Float8_e4m3fnuz>);
if constexpr (std::is_same_v<T, c10::Float8_e4m3fn>) {
return __nv_bfloat16(__nv_bfloat16_raw{.x = 50144});
} else {
return __nv_bfloat16(__nv_bfloat16_raw{.x = 50032});
}
}
#ifndef USE_ROCM
template <typename fp8_type, int32_t NUM_WARPS, typename Idx_t,
int NUM_PARALLEL_TOKENS, bool USE_UE8M0, int GROUP_SIZE = 128,
int NUM_STAGES = 3>
__global__ void silu_mul_fp8_quant_deep_gemm_kernel(
const __nv_bfloat16* __restrict__ _input, fp8_type* __restrict__ _y_q,
float* __restrict__ _y_s, const int32_t* __restrict__ counts,
// sizes
int H, int G,
// strides (in elements)
Idx_t stride_i_e, Idx_t stride_i_t, Idx_t stride_i_h, Idx_t stride_yq_e,
Idx_t stride_yq_t, Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t,
Idx_t stride_ys_g, Idx_t stride_counts_e) {
static constexpr __nv_bfloat16 fp8_min = get_fp8_min<fp8_type>();
static constexpr __nv_bfloat16 fp8_max = get_fp8_max<fp8_type>();
// We assign EPS with its 16-bit unsigned counterpart to allow constexpr.
static constexpr __nv_bfloat16 EPS = (__nv_bfloat16_raw{.x = 11996});
// We pack 8 16-bit bfloat16 values into a 128-bit __int128_t.
static constexpr int32_t BFLOAT16_PER_GROUP = 8;
// We split the shared memory in half, corresponding to gate and up matrices:
// [...gate_i, ...up_i] where 0 <= i < stages.
static constexpr int32_t S_NUM_128 =
2u * (GROUP_SIZE / BFLOAT16_PER_GROUP) * NUM_WARPS * NUM_STAGES;
static constexpr auto THREAD_COUNT = NUM_WARPS * WARP_SIZE;
static constexpr int HALF_THREAD_COUNT = THREAD_COUNT / 2;
static constexpr int32_t S_NUM_64 = S_NUM_128 * 2;
__shared__ __int128_t __align__(16) s_buff_128[S_NUM_128];
const int32_t tid = threadIdx.x;
const int32_t warp_id = tid / WARP_SIZE;
const int32_t lane_id = tid % WARP_SIZE;
auto s_buff_compute_32 = reinterpret_cast<__nv_bfloat162*>(s_buff_128);
// block handles one (expert e, group g)
int32_t pid = blockIdx.x;
int32_t e = pid / G;
int32_t g = pid % G;
const int32_t n_tokens = counts[e * stride_counts_e];
if (!n_tokens) {
return; // Exit ASAP.
}
const Idx_t stride_i_t_128 = stride_i_t / 8u;
int32_t n_tokens_lower, n_tokens_upper;
// Each block i iterates over tokens of a slice of n_tokens =
// expert_counts[i], with the size of chunk being
// (n_tokens / NUM_PARALLEL_TOKENS) + residual, instead of
// updiv(n_tokens, NUM_PARALLEL_TOKENS) for better scheduling.
if (n_tokens < NUM_PARALLEL_TOKENS && blockIdx.y < n_tokens) {
// Specialize this, but can be likely fused.
if (blockIdx.y >= NUM_PARALLEL_TOKENS) {
return;
}
n_tokens_lower = blockIdx.y;
n_tokens_upper = blockIdx.y + 1;
} else {
auto chunk_size = n_tokens / NUM_PARALLEL_TOKENS;
auto residual = n_tokens - chunk_size * NUM_PARALLEL_TOKENS;
auto calc_id = [&](int32_t id) {
if (id < residual) {
return min(n_tokens, id * (chunk_size + 1));
} else {
return min(n_tokens, id * chunk_size + residual);
}
};
n_tokens_lower = calc_id(blockIdx.y);
n_tokens_upper = calc_id(blockIdx.y + 1);
}
if (n_tokens_lower >= n_tokens_upper) {
return;
}
// We do calculations here, using constexpr wherever possible.
const Idx_t base_i = e * stride_i_e + NUM_WARPS * g * GROUP_SIZE * stride_i_h;
const Idx_t base_ys = e * stride_ys_e + NUM_WARPS * g * stride_ys_g;
const Idx_t base_yq =
e * stride_yq_e + NUM_WARPS * g * GROUP_SIZE * stride_yq_h;
Idx_t gate_off_128 = (base_i / static_cast<Idx_t>(8u));
auto input_128_ptr = reinterpret_cast<const __int128_t*>(_input);
auto gate_128_ptr = input_128_ptr + gate_off_128 + (tid % HALF_THREAD_COUNT) +
stride_i_t_128 * n_tokens_lower;
auto up_128_ptr = gate_128_ptr + (H * stride_i_h) / 8u;
auto y_s_ptr =
_y_s + base_ys + warp_id * stride_ys_g + n_tokens_lower * stride_ys_t;
auto y_q_ptr = _y_q + base_yq + warp_id * GROUP_SIZE +
stride_yq_t * n_tokens_lower + 4 * lane_id;
int32_t t_load = n_tokens_lower, load_stage_id = 0;
auto s_buff_gate_load_128 = s_buff_128 + (tid % HALF_THREAD_COUNT);
auto s_buff_up_load_128 = s_buff_gate_load_128 + S_NUM_128 / 2u;
int32_t stage_offset{};
static constexpr int32_t LOAD_STAGE_SIZE = (NUM_WARPS * WARP_SIZE / 2);
static constexpr int32_t LOAD_STAGE_MOD =
NUM_STAGES * (NUM_WARPS * WARP_SIZE / 2);
// Two halves of all threads in a block conduct global loads for gate and up,
// repsectively.
auto load_and_advance_y_pred = [&] {
if (t_load < n_tokens_upper) {
auto s_gate_stage_128_staged_ptr = s_buff_gate_load_128 + stage_offset;
auto s_up_stage_128_staged_ptr = s_buff_up_load_128 + stage_offset;
// It is very important that LOAD_STAGE_SIZE is constexpr to avoid
// unnecessary ALU ops.
stage_offset += LOAD_STAGE_SIZE;
stage_offset %= LOAD_STAGE_MOD;
if (tid < HALF_THREAD_COUNT) {
cp_async4(s_gate_stage_128_staged_ptr, gate_128_ptr);
gate_128_ptr += stride_i_t_128;
} else {
cp_async4(s_up_stage_128_staged_ptr, up_128_ptr);
up_128_ptr += stride_i_t_128;
}
++t_load;
++load_stage_id;
}
// We fence even if there is nothing to load to simplify pipelining.
cp_async_fence();
};
#pragma unroll
for (int i = 0; i < NUM_STAGES - 1; i++) {
load_and_advance_y_pred();
}
__int64_t* s_gate_ptr = reinterpret_cast<__int64_t*>(
s_buff_compute_32 + warp_id * (GROUP_SIZE / 2)) +
lane_id;
__int64_t* s_up_ptr = s_gate_ptr + S_NUM_64 / 2;
static constexpr int32_t STAGE_SIZE = (GROUP_SIZE * NUM_WARPS) / 4u;
static constexpr int32_t STAGE_MOD = STAGE_SIZE * NUM_STAGES;
int32_t compute_pipeline_offset_64 = 0;
for (int32_t t = n_tokens_lower; t < n_tokens_upper; ++t) {
__nv_bfloat16 y_max_bf16 = EPS;
__nv_bfloat162 results_bf162[2];
cp_async_wait<NUM_STAGES - 2>();
__syncthreads();
// We double-buffer pipelined loads so that the next load will
// concurrently run with compute without overwrites.
load_and_advance_y_pred();
auto s_gate_compute_64 = s_gate_ptr + compute_pipeline_offset_64;
auto s_up_compute_64 = s_up_ptr + compute_pipeline_offset_64;
// STAGE_SIZE must also be constexpr!
compute_pipeline_offset_64 += STAGE_SIZE;
compute_pipeline_offset_64 %= STAGE_MOD;
// Each thread loads (gate/up) 2X 4X bfloat16 values into registers.
__int64_t gate64 = *s_gate_compute_64;
__nv_bfloat162* s_gate_compute_32 =
reinterpret_cast<__nv_bfloat162*>(&gate64);
__int64_t up64 = *s_up_compute_64;
__nv_bfloat162* s_up_compute_32 = reinterpret_cast<__nv_bfloat162*>(&up64);
#pragma unroll
for (int i = 0; i < 2; i++) {
// For silu, we make sure that div is emitted.
float2 gate = silu2(__bfloat1622float2(s_gate_compute_32[i]));
results_bf162[i] = __float22bfloat162_rn(gate);
}
#pragma unroll
for (int i = 0; i < 2; i++) {
results_bf162[i] = __hmul2(results_bf162[i], s_up_compute_32[i]);
}
auto _y_max2 =
__hmax2(__habs2(results_bf162[0]), __habs2(results_bf162[1]));
y_max_bf16 = __hmax(_y_max2.x, _y_max2.y);
// An entire group is assigned to a single warp, so a simple warp reduce
// is used.
__nv_bfloat16 y_s = warp_max(y_max_bf16) / fp8_max;
if constexpr (USE_UE8M0) {
y_s = hexp2(hceil(hlog2(y_s)));
}
auto inv_y = __float2bfloat16_rn(1.f) / y_s;
auto y_s2 = make_bfloat162(inv_y, inv_y);
#pragma unroll
for (int32_t i = 0; i < 2; ++i) {
results_bf162[i] =
clip(__hmul2(results_bf162[i], y_s2), __bfloat162bfloat162(fp8_min),
__bfloat162bfloat162(fp8_max));
}
auto fp8x4 = __nv_fp8x4_e4m3(results_bf162[0], results_bf162[1]);
*reinterpret_cast<__nv_fp8x4_e4m3*>(y_q_ptr) = fp8x4;
y_q_ptr += stride_yq_t;
if (lane_id == 0) {
*y_s_ptr = y_s;
y_s_ptr += stride_ys_t;
}
}
}
#endif
} // namespace vllm
// Launch activation, gating, and quantize kernel.
......@@ -119,3 +470,117 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d]
TORCH_CHECK(input.size(-1) % 2 == 0);
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}
void silu_mul_fp8_quant_deep_gemm_cuda(
const at::Tensor& input, // (E, T, 2*H)
const at::Tensor& counts, // (E)
at::Tensor& y_q, // (E, T, H) [OUT]
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens) {
#ifndef USE_ROCM
// This kernel relies heavily on cp.async and fp8 support.
// This kernel currently only supports H % 128 == 0 and assumes a
// fixed GROUP_SIZE of 128.
TORCH_CHECK(input.dtype() == torch::kBFloat16);
TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn ||
y_q.dtype() == torch::kFloat8_e4m3fnuz);
TORCH_CHECK(y_s.dtype() == torch::kFloat32);
TORCH_CHECK(input.size(-1) % 256 == 0);
// Check that num_parallel_tokens is of power of 2 and between 1 and 64.
TORCH_CHECK(1 <= num_parallel_tokens && num_parallel_tokens <= 64);
TORCH_CHECK(!(num_parallel_tokens & (num_parallel_tokens - 1)));
using Idx_t = int64_t;
Idx_t E = input.size(0);
Idx_t T = input.size(1);
Idx_t H = input.size(2) / 2;
Idx_t stride_i_e = input.stride(0);
Idx_t stride_i_t = input.stride(1);
Idx_t stride_i_h = input.stride(2);
Idx_t stride_yq_e = y_q.stride(0);
Idx_t stride_yq_t = y_q.stride(1);
Idx_t stride_yq_h = y_q.stride(2);
Idx_t stride_ys_e = y_s.stride(0);
Idx_t stride_ys_t = y_s.stride(1);
Idx_t stride_ys_g = y_s.stride(2);
Idx_t stride_counts_e = counts.stride(0);
static constexpr int GROUP_SIZE = 128;
#define KERNEL_FN \
if (use_ue8m0) { \
vllm::silu_mul_fp8_quant_deep_gemm_kernel<fp8_t, NUM_WARPS, Idx_t, \
NUM_PARALLEL_TOKENS, true> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
reinterpret_cast<int32_t*>(counts.data_ptr<int>()), H, G, \
stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \
stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \
stride_counts_e); \
} else { \
vllm::silu_mul_fp8_quant_deep_gemm_kernel<fp8_t, NUM_WARPS, Idx_t, \
NUM_PARALLEL_TOKENS, false> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
reinterpret_cast<int32_t*>(counts.data_ptr<int>()), H, G, \
stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \
stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \
stride_counts_e); \
}
#define KERNEL_CALL_H \
if (H % (4 * GROUP_SIZE) == 0) { \
static constexpr int NUM_WARPS = 4; \
populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \
KERNEL_FN \
} else { \
static constexpr int NUM_WARPS = 1; \
populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \
KERNEL_FN \
}
#define KERNEL_CALL_TOP_LEVEL \
if (num_parallel_tokens == 1) { \
static constexpr int NUM_PARALLEL_TOKENS = 1; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 2) { \
static constexpr int NUM_PARALLEL_TOKENS = 2; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 4) { \
static constexpr int NUM_PARALLEL_TOKENS = 4; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 8) { \
static constexpr int NUM_PARALLEL_TOKENS = 8; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 16) { \
static constexpr int NUM_PARALLEL_TOKENS = 16; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 32) { \
static constexpr int NUM_PARALLEL_TOKENS = 32; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 64) { \
static constexpr int NUM_PARALLEL_TOKENS = 64; \
KERNEL_CALL_H \
}
Idx_t G;
dim3 block, grid;
auto populate_launch_params = [&](int num_warps, int _num_parallel_tokens) {
G = H / Idx_t(group_size * num_warps);
grid = dim3(E * G, _num_parallel_tokens);
block = dim3(num_warps * WARP_SIZE);
};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
VLLM_DISPATCH_FP8_TYPES(y_q.scalar_type(),
"silu_mul_fp8_quant_deep_gemm_kernel",
[&] { KERNEL_CALL_TOP_LEVEL });
#endif
}
......@@ -32,6 +32,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
#define stride_tag
#endif
ops.def(
"silu_mul_fp8_quant_deep_gemm_cuda(Tensor input, Tensor counts, Tensor! "
"y_q, Tensor! y_s, int group_size, "
"bool use_ue8m0, int num_parallel_tokens) -> ()");
ops.impl("silu_mul_fp8_quant_deep_gemm_cuda", torch::kCUDA,
&silu_mul_fp8_quant_deep_gemm_cuda);
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
......
......@@ -5,28 +5,52 @@ import pytest
import torch
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
silu_mul_fp8_quant_deep_gemm)
silu_mul_fp8_quant_deep_gemm_cuda)
from vllm.platforms import current_platform
from vllm.utils import cdiv
fp8_dtype = torch.float8_e4m3fn
# (E, T, H, group_size, seed)
CASES = [
(1, 1, 128, 64, 0),
(1, 4, 128, 128, 0),
(2, 4, 256, 128, 0),
(32, 64, 256, 128, 0),
(17, 31, 768, 128, 0),
(1, 1, 128, fp8_dtype),
(1, 4, 128, fp8_dtype),
(2, 4, 256, fp8_dtype),
(32, 64, 256, fp8_dtype),
(17, 31, 768, fp8_dtype),
(1, 1, 128 * 1, fp8_dtype),
(1, 1, 128 * 2, fp8_dtype),
(1, 1, 128 * 3, fp8_dtype),
(1, 1, 128 * 4, fp8_dtype),
(8, 16, 128 * 1, fp8_dtype),
(8, 16, 128 * 2, fp8_dtype),
(8, 16, 128 * 3, fp8_dtype),
(8, 16, 128 * 4, fp8_dtype),
(8, 64, 7168, fp8_dtype),
(8, 128, 7168, fp8_dtype),
(8, 256, 7168, fp8_dtype),
(8, 512, 7168, fp8_dtype),
(8, 1024, 7168, fp8_dtype),
(256, 8, 7168, fp8_dtype),
(256, 16, 7168, fp8_dtype),
(256, 32, 7168, fp8_dtype),
(256, 64, 7168, fp8_dtype),
# Only add a few fnuz tests to help with long CI times.
(8, 512, 7168, torch.float8_e4m3fnuz),
(8, 1024, 7168, torch.float8_e4m3fnuz),
]
@pytest.mark.parametrize("E,T,H,group_size,seed", CASES)
@pytest.mark.parametrize("E,T,H,fp8_type", CASES)
@torch.inference_mode()
def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed):
current_platform.seed_everything(seed)
def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type):
group_size = 128
current_platform.seed_everything(42)
# Input tensor of shape (E, T, 2*H)
y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda")
tokens_per_expert = torch.randint(
low=0,
low=T // 2,
high=T,
size=(E, ),
dtype=torch.int32,
......@@ -34,45 +58,59 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed):
)
# Run the Triton kernel
y_q, y_s = silu_mul_fp8_quant_deep_gemm(y,
tokens_per_expert,
group_size=group_size,
eps=1e-10)
y_q, y_s = silu_mul_fp8_quant_deep_gemm_cuda(y,
tokens_per_expert,
group_size=group_size)
# Reference implementation
fp8_info = torch.finfo(torch.float8_e4m3fn)
torch.cuda.synchronize()
fp8_info = torch.finfo(fp8_dtype)
fp8_max = fp8_info.max
fp8_min = fp8_info.min
eps = 1e-10
# Compute silu activation and elementwise multiplication
y1 = y[..., :H]
y1 = y[..., :H].float()
y2 = y[..., H:]
silu_x = y1 * torch.sigmoid(y1)
merged = silu_x * y2
# Compute reference scales and quantized output, skipping padded tokens
for e in range(E):
nt = tokens_per_expert[e].item()
ref_s = torch.empty((T, H // group_size),
ref_s = torch.empty((T, cdiv(H, group_size)),
dtype=torch.float32,
device="cuda")
ref_q = torch.empty((T, H), dtype=torch.float8_e4m3fn, device="cuda")
ref_q = torch.empty((T, H), dtype=fp8_dtype, device="cuda")
for t in range(nt):
data = merged[e, t]
data_grp = data.view(H // group_size, group_size)
amax = data_grp.abs().amax(dim=1).clamp(min=eps)
scale = amax / fp8_max
data = merged[e, t].float()
ref_q_row = torch.empty_like(data)
# process full groups
n_full_groups = H // group_size
if n_full_groups > 0:
data_grp = data[:n_full_groups * group_size].view(
n_full_groups, group_size)
amax = data_grp.abs().amax(dim=1).clamp(min=eps)
scale = amax / fp8_max
scaled = data[:n_full_groups *
group_size] / scale.repeat_interleave(group_size)
ref_q_row[:n_full_groups * group_size] = scaled.clamp(
fp8_min, fp8_max).to(fp8_dtype)
ref_s[t, :n_full_groups] = scale
scaled = data / scale.repeat_interleave(group_size)
clamped = scaled.clamp(fp8_min, fp8_max)
q = clamped.to(torch.float8_e4m3fn)
# process remainder group
rem = H % group_size
if rem > 0:
data_rem = data[-rem:]
amax = data_rem.abs().amax().clamp(min=eps)
scale = amax / fp8_max
scaled = data_rem / scale
ref_q_row[-rem:] = scaled.clamp(fp8_min, fp8_max).to(fp8_dtype)
ref_s[t, -1] = scale
ref_s[t] = scale
ref_q[t] = q
ref_q[t] = ref_q_row
y_se = y_s[e]
y_qe = y_q[e]
y_se = y_s[e].float()
y_qe = y_q[e].float()
torch.testing.assert_close(y_se[:nt], ref_s[:nt], atol=1e-4, rtol=1e-2)
torch.testing.assert_close(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from math import log2
from typing import Optional
import torch
......@@ -10,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked,
is_deep_gemm_e8m0_used)
......@@ -24,35 +26,28 @@ def _silu_mul_fp8_quant_deep_gemm(
y_q_ptr, # fp8 quantized activations (E, T, H)
y_s_ptr, # 16-bit scales (E, T, G)
counts_ptr, # int32 num tokens per expert (E)
# Sizes ---------------------------------------------------------------
H: tl.constexpr, # hidden dimension (per output)
GROUP_SIZE: tl.constexpr, # elements per group (usually 128)
# Strides for input (elements) ---------------------------------------
stride_i_e,
stride_i_t,
stride_i_h,
# Strides for y_q (elements) -----------------------------------------
stride_yq_e,
stride_yq_t,
stride_yq_h,
# Strides for y_s (elements) -----------------------------------------
stride_ys_e,
stride_ys_t,
stride_ys_g,
# Stride for counts (elements)
stride_counts_e,
# Numeric params ------------------------------------------------------
eps: tl.constexpr,
fp8_min: tl.constexpr,
fp8_max: tl.constexpr,
use_ue8m0: tl.constexpr,
# Meta ---------------------------------------------------------------
BLOCK: tl.constexpr,
NUM_STAGES: tl.constexpr,
......@@ -101,17 +96,15 @@ def _silu_mul_fp8_quant_deep_gemm(
tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s)
def silu_mul_fp8_quant_deep_gemm(
def silu_mul_fp8_quant_deep_gemm_cuda(
y: torch.Tensor, # (E, T, 2*H)
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
num_parallel_tokens=16,
group_size: int = 128,
eps: float = 1e-10,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
y has shape (E, T, 2*H). The first half of the last dimension is
y has shape (E, T, 2*H). The first half of the last dimension is
silu-activated, multiplied by the second half, then quantized into FP8.
Returns `(y_q, y_s)` where
* `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
* `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
......@@ -120,22 +113,17 @@ def silu_mul_fp8_quant_deep_gemm(
E, T, H2 = y.shape
assert H2 % 2 == 0, "last dim of y must be even (2*H)"
H = H2 // 2
G = H // group_size
assert H % group_size == 0, "H must be divisible by group_size"
assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, \
"tokens_per_expert must be shape (E,)"
G = (H + group_size - 1) // group_size
assert H % 8 == 0, "H must be divisible by 8"
assert group_size == 128, "H must be divisible by 8"
assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E
tokens_per_expert = tokens_per_expert.to(device=y.device,
dtype=torch.int32)
# allocate outputs
fp8_dtype = torch.float8_e4m3fn
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
# strides (elements)
stride_i_e, stride_i_t, stride_i_h = y.stride()
stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride()
# desired scale strides (elements): (T*G, 1, T)
stride_ys_e = T * G
stride_ys_t = 1
stride_ys_g = T
......@@ -144,47 +132,86 @@ def silu_mul_fp8_quant_deep_gemm(
dtype=torch.float32,
device=y.device)
stride_cnt_e = tokens_per_expert.stride()[0]
# Static grid over experts and H-groups.
# A loop inside the kernel handles the token dim
grid = (E * G, )
f_info = torch.finfo(fp8_dtype)
fp8_max = f_info.max
fp8_min = f_info.min
_silu_mul_fp8_quant_deep_gemm[grid](
y,
y_q,
y_s,
tokens_per_expert,
H,
group_size,
stride_i_e,
stride_i_t,
stride_i_h,
stride_yq_e,
stride_yq_t,
stride_yq_h,
stride_ys_e,
stride_ys_t,
stride_ys_g,
stride_cnt_e,
eps,
fp8_min,
fp8_max,
is_deep_gemm_e8m0_used(),
BLOCK=group_size,
NUM_STAGES=4,
num_warps=1,
)
use_ue8m0 = is_deep_gemm_e8m0_used()
if E <= 16:
max_empirical_parallelism = 64
elif E <= 32:
max_empirical_parallelism = 16
else:
max_empirical_parallelism = 4
# We never want to launch more than Tx number of threads
# This computes the clip.
num_parallel_tokens = max(
1,
min(max_empirical_parallelism, 2**int(log2(min(num_parallel_tokens,
T)))))
cuda_arch = current_platform.get_device_capability(
device_id=y.device.index).to_int()
if cuda_arch >= 80:
torch.ops._C.silu_mul_fp8_quant_deep_gemm_cuda(y, tokens_per_expert,
y_q, y_s, group_size,
use_ue8m0,
num_parallel_tokens)
else:
# Default to triton if not on cuda or if arch is too old
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
stride_cnt_e = tokens_per_expert.stride()[0]
# Static grid over experts and H-groups.
# A loop inside the kernel handles the token dim
grid = (E * G, )
# strides (elements)
stride_i_e, stride_i_t, stride_i_h = y.stride()
stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride()
# desired scale strides (elements): (T*G, 1, T)
stride_ys_e = T * G
stride_ys_t = 1
stride_ys_g = T
y_s = torch.empty_strided(
(E, T, G),
(stride_ys_e, stride_ys_t, stride_ys_g),
dtype=torch.float32,
device=y.device,
)
f_info = torch.finfo(fp8_dtype)
fp8_max = f_info.max
fp8_min = f_info.min
eps: float = 1e-10
_silu_mul_fp8_quant_deep_gemm[grid](
y,
y_q,
y_s,
tokens_per_expert,
H,
group_size,
stride_i_e,
stride_i_t,
stride_i_h,
stride_yq_e,
stride_yq_t,
stride_yq_h,
stride_ys_e,
stride_ys_t,
stride_ys_g,
stride_cnt_e,
eps,
fp8_min,
fp8_max,
is_deep_gemm_e8m0_used(),
BLOCK=group_size,
NUM_STAGES=4,
num_warps=1,
)
return y_q, y_s
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# The Deep Gemm kernels only support block size of 128
DEEPGEMM_BLOCK_SHAPE: list[int] = [128, 128]
......@@ -297,8 +324,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale),
workspace1, expert_num_tokens, expected_m)
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
expert_num_tokens)
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm_cuda(
workspace1, expert_num_tokens)
fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), output,
expert_num_tokens, expected_m)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment