Commit 3fb4b5fa authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.18.0' into v0.18.0-ori

parents bcf25339 89138b21
......@@ -263,12 +263,10 @@ void get_cutlass_moe_mm_data_caller(
}
template <bool SWAP_AB>
__global__ void compute_pplx_data(int32_t* expert_offsets,
int32_t* problem_sizes1,
int32_t* problem_sizes2,
const int32_t* __restrict__ expert_num_tokens,
const int padded_m, const int n,
const int k) {
__global__ void compute_batched_moe_data(
int32_t* expert_offsets, int32_t* problem_sizes1, int32_t* problem_sizes2,
const int32_t* __restrict__ expert_num_tokens, const int padded_m,
const int n, const int k) {
int expert_idx = threadIdx.x;
expert_offsets[expert_idx] = expert_idx * padded_m;
......@@ -289,24 +287,22 @@ __global__ void compute_pplx_data(int32_t* expert_offsets,
}
}
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m,
const int64_t n, const int64_t k) {
void get_cutlass_batched_moe_mm_data_caller(
torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k) {
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
if (num_local_experts * padded_m > SWAP_AB_THRESHOLD) {
compute_pplx_data<false><<<1, num_local_experts, 0, stream>>>(
compute_batched_moe_data<false><<<1, num_local_experts, 0, stream>>>(
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
k);
} else {
compute_pplx_data<true><<<1, num_local_experts, 0, stream>>>(
compute_batched_moe_data<true><<<1, num_local_experts, 0, stream>>>(
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
......
......@@ -82,13 +82,11 @@ void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
const int64_t n, const int64_t k, const bool swap_ab);
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m,
const int64_t n, const int64_t k);
void get_cutlass_batched_moe_mm_data_caller(
torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k);
#endif
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
......@@ -319,29 +317,30 @@ void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
version_num, ". Required capability: 90, 100, or 120");
}
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m, const int64_t n,
const int64_t k) {
void get_cutlass_batched_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m, const int64_t n,
const int64_t k) {
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1,
problem_sizes2, expert_num_tokens,
num_local_experts, padded_m, n, k);
get_cutlass_batched_moe_mm_data_caller(expert_offsets, problem_sizes1,
problem_sizes2, expert_num_tokens,
num_local_experts, padded_m, n, k);
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
"for CUDA device capability: ",
version_num, ". Required capability: 90, 100, or 120");
TORCH_CHECK_NOT_IMPLEMENTED(false,
"No compiled get_cutlass_batched_moe_mm_data: no "
"cutlass_scaled_mm kernel "
"for CUDA device capability: ",
version_num,
". Required capability: 90, 100, or 120");
}
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
......
......@@ -379,7 +379,9 @@ void per_token_group_quant_8bit_packed(const torch::Tensor& input,
void per_token_group_quant_fp8(const torch::Tensor& input,
torch::Tensor& output_q, torch::Tensor& output_s,
int64_t group_size, double eps, double fp8_min,
double fp8_max, bool scale_ue8m0) {
double fp8_max, bool scale_ue8m0,
bool dummy_is_scale_transposed = false,
bool dummy_is_tma_aligned = false) {
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
fp8_min, fp8_max, scale_ue8m0);
}
\ No newline at end of file
......@@ -12,6 +12,7 @@
#include "../cuda_compat.h"
#include "dispatch_utils.h"
#include "quantization/w8a8/fp8/common.cuh"
#include "core/batch_invariant.hpp"
// TODO(rasmith): The kernels in this file are susceptible to integer overflow
// issues, do not take strides, and are unable to handle PyTorch tensors that
......@@ -304,8 +305,9 @@ __device__ inline unsigned int min__(uint32_t a, uint32_t b) {
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_sml_(const int K, const int M, const int Bx, const int By,
const scalar_t* B, const scalar_t* __restrict__ A,
wvSplitK_hf_sml_(const int K, const int Kbp, const int Kap, const int M,
const int Bx, const int By, const scalar_t* B,
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
constexpr int max_lds_len = LDS_SIZE / 2;
......@@ -314,7 +316,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else
constexpr bool use_mfma = false;
#endif
using scalar8 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
using half4 =
......@@ -346,13 +347,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// - Then the WG will move to another 8 K elements
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for (uint32_t k = 0; k < min__(K * N, max_lds_len);
k += THRDS * WvPrGrp * A_CHUNK) {
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
if (k_in >= min__(K * N, max_lds_len)) break;
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
#if defined(__gfx950__)
__builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0);
#else
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
#endif
}
__syncthreads();
......@@ -360,9 +361,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
float sum[N][YTILE];
scalar8 sum4[N][YTILE];
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
// There are 16 waves per WG, and hence, each WG is
......@@ -386,44 +384,20 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// YTILE represents how many column of weight matrix
// are being worked on by each wave.
//----------------------------------------------------
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++)
if constexpr (!use_mfma)
sum[n][i] = 0;
else
sum4[n][i] = {0, 0, 0, 0};
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
//----------------------------------------------------
// Fetch weight matrix B in interleaved K-split!
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements (1024B)
// - YTILE represents the number of column being serviced
// by wave
// - Loop for fetching weight matrix (B) are unrolled
//
// Fetch activation matrix A from LDS
// - Loop for fetching activation matrix (A) are unrolled
//
// Finally, do the matrix multiplication in an unrolled
// fashion. This provides lot of food for compiler
// scheduling.
//
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
// for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
float sum[N][YTILE] = {};
scalar8 sum4[N][YTILE] = {};
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
bigType bigA[N][UNRL] = {};
bigType bigB[YTILE][UNRL];
// Fetch the weight matrix from memory!
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
const scalar_t* B_ = &B[(m + 0) * K + k_];
const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)];
for (int y = 0; y < YTILE; y++)
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * K])));
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
......@@ -432,33 +406,20 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
// Fetch A activation matrix in interleaved fashion from LDS or memory
for (int n = 0; n < N; n++) {
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n])));
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for (uint32_t n = 0; n < N; n++) {
#pragma unroll
for (int y = 0; y < YTILE; y++) {
if constexpr (!use_mfma)
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
}
else
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
......@@ -466,46 +427,44 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}
}
__builtin_amdgcn_sched_barrier(0);
//----------------------------------------------------
// Final reduction step using shuffle
//----------------------------------------------------
if constexpr (!use_mfma) {
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x118, 0xf, 0xf,
1); // row_shr8
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x114, 0xf, 0xf,
1); // row_shr4
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x112, 0xf, 0xf,
1); // row_shr2
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf,
1); // row_shr1
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf,
1); // ROW_BCAST15
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf,
1); // ROW_BCAST31
}
}
if (threadIdx.x == 63) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
}
}
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
for (int y = 0; y < YTILE; y++) {
if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS)
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
sum[n][y] += __half2float(biases[n][y]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS)
sum[n][i] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
sum[n][y] += __bfloat162float(biases[n][y]);
}
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y]);
}
}
}
......@@ -514,45 +473,43 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (int n = 0; n < N; n++) {
#pragma unroll
for (int y = 0; y < YTILE; y++) {
// float accm1 = 0;
// for (int i=0; i<64; i++)
// accm1 += __shfl(sum4[n][y][i%4], i);
/*float accm1 = 0;
for (int i=0; i<64; i++)
accm1 += __shfl(sum4[n][y][i%4], i);
sum4[n][y][0] = accm1;*/
float accm = sum4[n][y][0];
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][1]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][2]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][3]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][1], 0x101, 0xf, 0xf,
1); // row_shl1
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][2], 0x102, 0xf, 0xf,
1); // row_shl2
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][3], 0x103, 0xf, 0xf,
1); // row_shl3
accm += __builtin_amdgcn_mov_dpp(accm, 0x104, 0xf, 0xf,
1); // row_shl4
accm += __builtin_amdgcn_mov_dpp(accm, 0x108, 0xf, 0xf,
1); // row_shl8
accm = __builtin_amdgcn_mov_dpp(accm, 0x11f, 0xf, 0xf,
1); // row_shr15
accm += __builtin_amdgcn_mov_dpp(accm, 0x142, 0xf, 0xf,
1); // ROW_BCAST15
accm += __builtin_amdgcn_mov_dpp(accm, 0x143, 0xf, 0xf,
1); // ROW_BCAST31
sum4[n][y][0] = accm;
}
}
if (threadIdx.x == 63) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
}
}
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (BIAS)
sum4[n][i][0] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
for (int y = 0; y < YTILE; y++) {
sum4[n][y][0] += __bfloat162float(biases[n][y]);
C[m + y + n * M] = __float2bfloat16(sum4[n][y][0]);
}
}
}
......@@ -563,8 +520,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_sml_(const int K, const int M, const int Bx,
const int By, const scalar_t* B,
__global__ void wvSplitK_hf_sml_(const int K, const int Kbp, const int Kap,
const int M, const int Bx, const int By,
const scalar_t* B,
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
......@@ -577,8 +535,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int M, const int Bx,
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_(const int K, const int M, const int Bx, const int By,
const scalar_t* B, const scalar_t* __restrict__ A,
wvSplitK_hf_(const int K, const int Kbp, const int Kap, const int M,
const int Bx, const int By, const scalar_t* B,
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
constexpr int max_lds_len = LDS_SIZE / 2;
......@@ -601,13 +560,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
scalar8 h8;
};
//----------------------------------------------------
// Reserving 64 KB of LDS to have 1 WG / CU
// Goal is to bring the activation matrix A to the LDS
// and use it across the lifetime of the work group
// TODO: When activation matrix is larger than 64 KB
// then this is not going to work!
//----------------------------------------------------
__shared__ scalar_t s[max_lds_len];
//----------------------------------------------------
......@@ -618,12 +570,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
commitColumn[i] = 1;
}
//----------------------------------------------------
// Indexing function into the column of weight matrix B
// Algorithm does 64 lane k-splitting / wave and uses
// WG ID and Thread ID to find the index.
//----------------------------------------------------
// int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp);
uint32_t m = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE;
// Check whether there will be fragmentation!
......@@ -636,91 +582,34 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
m = startColumn;
}
//----------------------------------------------------
// Fetch the activation matrix to LDS
// Loop iteration:
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements
// - Each WG will fetch 512 * 16 => 8K elements
// - Then the WG will move to another 8 K elements
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for (uint32_t k = 0; k < min__(K * N, max_lds_len);
k += THRDS * WvPrGrp * A_CHUNK) {
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
if (k_in >= min__(K * N, max_lds_len)) break;
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
#if defined(__gfx950__)
__builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0);
#else
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
#endif
}
__syncthreads();
if (threadIdx.y >= _WvPrGrp) return;
float sum[N][YTILE];
scalar8 sum4[N][YTILE];
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
// There are 16 waves per WG, and hence, each WG is
// working on 16 columns of weight matrix. Moreover,
// we tile in column direction by YTILE, so when YTILE=1
// the above math is right, however, when YTILE=2 then
// each wave will be working on 2 columns and WG will
// be working on 32 columns.
//
// Top level loop that makes WGs persistent!
// - WGs iterates across columns of weight matrix
// - Each wave within WG works on a given column(s)
// - After completing first set of columns, WGs start
// working on the next set of available columns
//----------------------------------------------------
while (m < M) {
//----------------------------------------------------
// 'sum' accumulates the matrix A x B computation
// split across 64 lanes.
//
// YTILE represents how many column of weight matrix
// are being worked on by each wave.
//----------------------------------------------------
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++)
if constexpr (!use_mfma)
sum[n][i] = 0;
else
sum4[n][i] = {0, 0, 0, 0};
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
//----------------------------------------------------
// Fetch weight matrix B in interleaved K-split!
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements (1024B)
// - YTILE represents the number of column being serviced
// by wave
// - Loop for fetching weight matrix (B) are unrolled
//
// Fetch activation matrix A from LDS
// - Loop for fetching activation matrix (A) are unrolled
//
// Finally, do the matrix multiplication in an unrolled
// fashion. This provides lot of food for compiler
// scheduling.
//
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
float sum[N][YTILE] = {};
scalar8 sum4[N][YTILE] = {};
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
bigType bigA[N][UNRL] = {};
bigType bigB[YTILE][UNRL];
// Fetch the weight matrix from memory!
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
const scalar_t* B_ = &B[(m + 0) * K + k_];
for (int b = 0; b < YTILE; b++)
bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K])));
const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)];
for (int y = 0; y < YTILE; y++)
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
......@@ -729,36 +618,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
// Fetch A activation matrix in interleaved fashion from LDS or memory
for (int n = 0; n < N; n++) {
if (k_ + K * n < max_lds_len)
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
if (k_ + Kap * n < max_lds_len)
bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n])));
else
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
bigA[n][k2] = *((const bigType*)(&(A[k_ + Kap * n])));
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for (uint32_t n = 0; n < N; n++) {
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for (int y = 0; y < YTILE; y++) {
if constexpr (!use_mfma)
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
}
else
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
......@@ -773,40 +649,38 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if constexpr (!use_mfma) {
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x118, 0xf, 0xf,
1); // row_shr8
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x114, 0xf, 0xf,
1); // row_shr4
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x112, 0xf, 0xf,
1); // row_shr2
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf,
1); // row_shr1
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf,
1); // ROW_BCAST15
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf,
1); // ROW_BCAST31
}
}
if (threadIdx.x == 63) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
}
}
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i]) {
for (int y = 0; y < YTILE; y++) {
if (commitColumn[y]) {
if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS)
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
sum[n][y] += __half2float(biases[n][y]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS)
sum[n][i] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
sum[n][y] += __bfloat162float(biases[n][y]);
}
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y]);
}
}
}
......@@ -819,44 +693,39 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// float accm1 = 0;
// for (int i=0; i<64; i++)
// accm1 += __shfl(sum4[n][y][i%4], i);
float accm = sum4[n][y][0];
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][1]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][2]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][3]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][1], 0x101, 0xf, 0xf,
1); // row_shl1
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][2], 0x102, 0xf, 0xf,
1); // row_shl2
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][3], 0x103, 0xf, 0xf,
1); // row_shl3
accm += __builtin_amdgcn_mov_dpp(accm, 0x104, 0xf, 0xf,
1); // row_shl4
accm += __builtin_amdgcn_mov_dpp(accm, 0x108, 0xf, 0xf,
1); // row_shl8
accm = __builtin_amdgcn_mov_dpp(accm, 0x11f, 0xf, 0xf,
1); // row_shr15
accm += __builtin_amdgcn_mov_dpp(accm, 0x142, 0xf, 0xf,
1); // ROW_BCAST15
accm += __builtin_amdgcn_mov_dpp(accm, 0x143, 0xf, 0xf,
1); // ROW_BCAST31
sum4[n][y][0] = accm;
}
}
if (threadIdx.x == 63) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
}
}
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i]) {
if (BIAS)
sum4[n][i][0] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
for (int y = 0; y < YTILE; y++) {
if (commitColumn[y]) {
sum4[n][y][0] += __bfloat162float(biases[n][y]);
C[m + y + n * M] = __float2bfloat16(sum4[n][y][0]);
}
}
}
......@@ -880,9 +749,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_(const int K, const int M, const int Bx,
const int By, const scalar_t* B,
const scalar_t* __restrict__ A,
__global__ void wvSplitK_hf_(const int K, const int Kbp, const int Kap,
const int M, const int Bx, const int By,
const scalar_t* B, const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
......@@ -894,8 +763,9 @@ __global__ void wvSplitK_hf_(const int K, const int M, const int Bx,
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_big_(const int K, const int M, const int Bx, const int By,
const scalar_t* B, const scalar_t* __restrict__ A,
wvSplitK_hf_big_(const int K, const int Kbp, const int Kap, const int M,
const int Bx, const int By, const scalar_t* B,
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
constexpr int max_lds_len = LDS_SIZE / 2;
......@@ -966,13 +836,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
//----------------------------------------------------
#define PCML
#ifndef PCML
for (uint32_t k = 0; k < min__(K * N, max_lds_len);
k += THRDS * WvPrGrp * A_CHUNK) {
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
if (k_in >= min__(K * N, max_lds_len)) break;
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
#if defined(__gfx950__)
__builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0);
#else
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
#endif
}
__syncthreads();
#endif
......@@ -987,10 +857,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
? kFit
: (kFit - kFit % TUC); // round up to multiple of TUC
// if (kFit == 0) kFit = TUC;
kFit = min__(kFit, K);
float sum[N][YTILE];
scalar8 sum4[N][YTILE];
kFit = min__(kFit, Kap);
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
......@@ -1021,15 +888,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// YTILE represents how many column of weight matrix
// are being worked on by each wave.
//----------------------------------------------------
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++)
if constexpr (!use_mfma)
sum[n][i] = 0;
else
sum4[n][i] = {0, 0, 0, 0};
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
float sum[N][YTILE] = {};
scalar8 sum4[N][YTILE] = {};
//----------------------------------------------------
// Fetch weight matrix B in interleaved K-split!
// - Each thread (lane) is fetching 8 elements (A_Chunk)
......@@ -1048,18 +909,26 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
bigType bigA[N][UNRL] = {};
bigType bigB[YTILE][UNRL];
#ifdef PCML
if ((k1 == 0) || (k1 == kBase + kFit)) { // load next chunk of A[] to LDS
if (k1 != 0) kBase += kFit;
__syncthreads();
for (uint32_t k = 0; k < kFit; k += THRDS * _WvPrGrp * A_CHUNK) {
uint32_t kOff = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
if (kBase + kOff >= K) break;
if (kBase + kOff >= Kap) break;
if (kOff >= kFit) break;
for (uint32_t n = 0; n < N; n++) {
uint32_t k_in = kBase + n * K + kOff;
uint32_t k_in = kBase + n * Kap + kOff;
uint32_t k_ot = n * kFit + kOff;
#if defined(__gfx950__)
__builtin_amdgcn_global_load_lds((int*)(&A[k_in]), (int*)(&s[k_ot]),
16, 0, 0);
#else
*((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in]));
#endif
}
}
__syncthreads();
......@@ -1072,11 +941,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
const scalar_t* B_ = &B[(m + 0) * K + k_];
for (int b = 0; b < YTILE; b++)
bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K])));
const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)];
for (int y = 0; y < YTILE; y++)
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
......@@ -1085,17 +952,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
// Fetch A activation matrix in interleaved fashion from LDS or memory
for (int n = 0; n < N; n++) {
#ifdef PCML
bigA[n][k2] = *((const bigType*)(&(s[k_ - kBase + kFit * n])));
#else
if (k_ + K * n < 32 * 1024)
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
if (k_ + Kap * n < max_lds_len)
bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n])));
else
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
bigA[n][k2] = *((const bigType*)(&(A[k_ + Kap * n])));
#endif
}
}
......@@ -1103,22 +967,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// Do the matrix multiplication in interleaved manner
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
#pragma unroll
for (uint32_t n = 0; n < N; n++) {
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for (int y = 0; y < YTILE; y++) {
if constexpr (!use_mfma)
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
}
else
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
......@@ -1141,40 +996,38 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if constexpr (!use_mfma) {
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x118, 0xf, 0xf,
1); // row_shr8
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x114, 0xf, 0xf,
1); // row_shr4
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x112, 0xf, 0xf,
1); // row_shr2
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf,
1); // row_shr1
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf,
1); // ROW_BCAST15
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf,
1); // ROW_BCAST31
}
}
if (threadIdx.x == 63) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
}
}
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i]) {
for (int y = 0; y < YTILE; y++) {
if (commitColumn[y]) {
if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS)
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
sum[n][y] += __half2float(biases[n][y]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS)
sum[n][i] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
sum[n][y] += __bfloat162float(biases[n][y]);
}
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y]);
}
}
}
......@@ -1185,42 +1038,38 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#pragma unroll
for (int y = 0; y < YTILE; y++) {
float accm = sum4[n][y][0];
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][1]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][2]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][3]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][1], 0x101, 0xf, 0xf,
1); // row_shl1
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][2], 0x102, 0xf, 0xf,
1); // row_shl2
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][3], 0x103, 0xf, 0xf,
1); // row_shl3
accm += __builtin_amdgcn_mov_dpp(accm, 0x104, 0xf, 0xf,
1); // row_shl4
accm += __builtin_amdgcn_mov_dpp(accm, 0x108, 0xf, 0xf,
1); // row_shl8
accm = __builtin_amdgcn_mov_dpp(accm, 0x11f, 0xf, 0xf,
1); // row_shr15
accm += __builtin_amdgcn_mov_dpp(accm, 0x142, 0xf, 0xf,
1); // ROW_BCAST15
accm += __builtin_amdgcn_mov_dpp(accm, 0x143, 0xf, 0xf,
1); // ROW_BCAST31
sum4[n][y][0] = accm;
}
}
if (threadIdx.x == 63) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
}
}
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i]) {
if (BIAS)
sum4[n][i][0] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
for (int y = 0; y < YTILE; y++) {
if (commitColumn[y]) {
sum4[n][y][0] += __bfloat162float(biases[n][y]);
C[m + y + n * M] = __float2bfloat16(sum4[n][y][0]);
}
}
}
......@@ -1244,8 +1093,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx,
const int By, const scalar_t* B,
__global__ void wvSplitK_hf_big_(const int K, const int Kbp, const int Kap,
const int M, const int Bx, const int By,
const scalar_t* B,
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
......@@ -1272,6 +1122,8 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
auto M_in = in_a.size(0);
auto K_in = in_a.size(1);
auto N_in = in_b.size(0);
auto Kap_in = in_a.stride(0);
auto Kbp_in = in_b.stride(0);
auto Bx_in =
(in_bias.has_value() && in_bias->numel() > 0)
? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0)
......@@ -1296,27 +1148,30 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int max_lds_len = get_lds_size() / 2;
#define WVSPLITK(_YTILE, _UNRL, _N) \
{ \
dim3 block(64, 16); \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \
wvSplitK_hf_sml_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
else if (K_in * N_in <= max_lds_len * 1.2) \
wvSplitK_hf_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
else \
wvSplitK_hf_big_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
#define WVSPLITK(_YTILE, _UNRL, _N) \
{ \
dim3 block(64, 16); \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \
if ((Kbp_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \
wvSplitK_hf_sml_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
By_in, af4, bf4, biasf4, c, __wvPrGrp, \
CuCount); \
else if (Kbp_in * N_in <= max_lds_len * 1.2) \
wvSplitK_hf_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
By_in, af4, bf4, biasf4, c, __wvPrGrp, \
CuCount); \
else \
wvSplitK_hf_big_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
By_in, af4, bf4, biasf4, c, __wvPrGrp, \
CuCount); \
}
#define WVSPLIT_TILE(_sYT, __N) \
{ \
bool fit_lds = (K_in * N_in <= max_lds_len); \
bool fit_lds = (Kbp_in * N_in <= max_lds_len); \
if (_sYT <= 1) \
WVSPLITK(1, 4, __N) \
else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \
......@@ -1370,17 +1225,14 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
#if defined(__gfx950__)
#define WVSPLITKRC_1KPASS
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N, int GrpsShrB, int CHUNKK>
int UNRL, int N, int GrpsShrB, int CHUNKK, int DTRMNSTC>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
__attribute__((amdgpu_waves_per_eu(1, 1)))
wvSplitKrc_(const int actlN, const int K, const int M, const int Bx,
const int By, const scalar_t* __restrict__ B,
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, float* glbl, scalar_t* C,
const int CuCount) {
// Use upper half of glbl buffer for atomic reduce counting
int* cntr = (int*)(&glbl[M * N]);
wvSplitKrc_(const int actlN, const int K, const int Kap, const int M,
const int Bx, const int By, const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
const scalar_t* __restrict__ BIAS, float* glbl, int* cntr,
scalar_t* C, const int CuCount) {
constexpr int NTILE = 16;
constexpr int APAD = 1;
constexpr int ASTRD = 64;
......@@ -1568,15 +1420,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
{
#endif
unsigned int kOff = k + (thrd * A_CHUNK);
unsigned int kOffcp =
k_str + kOff; // min__(K - A_CHUNK, k_str + kOff);
unsigned int kOffcp = min__(K - A_CHUNK, k_str + kOff);
for (unsigned int n = 0; n < N; n += CHUNKK * sprdN) {
__builtin_amdgcn_global_load_lds(
(int*)(&A[min__(
K * actlN - A_CHUNK,
kOffcp + K * (n / CHUNKK +
(N / CHUNKK) * (threadIdx.x / (64 / CHUNKK)) +
(threadIdx.y % sprdN)))]),
(int*)(&A[min__(Kap * actlN - A_CHUNK,
kOffcp + Kap * (n / CHUNKK +
(N / CHUNKK) * (threadIdx.x /
(64 / CHUNKK)) +
(threadIdx.y % sprdN)))]),
(int*)(&s[(k +
kFitPdd * ((n / CHUNKK) + (threadIdx.y % sprdN)))]),
16, 0, 0);
......@@ -1623,7 +1474,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#endif
// B[] staging is cooperative across GrpsShrB, so sync here before reading
// back. This wait is currently inserted by compiler, but not gauranteed.
// back. This wait is currently inserted by compiler, but not guaranteed.
asm volatile("s_waitcnt 0");
__syncthreads();
......@@ -1680,45 +1531,98 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}
union flt4 {
scalar8 s8;
float2 f2[2];
float4 f4;
};
if (m + (threadIdx.x % 16) < M) {
int my_cntr;
int mindx = m + (threadIdx.x % 16);
int g_mindx = m * 4 + (threadIdx.x % 64); // coalesced atomic reduction
scalar_t biases[N / NTILE / GrpsShrB][4] = {};
// Atomic add the output, read biases
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++)
for (uint32_t j = 0; j < 4; j++) {
// int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
// (N / GrpsShrB) * (threadIdx.y % GrpsShrB);
// int adr = mindx + M * nindx;
int g_nindx =
j + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4;
int g_adr = g_mindx + M * g_nindx * 4;
atomicAdd(&glbl[g_adr], sum4[nt][0][j]);
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
int g_nindx =
(nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4;
int g_adr = g_mindx * 4 + 0 + M * g_nindx * 4;
if (DTRMNSTC) {
flt4 flt4_ = {.s8 = sum4[nt][0]};
__hip_atomic_store((float2*)&glbl[g_adr + M * N * (m0 / Mmod)],
flt4_.f2[0], __ATOMIC_RELAXED,
__HIP_MEMORY_SCOPE_AGENT);
__hip_atomic_store((float2*)&glbl[g_adr + 2 + M * N * (m0 / Mmod)],
flt4_.f2[1], __ATOMIC_RELAXED,
__HIP_MEMORY_SCOPE_AGENT);
} else {
for (uint32_t j = 0; j < 4; j++)
atomicAdd((&glbl[g_adr + j]), sum4[nt][0][j]);
}
}
__atomic_signal_fence(__ATOMIC_SEQ_CST);
asm volatile("s_waitcnt vmcnt(0)" ::: "memory");
__atomic_signal_fence(__ATOMIC_SEQ_CST);
int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE +
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
int adr_ = mindx + M * nindx_ / 4;
// Update the complete counter
my_cntr = atomicAdd(&cntr[adr_], 1);
float vals[N / NTILE / GrpsShrB][4] = {};
// make sure LDS is free for write out staging
if (DTRMNSTC) __syncthreads();
// Update the complete counter
flt4 vals[N / NTILE / GrpsShrB] = {};
// If we're the last k-shard, read back the value and convert...
if (my_cntr + 1 == k_rnd) {
if (BIAS)
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
for (uint32_t j = 0; j < 4; j++) {
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * Bx];
cntr[adr_] = 0; // clear for next round
if constexpr (DTRMNSTC) {
#pragma unroll
for (int ks = 0; ks < k_rnd; ks++) {
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
int g_nindx =
(nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4;
int g_adr = g_mindx * 4 + 0 + M * g_nindx * 4;
__builtin_amdgcn_global_load_lds(
(float4*)(&glbl[g_adr + M * N * ks]),
&(((float4*)s)[(threadIdx.y * THRDS) + ks * THRDS * 4 +
nt * THRDS * 4 * k_rnd]),
16, 0, 0);
}
}
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
for (uint32_t j = 0; j < 4; j++) {
if (BIAS)
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
for (uint32_t j = 0; j < 4; j++) {
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * Bx];
}
}
asm volatile("s_waitcnt 0");
for (int ks = 0; ks < k_rnd; ks++) {
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
float4 eval = ((float4*)s)[(threadIdx.x + threadIdx.y * THRDS) +
ks * THRDS * 4 + nt * THRDS * 4 * k_rnd];
vals[nt].f4 += eval;
}
}
} else {
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
int g_nindx =
j + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4;
int g_adr = g_mindx + M * g_nindx * 4;
vals[nt][j] = glbl[g_adr];
(nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4;
int g_adr = g_mindx * 4 + 0 + M * g_nindx * 4;
vals[nt].f4 = *(float4*)(&glbl[g_adr]);
*(float4*)(&glbl[g_adr]) = {}; // clear out for next round
}
if (BIAS)
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
for (uint32_t j = 0; j < 4; j++) {
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * Bx];
}
}
}
__builtin_amdgcn_sched_barrier(0);
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
......@@ -1728,11 +1632,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (nindx < actlN) {
int adr = mindx + M * nindx;
if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
vals[nt][j] += __bfloat162float(biases[nt][j]);
C[adr] = __float2bfloat16(vals[nt][j]);
vals[nt].s8[j] += __bfloat162float(biases[nt][j]);
C[adr] = __float2bfloat16(vals[nt].s8[j]);
} else {
vals[nt][j] += __half2float(biases[nt][j]);
C[adr] = __float2half(vals[nt][j]);
vals[nt].s8[j] += __half2float(biases[nt][j]);
C[adr] = __float2half(vals[nt].s8[j]);
}
}
}
......@@ -1751,21 +1655,25 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N, int GrpsShrB, int CHUNKK>
__global__ void wvSplitKrc_(const int actlN, const int K, const int M,
const int Bx, const int By, const scalar_t* B,
const scalar_t* __restrict__ A,
int UNRL, int N, int GrpsShrB, int CHUNKK, int DTRMNSTC>
__global__ void wvSplitKrc_(const int actlN, const int K, const int Kap,
const int M, const int Bx, const int By,
const scalar_t* B, const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, float* glbl,
// int* cntr,
scalar_t* C, const int CuCount){UNREACHABLE_CODE}
int* cntr, scalar_t* C,
const int CuCount){UNREACHABLE_CODE}
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
const std::optional<at::Tensor>& in_bias,
const int64_t CuCount) {
auto M_in = in_a.size(0);
auto N_in = in_b.size(0);
auto K_in = in_a.size(1);
int _DTRMNSTC = 1; // vllm::vllm_is_batch_invariant();
auto M_in = in_b.size(0);
auto N_in = in_a.size(0);
auto K_in = in_b.size(1);
auto Kap_in = in_a.stride(0);
auto Bx_in =
(in_bias.has_value() && in_bias->numel() > 0)
? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0)
......@@ -1782,13 +1690,9 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
auto out_c = torch::empty(
{N_in, M_in},
torch::TensorOptions().dtype(in_b.dtype()).device(in_b.device()));
torch::TensorOptions().dtype(in_a.dtype()).device(in_a.device()));
auto N_p2 = 1U << (32 - __builtin_clz(N_in - 1));
auto axl_glbl = torch::empty(
{N_p2 + N_p2 / 4, M_in + M_in / 4},
torch::TensorOptions().dtype(torch::kFloat32).device(in_b.device()));
axl_glbl.zero_(); // disable for FAST_UNSAFE_RDC_INIT
dim3 grid(CuCount);
......@@ -1796,55 +1700,70 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// const int max_lds_len = get_lds_size() / 2;
// With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
// and each working on a 512-shard of K, how many CUs would we need?
int rndup_cus = ((M_in + 64 - 1) / 64) * ((K_in + 512 - 1) / 512);
// How many of 4 waves in a group can work on same 16 Ms at same time? First
// try to maximize this. This reduces the Ms each group works on, i.e.
// increasing the number of CUs needed.
int GrpsShrB = min(N_p2 / 16, 4);
// Given the above, how many CUs would we need?
int CuNeeded = rndup_cus * GrpsShrB;
if (CuNeeded > CuCount) throw std::runtime_error("Invalid wvSplitKrc size");
// Can we increase SplitK by shrinking the K-shared to 256?
int chunkk = (CuNeeded * 2 <= CuCount) ? 2 : 1;
static torch::Tensor axl_glbl =
torch::zeros(
128 * 1024 * (_DTRMNSTC ? 12 : 1),
torch::TensorOptions().dtype(torch::kFloat32).device(in_a.device()))
.detach();
static torch::Tensor axl_cntr =
torch::zeros(
128 * 1024 * (_DTRMNSTC ? 12 : 1) / 4,
torch::TensorOptions().dtype(torch::kInt).device(in_a.device()))
.detach();
auto glbl = axl_glbl.data_ptr<float>();
auto cntr = axl_cntr.data_ptr<int>();
#define WVSPLITKrc(_N, _GrpsShrB, _CHUNKK) \
{ \
dim3 block(64, 4); \
wvSplitKrc_<fptype, 64, 16, 4, 8, 1, _N, _GrpsShrB, _CHUNKK> \
<<<grid, block, 0, stream>>>(N_in, K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, glbl, c, CuCount); \
if (_DTRMNSTC) \
wvSplitKrc_<fptype, 64, 16, 4, 8, 1, _N, _GrpsShrB, _CHUNKK, 1> \
<<<grid, block, 0, stream>>>(N_in, K_in, Kap_in, M_in, Bx_in, By_in, \
af4, bf4, biasf4, glbl, cntr, c, \
CuCount); \
else \
wvSplitKrc_<fptype, 64, 16, 4, 8, 1, _N, _GrpsShrB, _CHUNKK, 0> \
<<<grid, block, 0, stream>>>(N_in, K_in, Kap_in, M_in, Bx_in, By_in, \
af4, bf4, biasf4, glbl, cntr, c, \
CuCount); \
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitKrc", [&] {
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_a.scalar_type(), "wvSplitKrc", [&] {
using fptype = typename scalar<scalar_t>::type;
fptype* af4 = reinterpret_cast<fptype*>(in_a.data_ptr());
const fptype* af4 = reinterpret_cast<const fptype*>(in_a.data_ptr());
const fptype* bf4 = reinterpret_cast<const fptype*>(in_b.data_ptr());
const fptype* biasf4 =
(in_bias.has_value() && in_bias->numel() > 0)
? reinterpret_cast<const fptype*>(in_bias->data_ptr())
: nullptr;
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
auto glbl = axl_glbl.data_ptr<float>();
// With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
// and each working on a 512-shard of K, how many CUs would we need?
int rndup_cus = ((M_in + 64 - 1) / 64) * ((K_in + 512 - 1) / 512);
// How many of 4 waves in a group can work on same 16 Ms at same time? First
// try to maximize this. This reduces the Ms each group works on, i.e.
// increasing the number of CUs needed.
int GrpsShrB = min(N_p2 / 16, 4);
// Given the above, how many CUs would we need?
int CuNeeded = rndup_cus * GrpsShrB;
if (CuNeeded > CuCount) std::runtime_error("Invalid wvSplitKrc size");
// Can we increase SplitK by shrinking the K-shared to 256?
int chunkk = (CuNeeded * 2 <= CuCount) ? 2 : 1;
switch (N_p2) {
case 16:
WVSPLITKrc(16, 1, 1) break;
case 32:
if (chunkk == 2)
WVSPLITKrc(32, 2, 2) else if (chunkk == 1) WVSPLITKrc(32, 2, 1) break;
if (chunkk == 2) WVSPLITKrc(32, 2, 2) else WVSPLITKrc(32, 2, 1) break;
case 64:
if (chunkk == 2)
WVSPLITKrc(64, 4, 2) else if (chunkk == 1) WVSPLITKrc(64, 4, 1) break;
if (chunkk == 2) WVSPLITKrc(64, 4, 2) else WVSPLITKrc(64, 4, 1) break;
case 128:
if (chunkk == 2)
WVSPLITKrc(128, 4, 2) else if (chunkk == 1)
WVSPLITKrc(128, 4, 1) break;
if (chunkk == 2) WVSPLITKrc(128, 4, 2) else WVSPLITKrc(128, 4, 1) break;
default:
throw std::runtime_error(
"Unsupported N value: " + std::to_string(M_in) + "," +
......@@ -1903,7 +1822,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
float sB = *s_B;
while (m < M) {
floatx16 sum[N][YTILE] = {};
scalar8 sum[N][YTILE] = {};
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
bigType bigA[N][UNRL] = {};
bigType bigB[YTILE][UNRL];
......@@ -1937,7 +1856,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (uint32_t n = 0; n < N; n++) {
for (int i = 0; i < A_CHUNK; i += 8) {
for (int y = 0; y < YTILE; ++y) {
sum[n][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
sum[n][y] = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
bigA[n][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[n][y], 0, 0,
0);
}
......@@ -1950,31 +1869,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
float accm0 = sum[n][y][0];
float accm16 = sum[n][y][8];
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][1], 0x101, 0xf, 0xf,
1); // row_shl1
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][9], 0x101, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf,
1); // row_shl2
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][10], 0x102, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][3], 0x103, 0xf, 0xf,
1); // row_shl3
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][11], 0x103, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][4], 0x108, 0xf, 0xf,
1); // row_shl8
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][12], 0x108, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][5], 0x109, 0xf, 0xf,
1); // row_shl9
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][13], 0x109, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][6], 0x10a, 0xf, 0xf,
1); // row_shl10
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][14], 0x10a, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][7], 0x10b, 0xf, 0xf,
1); // row_shl11
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][15], 0x10b, 0xf, 0xf, 1);
accm0 += __shfl(accm0, 36);
accm16 += __shfl(accm16, 52);
sum[n][y][0] = accm0 + __shfl(accm16, 16);
accm0 += __shfl_down(accm0, 20);
accm0 += __shfl_down(accm0, 40);
sum[n][y][0] = accm0;
}
}
......@@ -2065,7 +1968,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
float sB = *s_B;
while (m < M) {
floatx16 sum[N][YTILE] = {};
scalar8 sum[N][YTILE] = {};
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
bigType bigA[N][UNRL] = {};
bigType bigB[YTILE][UNRL];
......@@ -2101,7 +2004,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (uint32_t n = 0; n < N; n++) {
for (int i = 0; i < A_CHUNK; i += 8) {
for (int y = 0; y < YTILE; ++y) {
sum[n][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
sum[n][y] = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
bigA[n][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[n][y], 0, 0,
0);
}
......@@ -2114,31 +2017,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
float accm0 = sum[n][y][0];
float accm16 = sum[n][y][8];
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][1], 0x101, 0xf, 0xf,
1); // row_shl1
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][9], 0x101, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf,
1); // row_shl2
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][10], 0x102, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][3], 0x103, 0xf, 0xf,
1); // row_shl3
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][11], 0x103, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][4], 0x108, 0xf, 0xf,
1); // row_shl8
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][12], 0x108, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][5], 0x109, 0xf, 0xf,
1); // row_shl9
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][13], 0x109, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][6], 0x10a, 0xf, 0xf,
1); // row_shl10
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][14], 0x10a, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][7], 0x10b, 0xf, 0xf,
1); // row_shl11
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][15], 0x10b, 0xf, 0xf, 1);
accm0 += __shfl(accm0, 36);
accm16 += __shfl(accm16, 52);
sum[n][y][0] = accm0 + __shfl(accm16, 16);
accm0 += __shfl_down(accm0, 20);
accm0 += __shfl_down(accm0, 40);
sum[n][y][0] = accm0;
}
}
......@@ -2243,16 +2130,16 @@ void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a,
: nullptr;
switch (N_in) {
case 1:
WVSPLITKQ(12, 2, 2, 2, 2, 1)
WVSPLITKQ(16, 2, 2, 2, 2, 1)
break;
case 2:
WVSPLITKQ(12, 2, 2, 2, 2, 2)
WVSPLITKQ(16, 2, 2, 2, 2, 2)
break;
case 3:
WVSPLITKQ(8, 2, 2, 1, 1, 3)
WVSPLITKQ(16, 2, 2, 2, 2, 3)
break;
case 4:
WVSPLITKQ(4, 2, 2, 1, 1, 4)
WVSPLITKQ(16, 2, 2, 2, 2, 4)
break;
default:
throw std::runtime_error(
......
......@@ -590,7 +590,7 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
// The range of logits within the row.
int rowStart = 0;
int seq_len = seqLens[rowIdx / next_n];
int rowEnd = seq_len - next_n + (rowIdx % next_n) + 1;
int rowEnd = max(0, seq_len - next_n + (rowIdx % next_n) + 1);
// Local pointers to this block
if constexpr (!multipleBlocksPerRow && !mergeBlocks) {
......@@ -740,4 +740,4 @@ void top_k_per_row_prefill(const torch::Tensor& logits,
static_cast<int>(stride0), static_cast<int>(stride1),
static_cast<int>(topK), kSortingAlgorithmThreshold);
}
}
}
\ No newline at end of file
......@@ -6,11 +6,11 @@
#include "cutlass_extensions/common.hpp"
bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability) {
// sparse CUTLASS kernels need at least
// sparse CUTLASS kernels need exactly hopper and are not forward compatible
// CUDA 12.2 and SM90 (Hopper)
#if defined CUDA_VERSION
return CUDA_VERSION >= 12020 && cuda_device_capability >= 90;
return CUDA_VERSION >= 12020 && cuda_device_capability == 90;
#endif
return false;
......@@ -98,7 +98,7 @@ std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a) {
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_sparse_compress for a compute capability less than "
"No compiled cutlass_sparse_compress for a compute capability equal to "
"CUDA device capability: ",
version_num);
}
// Portions of this file are adapted from SGLang PR:
// https://github.com/sgl-project/sglang/pull/11194
// and
// https://github.com/sgl-project/sglang/pull/17747
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#ifndef USE_ROCM
#include <cub/cub.cuh>
#else
#include <hipcub/hipcub.hpp>
#endif
namespace vllm {
constexpr int TopK = 2048; // DeepSeek V3 sparse attention top-k
constexpr int kThreadsPerBlock = 1024; // Threads per block
// Shared memory budget
#if defined(USE_ROCM)
constexpr size_t kSmem = 48 * 1024; // ROCm default: 48KB
#else
// Reduced from 128KB to 32KB to improve occupancy.
// Each radix pass needs at most ~TopK candidates in the threshold bin,
// so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient.
constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes)
#endif
struct FastTopKParams {
const float* __restrict__ input; // [batch, seq_len] Logits
const int32_t* __restrict__ row_starts; // [batch] Offset into each row
// (optional)
int32_t* __restrict__ indices; // [batch, TopK] Output top-k indices
int32_t* __restrict__ lengths; // [batch] Sequence lengths per row
int64_t input_stride; // Stride between rows
};
__device__ __forceinline__ auto convert_to_uint32_v2(float x) -> uint32_t {
uint32_t bits = __float_as_uint(x);
return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u);
}
__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t {
__half h = __float2half_rn(x);
uint16_t bits = __half_as_ushort(h);
uint16_t key = (bits & 0x8000) ? static_cast<uint16_t>(~bits)
: static_cast<uint16_t>(bits | 0x8000);
return static_cast<uint8_t>(key >> 8);
}
__device__ void naive_topk_cuda(const float* __restrict__ logits,
int32_t* __restrict__ output_indices,
int32_t seq_len) {
const int thread_id = threadIdx.x;
for (int i = thread_id; i < TopK; i += kThreadsPerBlock) {
output_indices[i] = (i < seq_len) ? i : -1;
}
}
// Adapted from:
// https://github.com/sgl-project/sglang/blob/v0.5.8/sgl-kernel/csrc/elementwise/topk.cu#L87
// by: DarkSharpness
// which at the same time is an optimized topk kernel copied from tilelang
// kernel
__device__ void fast_topk_cuda_tl(
const float* __restrict__ logits, // Input logits [seq_len]
int* __restrict__ output_indices, // Output top-k indices [TopK]
int logits_offset, // Starting offset in logits array
int seq_len) // Number of valid logits to process
{
constexpr int RADIX = 256;
constexpr int MAX_BUFFERED_ITEMS = kSmem / (2 * sizeof(int));
alignas(128) __shared__ int shared_histogram[2][RADIX + 128];
alignas(128) __shared__ int shared_output_count;
alignas(128) __shared__ int shared_threshold_bin;
alignas(128) __shared__ int shared_buffered_count[2];
extern __shared__ int buffered_indices[][MAX_BUFFERED_ITEMS];
const int thread_id = threadIdx.x;
int remaining_k = TopK;
// Pass 0: Build coarse 8-bit histogram using FP16 high bits
if (thread_id < RADIX + 1) {
shared_histogram[0][thread_id] = 0;
}
__syncthreads();
for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) {
const auto bin = convert_to_uint8(logits[idx + logits_offset]);
::atomicAdd(&shared_histogram[0][bin], 1);
}
__syncthreads();
// Helper: Compute cumulative sum (suffix sum) over histogram using ping-pong
// buffers
auto compute_cumulative_sum = [&]() {
static_assert(1 << 8 == RADIX,
"Radix must be 256 for 8 unrolled iterations");
#pragma unroll 8
for (int i = 0; i < 8; ++i) {
if (C10_LIKELY(thread_id < RADIX)) {
const int stride = 1 << i;
const int src_buffer = i & 1;
const int dst_buffer = src_buffer ^ 1;
int value = shared_histogram[src_buffer][thread_id];
if (thread_id < RADIX - stride) {
value += shared_histogram[src_buffer][thread_id + stride];
}
shared_histogram[dst_buffer][thread_id] = value;
}
__syncthreads();
}
};
compute_cumulative_sum();
// Find threshold bin where cumsum crosses remaining_k
if (thread_id < RADIX && shared_histogram[0][thread_id] > remaining_k &&
shared_histogram[0][thread_id + 1] <= remaining_k) {
shared_threshold_bin = thread_id;
shared_buffered_count[0] = 0;
shared_output_count = 0;
}
__syncthreads();
const int threshold_bin = shared_threshold_bin;
remaining_k -= shared_histogram[0][threshold_bin + 1];
// Early exit if threshold bin perfectly matches remaining_k
if (remaining_k == 0) {
for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) {
const int bin = convert_to_uint8(logits[idx + logits_offset]);
if (bin > threshold_bin) {
const int output_pos = ::atomicAdd(&shared_output_count, 1);
output_indices[output_pos] = idx;
}
}
__syncthreads();
return;
}
// Prepare for refinement passes: Process threshold bin
__syncthreads();
if (thread_id < RADIX + 1) {
shared_histogram[0][thread_id] = 0;
}
__syncthreads();
// Scan all elements and:
// 1. Write indices > threshold_bin to output
// 2. Buffer indices == threshold_bin for refinement
// 3. Build histogram for next refinement pass (fused optimization)
for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) {
const float logit_value = logits[idx + logits_offset];
const int bin = convert_to_uint8(logit_value);
if (bin > threshold_bin) {
// in top-k, write to output
const int output_pos = ::atomicAdd(&shared_output_count, 1);
output_indices[output_pos] = idx;
} else if (bin == threshold_bin) {
// Candidate for top-k, needs refinement
const int buffer_pos = ::atomicAdd(&shared_buffered_count[0], 1);
if (C10_LIKELY(buffer_pos < MAX_BUFFERED_ITEMS)) {
buffered_indices[0][buffer_pos] = idx;
// Fused: Build histogram for next pass
const uint32_t fp32_bits = convert_to_uint32_v2(logit_value);
const int next_bin = (fp32_bits >> 24) & 0xFF;
::atomicAdd(&shared_histogram[0][next_bin], 1);
}
}
}
__syncthreads();
// ============================================================================
// Passes 1-4: Refine using 8-bit passes over FP32 bits
// ============================================================================
// FP32 bits [31:0] split into 4 bytes processed MSB-first:
// Pass 1: bits [31:24], Pass 2: bits [23:16], Pass 3: bits [15:8], Pass 4:
// bits [7:0]
#pragma unroll 4
for (int pass = 0; pass < 4; ++pass) {
__shared__ int shared_final_k; // For final pass: remaining slots to fill
const int src_buffer = pass % 2;
const int dst_buffer = src_buffer ^ 1;
// Clamp buffered count to prevent overflow
const int raw_buffered = shared_buffered_count[src_buffer];
const int num_buffered =
(raw_buffered < MAX_BUFFERED_ITEMS) ? raw_buffered : MAX_BUFFERED_ITEMS;
compute_cumulative_sum();
// Find threshold bin for this pass
if (thread_id < RADIX && shared_histogram[0][thread_id] > remaining_k &&
shared_histogram[0][thread_id + 1] <= remaining_k) {
shared_threshold_bin = thread_id;
shared_buffered_count[dst_buffer] = 0;
shared_final_k = remaining_k - shared_histogram[0][thread_id + 1];
}
__syncthreads();
const int threshold_bin = shared_threshold_bin;
remaining_k -= shared_histogram[0][threshold_bin + 1];
// Bit offset for this pass: 24, 16, 8, 0
const int bit_offset = 24 - pass * 8;
// Early exit if threshold bin perfectly matches
if (remaining_k == 0) {
for (int i = thread_id; i < num_buffered; i += kThreadsPerBlock) {
const int idx = buffered_indices[src_buffer][i];
const uint32_t fp32_bits =
convert_to_uint32_v2(logits[idx + logits_offset]);
const int bin = (fp32_bits >> bit_offset) & 0xFF;
if (bin > threshold_bin) {
const int output_pos = ::atomicAdd(&shared_output_count, 1);
output_indices[output_pos] = idx;
}
}
__syncthreads();
break;
}
// Continue refinement
__syncthreads();
if (thread_id < RADIX + 1) {
shared_histogram[0][thread_id] = 0;
}
__syncthreads();
for (int i = thread_id; i < num_buffered; i += kThreadsPerBlock) {
const int idx = buffered_indices[src_buffer][i];
const float logit_value = logits[idx + logits_offset];
const uint32_t fp32_bits = convert_to_uint32_v2(logit_value);
const int bin = (fp32_bits >> bit_offset) & 0xFF;
if (bin > threshold_bin) {
// Definitely in top-k
const int output_pos = ::atomicAdd(&shared_output_count, 1);
output_indices[output_pos] = idx;
} else if (bin == threshold_bin) {
if (pass == 3) {
// Final pass (bits [7:0]): No more refinement possible
// Fill remaining slots in reverse order to maintain descending order
const int slot = ::atomicAdd(&shared_final_k, -1);
if (slot > 0) {
output_indices[TopK - slot] = idx;
}
} else {
// Buffer for next pass and build next histogram
const int buffer_pos =
::atomicAdd(&shared_buffered_count[dst_buffer], 1);
if (C10_LIKELY(buffer_pos < MAX_BUFFERED_ITEMS)) {
buffered_indices[dst_buffer][buffer_pos] = idx;
// Fused: Build histogram for next pass
const int next_bit_offset = bit_offset - 8;
const int next_bin = (fp32_bits >> next_bit_offset) & 0xFF;
::atomicAdd(&shared_histogram[0][next_bin], 1);
}
}
}
}
__syncthreads();
}
}
__global__ __launch_bounds__(kThreadsPerBlock) void topk_kernel(
const FastTopKParams params) {
const auto& [input, row_starts, indices, lengths, input_stride] = params;
const uint64_t batch_idx = blockIdx.x;
const int logits_offset = row_starts == nullptr ? 0 : row_starts[batch_idx];
const int seq_len = lengths[batch_idx];
int* output_indices = indices + batch_idx * TopK;
const float* logits = input + batch_idx * input_stride;
if (seq_len <= TopK) {
// Shortcut: All elements are in top-k
return naive_topk_cuda(logits, output_indices, seq_len);
} else {
return fast_topk_cuda_tl(logits, output_indices, logits_offset, seq_len);
}
}
FastTopKParams get_params(
const at::Tensor& score, const at::Tensor& lengths,
std::optional<at::Tensor> row_starts_opt = std::nullopt,
std::optional<at::Tensor> indices_opt = std::nullopt) {
const int64_t batch_size = score.size(0);
TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1,
"score must be 2D with contiguous rows");
TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous() &&
lengths.size(0) == batch_size,
"lengths must be 1D contiguous with size matching batch");
const int32_t* row_starts_ptr = nullptr;
if (row_starts_opt.has_value()) {
const auto& row_starts = *row_starts_opt;
TORCH_CHECK(row_starts.dim() == 1 && row_starts.size(0) == batch_size,
"row_starts must be 1D with size matching batch");
row_starts_ptr = row_starts.data_ptr<int32_t>();
}
int32_t* indices_ptr = nullptr;
if (indices_opt.has_value()) {
const auto& indices = *indices_opt;
TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous() &&
indices.size(0) == batch_size && indices.size(1) == TopK,
"indices must be 2D contiguous [batch, TopK]");
indices_ptr = indices.data_ptr<int32_t>();
}
return FastTopKParams{
.input = score.data_ptr<float>(),
.row_starts = row_starts_ptr,
.indices = indices_ptr,
.lengths = lengths.data_ptr<int32_t>(),
.input_stride = score.stride(0),
};
}
template <auto* kernel_func, size_t smem_bytes>
void setup_kernel_smem_once() {
static const cudaError_t result = []() -> cudaError_t {
#ifdef USE_ROCM
auto func_ptr = reinterpret_cast<const void*>(kernel_func);
#else
auto func_ptr = kernel_func;
#endif
return cudaFuncSetAttribute(
func_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
}();
TORCH_CHECK(
result == cudaSuccess,
"Failed to set kernel shared memory limit: ", cudaGetErrorString(result));
}
} // namespace vllm
void large_context_topk(
const torch::Tensor& logits, torch::Tensor& indices,
const torch::Tensor& seq_lens,
std::optional<torch::Tensor> row_starts = std::nullopt) {
TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor");
TORCH_CHECK(indices.is_cuda(), "indices must be a CUDA tensor");
TORCH_CHECK(seq_lens.is_cuda(), "seq_lens must be a CUDA tensor");
if (row_starts.has_value()) {
TORCH_CHECK(row_starts->is_cuda(), "row_starts must be a CUDA tensor");
}
const auto params = vllm::get_params(logits, seq_lens, row_starts, indices);
const int64_t batch_size = logits.size(0);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const dim3 grid(static_cast<uint32_t>(batch_size));
const dim3 block(vllm::kThreadsPerBlock);
vllm::setup_kernel_smem_once<vllm::topk_kernel, vllm::kSmem>();
vllm::topk_kernel<<<grid, block, vllm::kSmem, stream>>>(params);
const cudaError_t result = cudaGetLastError();
TORCH_CHECK(result == cudaSuccess,
"large_context_topk kernel failed: ", cudaGetErrorString(result));
}
\ No newline at end of file
......@@ -190,6 +190,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"int numRows, int stride0, int stride1, int topK) -> ()");
ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode);
ops.def(
"large_context_topk(Tensor score, Tensor indices, Tensor lengths, "
"Tensor? "
"row_starts_opt) -> ()");
ops.impl("large_context_topk", torch::kCUDA, &large_context_topk);
// Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor.
// ops.def(
......@@ -233,6 +239,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Quantization ops
#ifndef USE_ROCM
// DeepSeek V3 fused A GEMM (SM 9.0+, bf16 only, 1-16 tokens).
ops.def(
"dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
// conditionally compiled so impl registration is in source file
// Quantized GEMM for AWQ.
ops.def(
"awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
......@@ -415,6 +426,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
// conditionally compiled so impl registration is in source file
// Expert-specialization mxfp8 blockscaled grouped quantization (SM100+).
ops.def(
"mxfp8_experts_quant("
" Tensor input, Tensor problem_sizes, Tensor expert_offsets,"
" Tensor blockscale_offsets, Tensor! quant_output, Tensor! scale_factor)"
" -> ()");
// conditionally compiled so impl registration is in source file
// Expert-specialization mxfp8 blockscaled grouped GEMM (SM100+).
ops.def(
"cutlass_mxfp8_grouped_mm("
" Tensor a, Tensor b, Tensor sfa, Tensor sfb, Tensor! out,"
" Tensor problem_sizes, Tensor expert_offsets, Tensor blockscale_offsets)"
" -> ()");
// conditionally compiled so impl registration is in source file
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
ops.def(
......@@ -478,19 +505,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
&get_cutlass_moe_mm_problem_sizes_from_expert_offsets);
// A function that computes data required to run fused MoE with w8a8 grouped
// GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs
// GEMM in batched expert format. It takes expert_num_tokens
// as an input, and computes expert_offsets (token start indices of each
// expert). In addition to this, it computes problem sizes for each expert's
// multiplication used by the two mms called from fused MoE operation.
ops.def(
"get_cutlass_pplx_moe_mm_data(Tensor! expert_offsets, "
"get_cutlass_batched_moe_mm_data(Tensor! expert_offsets, "
" Tensor! problem_sizes1, "
" Tensor! problem_sizes2, "
" Tensor expert_num_tokens, "
" int num_local_experts, int padded_m, "
" int n, int k) -> ()");
ops.impl("get_cutlass_pplx_moe_mm_data", torch::kCUDA,
&get_cutlass_pplx_moe_mm_data);
ops.impl("get_cutlass_batched_moe_mm_data", torch::kCUDA,
&get_cutlass_batched_moe_mm_data);
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
ops.def(
......@@ -537,10 +564,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Compute NVFP4 block quantized tensor.
ops.def(
"scaled_fp4_quant(Tensor! output, Tensor input,"
" Tensor! output_scale, Tensor input_scale, bool "
"is_sf_swizzled_layout) -> ()");
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
"scaled_fp4_quant(Tensor input,"
" Tensor input_scale, bool "
"is_sf_swizzled_layout) -> (Tensor, Tensor)");
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant_func);
// Out variant
// TODO: Add {at::Tag::out_variant} tag and update all call sites
// to use the functional variant once vLLM upgrades PyTorch.
// See pytorch/pytorch#176117.
ops.def(
"scaled_fp4_quant.out(Tensor input,"
" Tensor input_scale, bool "
"is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) "
"-> ()");
ops.impl("scaled_fp4_quant.out", torch::kCUDA, &scaled_fp4_quant_out);
// Compute NVFP4 experts quantization.
ops.def(
......@@ -629,7 +667,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"int block_size,"
"Tensor? block_idx_first_scheduled_token,"
"Tensor? block_idx_last_scheduled_token,"
"Tensor? initial_state_idx) -> ()");
"Tensor? initial_state_idx,"
"Tensor? cu_chunk_seqlen,"
"Tensor? last_chunk_indices) -> ()");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
// Hadamard transforms
......@@ -637,11 +677,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
#ifndef USE_ROCM
// Compute per-token-group FP8 quantized tensor and scaling factor.
// The dummy arguments are here so we can correctly fuse with RMSNorm.
ops.def(
"per_token_group_fp8_quant(Tensor input, Tensor! output_q, Tensor! "
"output_s, "
"int group_size, float eps, float fp8_min, float fp8_max, bool "
"scale_ue8m0) -> ()");
"scale_ue8m0, bool dummy_is_scale_transposed, bool dummy_is_tma_aligned "
") -> ()");
ops.impl("per_token_group_fp8_quant", torch::kCUDA,
&per_token_group_quant_fp8);
......@@ -771,6 +813,10 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA,
&indexer_k_quant_and_cache);
cache_ops.def(
"concat_mla_q(Tensor ql_nope, Tensor q_pe, Tensor! q_out) -> ()");
cache_ops.impl("concat_mla_q", torch::kCUDA, &concat_mla_q);
cache_ops.def(
"cp_gather_indexer_k_quant_cache(Tensor kv_cache, Tensor! dst_k, Tensor! "
"dst_scale, Tensor block_table, Tensor cu_seq_lens) -> ()");
......
......@@ -132,8 +132,10 @@ ENV UV_LINK_MODE=copy
# Verify GCC version
RUN gcc --version
# Ensure CUDA compatibility library is loaded
RUN echo "/usr/local/cuda-$(echo "$CUDA_VERSION" | cut -d. -f1,2)/compat/" > /etc/ld.so.conf.d/cuda-compat.conf && ldconfig
# Enable CUDA forward compatibility by setting '-e VLLM_ENABLE_CUDA_COMPATIBILITY=1'
# Only needed for datacenter/professional GPUs with older drivers.
# See: https://docs.nvidia.com/deploy/cuda-compatibility/
ENV VLLM_ENABLE_CUDA_COMPATIBILITY=0
# ============================================================
# SLOW-CHANGING DEPENDENCIES BELOW
......@@ -260,7 +262,9 @@ RUN --mount=type=cache,target=/root/.cache/uv \
# Build the vLLM wheel
# if USE_SCCACHE is set, use sccache to speed up compilation
# AWS credentials mounted at ~/.aws/credentials for sccache S3 auth (optional)
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=secret,id=aws-credentials,target=/root/.aws/credentials,required=false \
if [ "$USE_SCCACHE" = "1" ]; then \
echo "Installing sccache..." \
&& case "${TARGETPLATFORM}" in \
......@@ -306,7 +310,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
#################### CSRC BUILD IMAGE ####################
#################### EXTENSIONS BUILD IMAGE ####################
# Build DeepGEMM, pplx-kernels, DeepEP - runs in PARALLEL with csrc-build
# Build DeepGEMM, DeepEP - runs in PARALLEL with csrc-build
# This stage is independent and doesn't affect csrc cache
FROM base AS extensions-build
ARG CUDA_VERSION
......@@ -333,10 +337,9 @@ RUN --mount=type=cache,target=/root/.cache/uv \
# Ensure the wheel dir exists so COPY won't fail when DeepGEMM is skipped
RUN mkdir -p /tmp/deepgemm/dist && touch /tmp/deepgemm/dist/.deepgemm_skipped
# Build pplx-kernels and DeepEP wheels
# Build DeepEP wheels
COPY tools/ep_kernels/install_python_libraries.sh /tmp/install_python_libraries.sh
# Defaults moved here from tools/ep_kernels/install_python_libraries.sh for centralized version management
ARG PPLX_COMMIT_HASH=12cecfd
ARG DEEPEP_COMMIT_HASH=73b6ea4
ARG NVSHMEM_VER
RUN --mount=type=cache,target=/root/.cache/uv \
......@@ -345,7 +348,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \
/tmp/install_python_libraries.sh \
--workspace /tmp/ep_kernels_workspace \
--mode wheel \
${PPLX_COMMIT_HASH:+--pplx-ref "$PPLX_COMMIT_HASH"} \
${DEEPEP_COMMIT_HASH:+--deepep-ref "$DEEPEP_COMMIT_HASH"} \
${NVSHMEM_VER:+--nvshmem-ver "$NVSHMEM_VER"} && \
find /tmp/ep_kernels_workspace/nvshmem -name '*.a' -delete
......@@ -560,8 +562,10 @@ ENV UV_HTTP_TIMEOUT=500
ENV UV_INDEX_STRATEGY="unsafe-best-match"
ENV UV_LINK_MODE=copy
# Ensure CUDA compatibility library is loaded
RUN echo "/usr/local/cuda-$(echo "$CUDA_VERSION" | cut -d. -f1,2)/compat/" > /etc/ld.so.conf.d/cuda-compat.conf && ldconfig
# Enable CUDA forward compatibility by setting '-e VLLM_ENABLE_CUDA_COMPATIBILITY=1'
# Only needed for datacenter/professional GPUs with older drivers.
# See: https://docs.nvidia.com/deploy/cuda-compatibility/
ENV VLLM_ENABLE_CUDA_COMPATIBILITY=0
# ============================================================
# SLOW-CHANGING DEPENDENCIES BELOW
......@@ -582,7 +586,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
# This is ~1.1GB and only changes when FlashInfer version bumps
# https://docs.flashinfer.ai/installation.html
# From versions.json: .flashinfer.version
ARG FLASHINFER_VERSION=0.6.3
ARG FLASHINFER_VERSION=0.6.6
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system flashinfer-cubin==${FLASHINFER_VERSION} \
&& uv pip install --system flashinfer-jit-cache==${FLASHINFER_VERSION} \
......@@ -616,7 +620,7 @@ RUN set -eux; \
ARG BITSANDBYTES_VERSION_X86=0.46.1
ARG BITSANDBYTES_VERSION_ARM64=0.42.0
ARG TIMM_VERSION=">=1.0.17"
ARG RUNAI_MODEL_STREAMER_VERSION=">=0.15.3"
ARG RUNAI_MODEL_STREAMER_VERSION=">=0.15.7"
RUN --mount=type=cache,target=/root/.cache/uv \
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
BITSANDBYTES_VERSION="${BITSANDBYTES_VERSION_ARM64}"; \
......@@ -624,7 +628,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
BITSANDBYTES_VERSION="${BITSANDBYTES_VERSION_X86}"; \
fi; \
uv pip install --system accelerate hf_transfer modelscope \
"bitsandbytes>=${BITSANDBYTES_VERSION}" "timm${TIMM_VERSION}" "runai-model-streamer[s3,gcs]${RUNAI_MODEL_STREAMER_VERSION}"
"bitsandbytes>=${BITSANDBYTES_VERSION}" "timm${TIMM_VERSION}" "runai-model-streamer[s3,gcs,azure]${RUNAI_MODEL_STREAMER_VERSION}"
# ============================================================
# VLLM INSTALLATION (depends on build stage)
......@@ -672,7 +676,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
# Pytorch now installs NVSHMEM, setting LD_LIBRARY_PATH
ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
# Install EP kernels wheels (pplx-kernels and DeepEP) that have been built in the `build` stage
# Install EP kernels wheels (DeepEP) that have been built in the `build` stage
RUN --mount=type=bind,from=build,src=/tmp/ep_kernels_workspace/dist,target=/vllm-workspace/ep_kernels/dist \
--mount=type=cache,target=/root/.cache/uv \
uv pip install --system ep_kernels/dist/*.whl --verbose \
......
......@@ -9,17 +9,14 @@
#
# Build targets:
# vllm-openai (default): used for serving deployment
# vllm-openai-zen: vLLM from source + zentorch from PyPI via vllm[zen]
# vllm-test: used for CI tests
# vllm-dev: used for development
#
# Build arguments:
# PYTHON_VERSION=3.13|3.12 (default)|3.11|3.10
# VLLM_CPU_DISABLE_AVX512=false (default)|true
# VLLM_CPU_AVX2=false (default)|true (for cross-compilation)
# VLLM_CPU_AVX512=false (default)|true (for cross-compilation)
# VLLM_CPU_AVX512BF16=false (default)|true (for cross-compilation)
# VLLM_CPU_AVX512VNNI=false (default)|true (for cross-compilation)
# VLLM_CPU_AMXBF16=false (default)|true (for cross-compilation)
# VLLM_CPU_X86=false (default)|true (for cross-compilation)
# VLLM_CPU_ARM_BF16=false (default)|true (for cross-compilation)
#
######################### COMMON BASE IMAGE #########################
......@@ -35,7 +32,7 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
apt-get update -y \
&& apt-get install -y --no-install-recommends sudo ccache git curl wget ca-certificates \
gcc-12 g++-12 libtcmalloc-minimal4 libnuma-dev ffmpeg libsm6 libxext6 libgl1 jq lsof \
gcc-12 g++-12 libtcmalloc-minimal4 libnuma-dev ffmpeg libsm6 libxext6 libgl1 jq lsof make xz-utils \
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 \
&& curl -LsSf https://astral.sh/uv/install.sh | sh
......@@ -90,27 +87,25 @@ ARG max_jobs=32
ENV MAX_JOBS=${max_jobs}
ARG GIT_REPO_CHECK=0
# Support for building with non-AVX512 vLLM: docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" ...
ARG VLLM_CPU_DISABLE_AVX512=0
ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512}
# Support for cross-compilation with AVX2 ISA: docker build --build-arg VLLM_CPU_AVX2="1" ...
ARG VLLM_CPU_AVX2=0
ENV VLLM_CPU_AVX2=${VLLM_CPU_AVX2}
# Support for cross-compilation with AVX512 ISA: docker build --build-arg VLLM_CPU_AVX512="1" ...
ARG VLLM_CPU_AVX512=0
ENV VLLM_CPU_AVX512=${VLLM_CPU_AVX512}
# Support for building with AVX512BF16 ISA: docker build --build-arg VLLM_CPU_AVX512BF16="true" ...
ARG VLLM_CPU_AVX512BF16=0
ENV VLLM_CPU_AVX512BF16=${VLLM_CPU_AVX512BF16}
# Support for building with AVX512VNNI ISA: docker build --build-arg VLLM_CPU_AVX512VNNI="true" ...
ARG VLLM_CPU_AVX512VNNI=0
ENV VLLM_CPU_AVX512VNNI=${VLLM_CPU_AVX512VNNI}
# Support for building with AMXBF16 ISA: docker build --build-arg VLLM_CPU_AMXBF16="true" ...
ARG VLLM_CPU_AMXBF16=1
ENV VLLM_CPU_AMXBF16=${VLLM_CPU_AMXBF16}
# Support for cross-compilation with x86 ISA including AVX2 and AVX512: docker build --build-arg VLLM_CPU_X86="true" ...
ARG VLLM_CPU_X86=0
ENV VLLM_CPU_X86=${VLLM_CPU_X86}
# Support for cross-compilation with ARM BF16 ISA: docker build --build-arg VLLM_CPU_ARM_BF16="true" ...
ARG VLLM_CPU_ARM_BF16=0
ENV VLLM_CPU_ARM_BF16=${VLLM_CPU_ARM_BF16}
WORKDIR /vllm-workspace
# Validate build arguments - prevent mixing incompatible ISA flags
RUN if [ "$TARGETARCH" = "arm64" ] && [ "$VLLM_CPU_X86" != "0" ]; then \
echo "ERROR: Cannot use x86-specific ISA flags (AVX2, AVX512, etc.) when building for ARM64 (--platform=linux/arm64)"; \
exit 1; \
fi && \
if [ "$TARGETARCH" = "amd64" ] && [ "$VLLM_CPU_ARM_BF16" != "0" ]; then \
echo "ERROR: Cannot use ARM-specific ISA flags (ARM_BF16) when building for x86_64 (--platform=linux/amd64)"; \
exit 1; \
fi
# Copy build requirements
COPY requirements/cpu-build.txt requirements/build.txt
......@@ -160,7 +155,7 @@ WORKDIR /vllm-workspace
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
apt-get install -y --no-install-recommends vim numactl xz-utils make clangd-14
apt-get install -y --no-install-recommends vim numactl clangd-14
RUN ln -s /usr/bin/clangd-14 /usr/bin/clangd
......@@ -218,21 +213,29 @@ LABEL org.opencontainers.image.source="https://github.com/vllm-project/vllm"
# Build configuration labels
ARG TARGETARCH
ARG VLLM_CPU_DISABLE_AVX512
ARG VLLM_CPU_AVX2
ARG VLLM_CPU_AVX512
ARG VLLM_CPU_AVX512BF16
ARG VLLM_CPU_AVX512VNNI
ARG VLLM_CPU_AMXBF16
ARG VLLM_CPU_X86
ARG VLLM_CPU_ARM_BF16
ARG PYTHON_VERSION
LABEL ai.vllm.build.target-arch="${TARGETARCH}"
LABEL ai.vllm.build.cpu-disable-avx512="${VLLM_CPU_DISABLE_AVX512:-false}"
LABEL ai.vllm.build.cpu-avx2="${VLLM_CPU_AVX2:-false}"
LABEL ai.vllm.build.cpu-avx512="${VLLM_CPU_AVX512:-false}"
LABEL ai.vllm.build.cpu-avx512bf16="${VLLM_CPU_AVX512BF16:-false}"
LABEL ai.vllm.build.cpu-avx512vnni="${VLLM_CPU_AVX512VNNI:-false}"
LABEL ai.vllm.build.cpu-amxbf16="${VLLM_CPU_AMXBF16:-false}"
LABEL ai.vllm.build.cpu-x86="${VLLM_CPU_X86:-false}"
LABEL ai.vllm.build.cpu-arm-bf16="${VLLM_CPU_ARM_BF16:-false}"
LABEL ai.vllm.build.python-version="${PYTHON_VERSION:-3.12}"
ENTRYPOINT ["vllm", "serve"]
######################### ZEN CPU PYPI IMAGE #########################
FROM vllm-openai AS vllm-openai-zen
ARG TARGETARCH
RUN if [ "$TARGETARCH" != "amd64" ]; then \
echo "ERROR: vllm-openai-amd only supports --platform=linux/amd64"; \
exit 1; \
fi
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install "vllm[zen]"
ENTRYPOINT ["vllm", "serve"]
......@@ -217,13 +217,13 @@ RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2.
# build flashinfer for torch nightly from source around 10 mins
# release version: v0.6.3
# release version: v0.6.6
# todo(elainewy): cache flashinfer build result for faster build
ENV CCACHE_DIR=/root/.cache/ccache
RUN --mount=type=cache,target=/root/.cache/ccache \
--mount=type=cache,target=/root/.cache/uv \
echo "git clone flashinfer..." \
&& git clone --depth 1 --branch v0.6.3 --recursive https://github.com/flashinfer-ai/flashinfer.git \
&& git clone --depth 1 --branch v0.6.6 --recursive https://github.com/flashinfer-ai/flashinfer.git \
&& cd flashinfer \
&& git submodule update --init --recursive \
&& echo "finish git clone flashinfer..." \
......
......@@ -184,6 +184,34 @@ RUN cd /opt/rixl && mkdir -p /app/install && \
--ucx-plugins-dir ${UCX_HOME}/lib/ucx \
--nixl-plugins-dir ${RIXL_HOME}/lib/x86_64-linux-gnu/plugins
# DeepEP build stage
FROM base AS build_deep
ARG ROCSHMEM_BRANCH="ba0bf0f3"
ARG ROCSHMEM_REPO="https://github.com/ROCm/rocm-systems.git"
ARG DEEPEP_BRANCH="e84464ec"
ARG DEEPEP_REPO="https://github.com/ROCm/DeepEP.git"
ARG DEEPEP_NIC="cx7"
ENV ROCSHMEM_DIR=/opt/rocshmem
RUN git clone ${ROCSHMEM_REPO} \
&& cd rocm-systems \
&& git checkout ${ROCSHMEM_BRANCH} \
&& mkdir -p projects/rocshmem/build \
&& cd projects/rocshmem/build \
&& cmake .. \
-DCMAKE_INSTALL_PREFIX="${ROCSHMEM_DIR}" \
-DROCM_PATH=/opt/rocm \
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
-DUSE_EXTERNAL_MPI=OFF \
&& make -j \
&& make install
# Build DeepEP wheel.
# DeepEP looks for rocshmem at ROCSHMEM_DIR.
RUN git clone ${DEEPEP_REPO} \
&& cd DeepEP \
&& git checkout ${DEEPEP_BRANCH} \
&& python3 setup.py --variant rocm --nic ${DEEPEP_NIC} bdist_wheel --dist-dir=/app/deep_install
# -----------------------
# vLLM wheel release build stage (for building distributable wheels)
......@@ -305,6 +333,19 @@ RUN --mount=type=bind,from=export_vllm,src=/,target=/install \
RUN --mount=type=bind,from=build_rixl,src=/app/install,target=/rixl_install \
uv pip install --system /rixl_install/*.whl
# Install DeepEP wheel
RUN --mount=type=bind,from=build_deep,src=/app/deep_install,target=/deep_install \
uv pip install --system /deep_install/*.whl
COPY --from=build_deep /opt/rocshmem /opt/rocshmem
# RIXL/MoRIIO runtime dependencies (RDMA userspace libraries)
RUN apt-get update -q -y && apt-get install -q -y \
librdmacm1 \
libibverbs1 \
ibverbs-providers \
ibverbs-utils \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /vllm-workspace
ARG COMMON_WORKDIR
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace
......@@ -330,6 +371,11 @@ RUN bash /tmp/install_torchcodec.sh \
# Copy in the v1 package (for python-only install test group)
COPY --from=export_vllm /vllm_v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1
# Set MIOPEN ENVS to resolve performance regressions in MIOpen 3D convolution kernel
# See: https://github.com/pytorch/pytorch/issues/169857
ENV MIOPEN_DEBUG_CONV_DIRECT=0
ENV MIOPEN_DEBUG_CONV_GEMM=0
# Source code is used in the `python_only_compile.sh` test
# We hide it inside `src/` so that this source code
# will not be imported by other tests
......
......@@ -9,7 +9,7 @@ ARG PYTORCH_AUDIO_BRANCH="v2.9.0"
ARG PYTORCH_AUDIO_REPO="https://github.com/pytorch/audio.git"
ARG FA_BRANCH="0e60e394"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="6af8b687"
ARG AITER_BRANCH="v0.1.10.post2"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
ARG MORI_BRANCH="2d02c6a9"
ARG MORI_REPO="https://github.com/ROCm/mori.git"
......@@ -239,7 +239,7 @@ RUN pip install pyyaml && cd aiter \
export HIP_CLANG_PATH=/opt/sccache-wrappers \
&& sccache --show-stats; \
fi \
&& PREBUILD_KERNELS=1 GPU_ARCHS=${AITER_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist \
&& GPU_ARCHS=${AITER_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist \
&& if [ "$USE_SCCACHE" = "1" ]; then sccache --show-stats; fi \
&& ls /app/aiter/dist/*.whl
RUN mkdir -p /app/install && cp /app/aiter/dist/*.whl /app/install
......
......@@ -6,8 +6,7 @@ ARG PYTHON_VERSION=3.12
ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/xpu"
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \
echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list && \
add-apt-repository -y ppa:kobuk-team/intel-graphics
echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
RUN apt clean && apt-get update -y && \
apt-get install -y --no-install-recommends --fix-missing \
......@@ -28,9 +27,22 @@ RUN apt clean && apt-get update -y && \
python3-pip
RUN apt update && apt upgrade -y && \
apt install -y libze1 libze-dev libze-intel-gpu1 intel-opencl-icd libze-intel-gpu-raytracing intel-ocloc && \
apt install -y intel-oneapi-compiler-dpcpp-cpp-2025.3
# Install UMD
RUN mkdir neo && \
cd neo && \
wget https://github.com/intel/intel-graphics-compiler/releases/download/v2.24.8/intel-igc-core-2_2.24.8+20344_amd64.deb && \
wget https://github.com/intel/intel-graphics-compiler/releases/download/v2.24.8/intel-igc-opencl-2_2.24.8+20344_amd64.deb && \
wget https://github.com/intel/compute-runtime/releases/download/25.48.36300.8/intel-ocloc_25.48.36300.8-0_amd64.deb && \
wget https://github.com/intel/compute-runtime/releases/download/25.48.36300.8/intel-opencl-icd_25.48.36300.8-0_amd64.deb && \
wget https://github.com/intel/compute-runtime/releases/download/25.48.36300.8/libigdgmm12_22.8.2_amd64.deb && \
wget https://github.com/intel/compute-runtime/releases/download/25.48.36300.8/libze-intel-gpu1_25.48.36300.8-0_amd64.deb && \
wget https://github.com/oneapi-src/level-zero/releases/download/v1.26.0/level-zero_1.26.0+u24.04_amd64.deb && \
dpkg -i *.deb && \
cd .. && \
rm -rf neo
ENV PATH="/root/.local/bin:$PATH"
ENV VIRTUAL_ENV="/opt/venv"
ENV UV_PYTHON_INSTALL_DIR=/opt/uv/python
......@@ -103,9 +115,57 @@ RUN --mount=type=cache,target=/root/.cache/uv \
# install development dependencies (for testing)
RUN uv pip install -e tests/vllm_test_utils
# install nixl from source code
ENV NIXL_VERSION=0.7.0
RUN python /workspace/vllm/tools/install_nixl_from_source_ubuntu.py
# install NIXL and UCX from source code
ARG UCX_VERSION=e5d98879705239d254ede40b4a52891850cb5349
ARG NIXL_VERSION=0.7.0
RUN apt-get update && apt-get install -y \
pciutils \
net-tools \
iproute2 \
hwloc \
numactl \
wget \
curl \
git \
build-essential \
autoconf \
automake \
libtool \
pkg-config \
rdma-core \
libibverbs-dev \
ibverbs-utils \
libibverbs1 \
librdmacm-dev \
librdmacm1 \
libibumad-dev \
libibumad3 \
libibmad-dev \
libibmad5 \
infiniband-diags \
perftest \
ibutils \
libmlx5-1 \
libmlx4-1 \
ibverbs-providers \
librdmacm1t64
ENV PKG_CONFIG_PATH=/tmp/ucx_install/lib/pkgconfig:${PKG_CONFIG_PATH}
ENV LD_LIBRARY_PATH=/tmp/ucx_install/lib:${LD_LIBRARY_PATH}
RUN --mount=type=cache,target=/root/.cache/uv \
git clone https://github.com/openucx/ucx /tmp/ucx_source && \
cd /tmp/ucx_source && git checkout "${UCX_VERSION}" && \
bash autogen.sh && \
./configure --prefix=/tmp/ucx_install --with-ze=yes --enable-examples --enable-mt && \
make CFLAGS="-Wno-error=incompatible-pointer-types" -j8 && make install && \
git clone https://github.com/ai-dynamo/nixl /tmp/nixl_source && \
cd /tmp/nixl_source && git checkout "${NIXL_VERSION}" && \
cd /tmp/nixl_source && \
uv pip install --upgrade meson pybind11 patchelf && \
uv pip install -r requirements.txt && \
uv pip install . && \
rm -rf /tmp/ucx_source /tmp/nixl_source
# FIX triton
RUN --mount=type=cache,target=/root/.cache/uv \
......
......@@ -52,9 +52,6 @@
"DEEPGEMM_GIT_REF": {
"default": "477618cd51baffca09c4b0b87e97c03fe827ef03"
},
"PPLX_COMMIT_HASH": {
"default": "12cecfd"
},
"DEEPEP_COMMIT_HASH": {
"default": "73b6ea4"
},
......@@ -68,7 +65,7 @@
"default": "true"
},
"FLASHINFER_VERSION": {
"default": "0.6.3"
"default": "0.6.6"
},
"GDRCOPY_CUDA_VERSION": {
"default": "12.8"
......@@ -86,7 +83,7 @@
"default": ">=1.0.17"
},
"RUNAI_MODEL_STREAMER_VERSION": {
"default": ">=0.15.3"
"default": ">=0.15.7"
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment