Commit dcb5624a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.5' into v0.8.5-dev

parents 55880ca2 ba41cc90
......@@ -269,6 +269,12 @@ void advance_step_flashinfer(
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
// void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
// torch::Tensor const& q_pe,
// torch::Tensor const& kv_c_and_k_pe_cache,
// torch::Tensor const& seq_lens,
// torch::Tensor const& page_table, double scale);
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
#ifndef USE_ROCM
......
......@@ -46,14 +46,26 @@ __global__ void compute_expert_offsets(
}
__global__ void compute_arg_sorts(const int* __restrict__ topk_ids,
const int32_t* __restrict__ expert_offsets,
int32_t* input_permutation,
int32_t* output_permutation,
int32_t* atomic_buffer, const int topk_length,
const int topk) {
int expert_id = blockIdx.x;
int const blk_expert_id = blockIdx.x;
int const num_experts = gridDim.x;
int32_t const num_tokens = expert_offsets[num_experts];
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
if (topk_ids[i] == expert_id) {
int const expert_id = topk_ids[i];
if (expert_id == -1 && blockIdx.x == 0) {
// output_permutation is used to re-order the moe outputs. It is
// used as c2 = c2[c_map], where c2 is a torch.tensor that is the
// output of the cutlass kernels and c_map is the output_permutation.
// c2 is initialized to zeros, therefore by setting the output_permutation
// to num_tokens, we are guaranteed to fill the moe outputs to zero
// for "invalid" topk_ids.
output_permutation[i] = num_tokens;
} else if (expert_id == blk_expert_id) {
int start = atomicAdd(&atomic_buffer[expert_id], 1);
input_permutation[start] = i / topk;
output_permutation[i] = start;
......@@ -83,6 +95,7 @@ void get_cutlass_moe_mm_data_caller(
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<const int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(input_permutation.data_ptr()),
static_cast<int32_t*>(output_permutation.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
......
......@@ -336,7 +336,7 @@ inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
if (mp2 <= 16) {
// M in [1, 16]
......
......@@ -321,7 +321,7 @@ inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
if (mp2 <= 16) {
// M in [1, 16]
......
......@@ -134,7 +134,7 @@ typename T::Gemm::Arguments args_from_options(
using StrideB = typename T::StrideB;
using StrideD = typename T::StrideD;
using Sm100BlkScaledConfig =
typename T::Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
int m = static_cast<int>(M);
int n = static_cast<int>(N);
......
......@@ -9,7 +9,11 @@
#include <cuda_runtime.h>
#include <iostream>
namespace marlin {
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin
#endif
namespace MARLIN_NAMESPACE_NAME {
// Marlin params
......@@ -23,6 +27,7 @@ static constexpr int pipe_stages =
static constexpr int min_thread_n = 64;
static constexpr int min_thread_k = 64;
static constexpr int max_thread_n = 256;
static constexpr int tile_size = 16;
static constexpr int max_par = 16;
......@@ -84,4 +89,4 @@ __device__ inline void cp_async_wait() {
#endif
} // namespace marlin
} // namespace MARLIN_NAMESPACE_NAME
......@@ -5,7 +5,11 @@
#include <cuda_fp16.h>
#include <cuda_bf16.h>
namespace marlin {
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin
#endif
namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t>
class ScalarType {};
......@@ -54,7 +58,7 @@ class ScalarType<nv_bfloat16> {
using FragS = Vec<nv_bfloat162, 1>;
using FragZP = Vec<nv_bfloat162, 4>;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
static __device__ float inline num2float(const nv_bfloat16 x) {
return __bfloat162float(x);
}
......@@ -74,6 +78,6 @@ class ScalarType<nv_bfloat16> {
#endif
};
} // namespace marlin
} // namespace MARLIN_NAMESPACE_NAME
#endif
......@@ -2,6 +2,15 @@
#include <torch/all.h>
torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
const int64_t rows_per_block);
torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
const int64_t CuCount);
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount);
void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
torch::Tensor& max_logits, torch::Tensor& tmp_out,
torch::Tensor& query, torch::Tensor& key_cache,
......
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <stdexcept>
#include <algorithm>
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "quantization/fp8/common.cuh"
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__))
#define __HIP__MI300_MI250__
#endif
#if defined(__HIPCC__) && defined(__gfx942__)
#define __HIP__MI300__
#endif
#if defined(NDEBUG)
#undef NDEBUG
#include <assert.h>
#define UNREACHABLE_CODE assert(false);
#define NDEBUG
#else
#define UNREACHABLE_CODE assert(false);
#endif
template <typename T>
struct scalar {};
template <typename T>
struct scalar2 {};
template <typename T>
__device__ __forceinline__ float2 __s22float2(T v);
template <typename T>
__device__ __forceinline__ T __float2s(float v);
template <typename T>
__device__ __forceinline__ T __float22s2_rn(float2 v);
// Definitions and cvt functions for fp16
template <>
struct scalar<c10::Half> {
using type = half;
};
template <>
struct scalar2<c10::Half> {
using type = __half2;
};
template <>
__device__ __forceinline__ half __float2s(float v) {
return __float2half(v);
}
template <>
__device__ __forceinline__ float2 __s22float2(__half2 v) {
return __half22float2(v);
}
template <>
__device__ __forceinline__ __half2 __float22s2_rn(float2 v) {
return __float22half2_rn(v);
}
// Definitions and cvt functions for bf16
template <>
struct scalar<c10::BFloat16> {
using type = __hip_bfloat16;
};
template <>
struct scalar2<c10::BFloat16> {
using type = __hip_bfloat162;
};
template <>
__device__ __forceinline__ __hip_bfloat16 __float2s(float v) {
return __float2bfloat16(v);
}
template <>
__device__ __forceinline__ float2 __s22float2(__hip_bfloat162 v) {
return __bfloat1622float2(v);
}
template <>
__device__ __forceinline__ __hip_bfloat162 __float22s2_rn(float2 v) {
return __float22bfloat162_rn(v);
}
template <typename T>
__device__ __forceinline__ T loadnt(T* addr) {
return __builtin_nontemporal_load(addr);
}
__device__ __forceinline__ float4 load_ntmprl(const float4* addr) {
auto addr_alias = reinterpret_cast<const float*>(addr);
auto dat0 = loadnt(addr_alias);
auto dat1 = loadnt(addr_alias + 1);
auto dat2 = loadnt(addr_alias + 2);
auto dat3 = loadnt(addr_alias + 3);
return make_float4(dat0, dat1, dat2, dat3);
}
// TBlock fetches entire rows of A, and entire col of B (K dimension); assume
// N=1 for time being grid is M/A_NUM_ROWS blocks
template <typename scalar_t, int NUM_A_ROWS_PER_BLOCK>
__global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b,
scalar_t* out_c, const int K) {
using scalar2_t = typename scalar2<scalar_t>::type;
auto af4 = reinterpret_cast<const float4*>(in_a);
auto bf4 = reinterpret_cast<const scalar2_t*>(in_b);
auto c = reinterpret_cast<scalar2_t*>(out_c);
__shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE];
const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK * K / 8;
const int threadid = threadIdx.x;
const int warp = threadIdx.x / WARP_SIZE;
const int lane = threadIdx.x % WARP_SIZE;
const int num_warps = blockDim.x / WARP_SIZE;
const int qwarpid = threadid / num_warps;
const int qthreadid = threadid % num_warps;
float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK];
scalar2_t colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w;
float acc[NUM_A_ROWS_PER_BLOCK];
scalar2_t acch2;
scalar2_t oval;
// As we later use warp shuffle operations, we may have more threads in the
// block than the actual available data, hence the if guard here.
if (threadid * 8 < K) {
#pragma unroll
for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) {
// rowA_elem4[i] holds 8 * half numbers seen as a single float4.
rowA_elem4[i] = load_ntmprl(&af4[row_addr + threadid + K / 8 * i]);
}
}
colB_elem4x = bf4[threadid * 4 + 0];
colB_elem4y = bf4[threadid * 4 + 1];
colB_elem4z = bf4[threadid * 4 + 2];
colB_elem4w = bf4[threadid * 4 + 3];
scalar2_t Af2;
[[maybe_unused]] scalar2_t Bf2;
float2 S;
auto Ah2ptr = reinterpret_cast<scalar2_t*>(&rowA_elem4);
scalar2_t* ah2lptr;
#pragma unroll
for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) {
// Multiply-add on 8 scalar_t.
ah2lptr = Ah2ptr + i * 4;
Af2 = *(ah2lptr);
acch2 = __hmul2(Af2, colB_elem4x);
Af2 = *(ah2lptr + 1);
acch2 = __hfma2(Af2, colB_elem4y, acch2);
Af2 = *(ah2lptr + 2);
acch2 = __hfma2(Af2, colB_elem4z, acch2);
Af2 = *(ah2lptr + 3);
acch2 = __hfma2(Af2, colB_elem4w, acch2);
S = __s22float2(acch2);
// See comment above concerning the if guard.
acc[i] = (threadid * 8 < K ? S.x + S.y : 0.f);
}
// all reduce across warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
#pragma unroll
for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) {
acc[i] += __shfl_xor(acc[i], mask);
}
}
// Warp leaders store the data to shared memory.
if (lane < NUM_A_ROWS_PER_BLOCK) {
red_smem[lane][warp] = acc[lane];
}
// Make sure the data is in shared memory.
__syncthreads();
if (qwarpid < NUM_A_ROWS_PER_BLOCK) {
acc[qwarpid] = qthreadid < num_warps ? red_smem[qwarpid][qthreadid] : 0.f;
for (int mask = num_warps / 2; mask >= 1; mask /= 2) {
acc[qwarpid] += __shfl_xor(acc[qwarpid], mask);
}
float oval2 = __shfl_xor(acc[qwarpid], num_warps);
if (lane % (num_warps * 2) == 0) {
oval = __float22s2_rn<scalar2_t>(make_float2(acc[qwarpid], oval2));
c[blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 + qwarpid / 2] = oval;
}
}
}
torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
const int64_t rows_per_block) {
auto M = in_a.size(0);
auto K = in_a.size(1);
auto N = in_b.size(0);
TORCH_CHECK(N == 1, "Row number of activation tensor must be 1.");
TORCH_CHECK(in_a.dtype() == in_b.dtype());
TORCH_CHECK(in_b.dtype() == torch::kFloat16 ||
in_b.dtype() == torch::kBFloat16);
auto out_c = torch::empty(
{N, M}, torch::TensorOptions().dtype(in_b.dtype()).device(in_b.device()));
// NUM_TREADS need to be a multiple of WARP_SIZE, as we are using warp shuffle
// operations.
const int NUM_THREADS =
K * 2 / 16 % WARP_SIZE == 0
? K * 2 / 16
: K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE);
int NUM_BLOCKS = M / rows_per_block;
const at::cuda::OptionalCUDAGuard device_guard(device_of(in_b));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// call the kernel function...
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "LLGemm1", [&] {
auto a_ptr = in_a.data_ptr<scalar_t>();
auto b_ptr = in_b.data_ptr<scalar_t>();
auto c_ptr = out_c.data_ptr<scalar_t>();
if (rows_per_block == 2) {
LLGemm1_kernel<scalar_t, 2>
<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(a_ptr, b_ptr, c_ptr, K);
} else if (rows_per_block == 4) {
LLGemm1_kernel<scalar_t, 4>
<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(a_ptr, b_ptr, c_ptr, K);
} else if (rows_per_block == 8) {
LLGemm1_kernel<scalar_t, 8>
<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(a_ptr, b_ptr, c_ptr, K);
} else if (rows_per_block == 16) {
LLGemm1_kernel<scalar_t, 16>
<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(a_ptr, b_ptr, c_ptr, K);
} else {
NUM_BLOCKS = M / 4;
LLGemm1_kernel<scalar_t, 4>
<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(a_ptr, b_ptr, c_ptr, K);
}
});
return out_c;
}
#define DOT2C(V0, V2, V3) \
if constexpr (std::is_same_v<scalar_t, half>) { \
asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(V0) : "0"(V0), "v"(V2), "v"(V3)); \
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) { \
float2 s = __bfloat1622float2(*((__hip_bfloat162*)(&(V2)))) * \
__bfloat1622float2(*((__hip_bfloat162*)(&(V3)))); \
V0 += (s.x + s.y); \
}
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
// This version targets cases where A[] fits LDS capacity
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
using scalar8 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
union bigType {
scalar_t h[A_CHUNK];
float f[A_CHUNK / 2];
float2 f2[A_CHUNK / 4];
double d[A_CHUNK / 4];
scalar8 h8;
};
//----------------------------------------------------
// Reserving 64 KB of LDS to have 1 WG / CU
// Goal is to bring the activation matrix A to the LDS
// and use it across the lifetime of the work group
// TODO: When activation matrix is larger than 64 KB
// then this is not goint to work!
//----------------------------------------------------
__shared__ scalar_t s[1024 * 32];
//----------------------------------------------------
// Fetch the activation matrix to LDS
// Loop iteration:
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements
// - Each WG will fetch 512 * 16 => 8K elements
// - Then the WG will move to another 8 K elements
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
k += THRDS * WvPrGrp * A_CHUNK) {
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
if (k_in >= min(K * N, 32 * 1024)) break;
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
}
__syncthreads();
if (threadIdx.y >= _WvPrGrp) return;
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
float sum[N][YTILE];
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
// There are 16 waves per WG, and hence, each WG is
// working on 16 columns of weight matrix. Moreover,
// we tile in column direction by YTILE, so when YTILE=1
// the above math is right, however, when YTILE=2 then
// each wave will be working on 2 columns and WG will
// be working on 32 columns.
//
// Top level loop that makes WGs persistent!
// - WGs iterates across columns of weight matrix
// - Each wave within WG works on a given column(s)
// - After completing first set of columns, WGs start
// working on the next set of available columns
//----------------------------------------------------
while (m < M) {
//----------------------------------------------------
// 'sum' accumulates the matrix A x B computation
// split across 64 lanes.
//
// YTILE represents how many column of weight matrix
// are being worked on by each wave.
//----------------------------------------------------
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++) sum[n][i] = 0;
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
//----------------------------------------------------
// Fetch weight matrix B in interleaved K-split!
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements (1024B)
// - YTILE represents the number of column being serviced
// by wave
// - Loop for fetching weight matrix (B) are unrolled
//
// Fetch activation matrix A from LDS
// - Loop for fetching activation matrix (A) are unrolled
//
// Finally, do the matrix multiplication in an unrolled
// fashion. This provides lot of food for compiler
// scheduling.
//
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
// for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
// Fetch the weight matrix from memory!
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
const scalar_t* B_ = &B[(m + 0) * K + k_];
bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K])));
//----------------------------------------------------
// The following code with YTILE > 1 has to be deleted
//----------------------------------------------------
if constexpr (YTILE >= 2)
bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K])));
if constexpr (YTILE >= 3)
bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K])));
if constexpr (YTILE >= 4)
bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K])));
if constexpr (YTILE >= 5)
bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K])));
if constexpr (YTILE >= 6)
bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K])));
if constexpr (YTILE >= 7)
bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K])));
if constexpr (YTILE >= 8)
bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
// Fetch A activation matrix in interleaved fashion from LDS or memory
for (int n = 0; n < N; n++) {
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for (uint32_t n = 0; n < N; n++) {
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b])
//----------------------------------------------------
// The following code with YTILE > 1
//----------------------------------------------------
if constexpr (YTILE >= 2) {
DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]);
}
if constexpr (YTILE >= 3) {
DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]);
}
if constexpr (YTILE >= 4) {
DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]);
}
if constexpr (YTILE >= 5) {
DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]);
}
if constexpr (YTILE >= 6) {
DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]);
}
if constexpr (YTILE >= 7) {
DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]);
}
if constexpr (YTILE >= 8) {
DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]);
}
}
}
}
}
//----------------------------------------------------
// Final reduction step using shuffle
//----------------------------------------------------
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
}
}
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
// if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]);
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
}
}
m += CuCount * _WvPrGrp * YTILE;
}
}
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
// This version targets cases where A[] marginally exceeds LDS capacity
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
using scalar8 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
union bigType {
scalar_t h[A_CHUNK];
float f[A_CHUNK / 2];
float2 f2[A_CHUNK / 4];
double d[A_CHUNK / 4];
scalar8 h8;
};
//----------------------------------------------------
// Reserving 64 KB of LDS to have 1 WG / CU
// Goal is to bring the activation matrix A to the LDS
// and use it across the lifetime of the work group
// TODO: When activation matrix is larger than 64 KB
// then this is not goint to work!
//----------------------------------------------------
__shared__ scalar_t s[1024 * 32];
//----------------------------------------------------
// Computation of columns that need to be committed to memory!
//----------------------------------------------------
uint32_t commitColumn[YTILE];
for (uint32_t i = 0; i < YTILE; i++) {
commitColumn[i] = 1;
}
//----------------------------------------------------
// Indexing function into the column of weight matrix B
// Algorithm does 64 lane k-splitting / wave and uses
// WG ID and Thread ID to find the index.
//----------------------------------------------------
// int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp);
uint32_t m = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE;
// Check whether there will be fragmenation!
// This will happen only for the last wave!
if (m < M && (m + YTILE) >= M) {
uint32_t startColumn = M - YTILE;
for (uint32_t i = 0; i < (m - startColumn); i++) {
commitColumn[i] = 0;
}
m = startColumn;
}
//----------------------------------------------------
// Fetch the activation matrix to LDS
// Loop iteration:
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements
// - Each WG will fetch 512 * 16 => 8K elements
// - Then the WG will move to another 8 K elements
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
k += THRDS * WvPrGrp * A_CHUNK) {
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
if (k_in >= min(K * N, 32 * 1024)) break;
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
}
__syncthreads();
if (threadIdx.y >= _WvPrGrp) return;
float sum[N][YTILE];
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
// There are 16 waves per WG, and hence, each WG is
// working on 16 columns of weight matrix. Moreover,
// we tile in column direction by YTILE, so when YTILE=1
// the above math is right, however, when YTILE=2 then
// each wave will be working on 2 columns and WG will
// be working on 32 columns.
//
// Top level loop that makes WGs persistent!
// - WGs iterates across columns of weight matrix
// - Each wave within WG works on a given column(s)
// - After completing first set of columns, WGs start
// working on the next set of available columns
//----------------------------------------------------
while (m < M) {
//----------------------------------------------------
// 'sum' accumulates the matrix A x B computation
// split across 64 lanes.
//
// YTILE represents how many column of weight matrix
// are being worked on by each wave.
//----------------------------------------------------
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++) sum[n][i] = 0;
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
//----------------------------------------------------
// Fetch weight matrix B in interleaved K-split!
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements (1024B)
// - YTILE represents the number of column being serviced
// by wave
// - Loop for fetching weight matrix (B) are unrolled
//
// Fetch activation matrix A from LDS
// - Loop for fetching activation matrix (A) are unrolled
//
// Finally, do the matrix multiplication in an unrolled
// fashion. This provides lot of food for compiler
// scheduling.
//
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
// Fetch the weight matrix from memory!
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
const scalar_t* B_ = &B[(m + 0) * K + k_];
bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K])));
//----------------------------------------------------
// The following code with YTILE > 1 has to be deleted
//----------------------------------------------------
if constexpr (YTILE >= 2)
bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K])));
if constexpr (YTILE >= 3)
bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K])));
if constexpr (YTILE >= 4)
bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K])));
if constexpr (YTILE >= 5)
bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K])));
if constexpr (YTILE >= 6)
bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K])));
if constexpr (YTILE >= 7)
bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K])));
if constexpr (YTILE >= 8)
bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
// Fetch A activation matrix in interleaved fashion from LDS or memory
for (int n = 0; n < N; n++) {
if (k_ + K * n < 32 * 1024)
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
else
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for (uint32_t n = 0; n < N; n++) {
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]);
//----------------------------------------------------
// The following code with YTILE > 1
//----------------------------------------------------
if constexpr (YTILE >= 2) {
DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]);
}
if constexpr (YTILE >= 3) {
DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]);
}
if constexpr (YTILE >= 4) {
DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]);
}
if constexpr (YTILE >= 5) {
DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]);
}
if constexpr (YTILE >= 6) {
DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]);
}
if constexpr (YTILE >= 7) {
DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]);
}
if constexpr (YTILE >= 8) {
DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]);
}
}
}
}
}
//----------------------------------------------------
// Final reduction step using shuffle
//----------------------------------------------------
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
}
}
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i])
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
}
}
m += CuCount * _WvPrGrp * YTILE;
// Check whether there will be fragmenation!
// This will happen only for the last wave!
if (m < M && (m + YTILE) >= M) {
uint32_t startColumn = M - YTILE;
for (uint32_t i = 0; i < (m - startColumn); i++) {
commitColumn[i] = 0;
}
m = startColumn;
}
}
}
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
// This version targets big A[] cases, where it is much larger than LDS capacity
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
using scalar8 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
union bigType {
scalar_t h[A_CHUNK];
float f[A_CHUNK / 2];
float2 f2[A_CHUNK / 4];
double d[A_CHUNK / 4];
scalar8 h8;
};
//----------------------------------------------------
// Reserving 64 KB of LDS to have 1 WG / CU
// Goal is to bring the activation matrix A to the LDS
// and use it across the lifetime of the work group
// TODO: When activation matrix is larger than 64 KB
// then this is not goint to work!
//----------------------------------------------------
__shared__ scalar_t s[1024 * 32];
//----------------------------------------------------
// Computation of columns that need to be committed to memory!
//----------------------------------------------------
uint32_t commitColumn[YTILE];
for (uint32_t i = 0; i < YTILE; i++) {
commitColumn[i] = 1;
}
// int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp);
if (threadIdx.y >= _WvPrGrp) return;
//----------------------------------------------------
// Indexing function into the column of weight matrix B
// Algorithm does 64 lane k-splitting / wave and uses
// WG ID and Thread ID to find the index.
//----------------------------------------------------
uint32_t m = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE;
// Check whether there will be fragmenation!
// This will happen only for the last wave!
if (m < M && (m + YTILE) >= M) {
uint32_t startColumn = M - YTILE;
for (uint32_t i = 0; i < (m - startColumn); i++) {
commitColumn[i] = 0;
}
m = startColumn;
}
//----------------------------------------------------
// Fetch the activation matrix to LDS
// Loop iteration:
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements
// - Each WG will fetch 512 * 16 => 8K elements
// - Then the WG will move to another 8 K elements
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
#define PCML
#ifndef PCML
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
k += THRDS * WvPrGrp * A_CHUNK) {
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
if (k_in >= min(K * N, 32 * 1024)) break;
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
}
__syncthreads();
#endif
#define TUC (THRDS * UNRL * A_CHUNK)
uint32_t kBase = 0;
// find biggest k size that fits in LDS
uint32_t kFit = (32 * 1024) / N;
// kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple
// of TUC
kFit = (kFit % TUC == 0)
? kFit
: (kFit - kFit % TUC); // round up to multiple of TUC
// if (kFit == 0) kFit = TUC;
kFit = min(kFit, K);
float sum[N][YTILE];
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
// There are 16 waves per WG, and hence, each WG is
// working on 16 columns of weight matrix. Moreover,
// we tile in column direction by YTILE, so when YTILE=1
// the above math is right, however, when YTILE=2 then
// each wave will be working on 2 columns and WG will
// be working on 32 columns.
//
// Top level loop that makes WGs persistent!
// - WGs iterates across columns of weight matrix
// - Each wave within WG works on a given column(s)
// - After completing first set of columns, WGs start
// working on the next set of available columns
//----------------------------------------------------
#ifdef PCML
int YW = (YTILE * _WvPrGrp);
uint32_t Mrndp = (M % YW == 0) ? M : (M - M % YW + YW);
while (m < Mrndp) {
#else
while (m < M) {
#endif
//----------------------------------------------------
// 'sum' accumulates the matrix A x B computation
// split across 64 lanes.
//
// YTILE represents how many column of weight matrix
// are being worked on by each wave.
//----------------------------------------------------
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++) sum[n][i] = 0;
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
//----------------------------------------------------
// Fetch weight matrix B in interleaved K-split!
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements (1024B)
// - YTILE represents the number of column being serviced
// by wave
// - Loop for fetching weight matrix (B) are unrolled
//
// Fetch activation matrix A from LDS
// - Loop for fetching activation matrix (A) are unrolled
//
// Finally, do the matrix multiplication in an unrolled
// fashion. This provides lot of food for compiler
// scheduling.
//
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
#ifdef PCML
if ((k1 == 0) || (k1 == kBase + kFit)) { // load next chunk of A[] to LDS
if (k1 != 0) kBase += kFit;
__syncthreads();
for (uint32_t k = 0; k < kFit; k += THRDS * _WvPrGrp * A_CHUNK) {
uint32_t kOff = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
if (kBase + kOff >= K) break;
if (kOff >= kFit) break;
for (uint32_t n = 0; n < N; n++) {
uint32_t k_in = kBase + n * K + kOff;
uint32_t k_ot = n * kFit + kOff;
*((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in]));
}
}
__syncthreads();
}
if (m >= M) continue;
#endif
// Fetch the weight matrix from memory!
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
const scalar_t* B_ = &B[(m + 0) * K + k_];
bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K])));
//----------------------------------------------------
// The following code with YTILE > 1 has to be deleted
//----------------------------------------------------
if constexpr (YTILE >= 2)
bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K])));
if constexpr (YTILE >= 3)
bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K])));
if constexpr (YTILE >= 4)
bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K])));
if constexpr (YTILE >= 5)
bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K])));
if constexpr (YTILE >= 6)
bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K])));
if constexpr (YTILE >= 7)
bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K])));
if constexpr (YTILE >= 8)
bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
// Fetch A activation matrix in interleaved fashion from LDS or memory
for (int n = 0; n < N; n++) {
#ifdef PCML
bigA[n][k2] = *((const bigType*)(&(s[k_ - kBase + kFit * n])));
#else
if (k_ + K * n < 32 * 1024)
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
else
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
#endif
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
#pragma unroll
for (uint32_t n = 0; n < N; n++) {
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]);
//----------------------------------------------------
// The following code with YTILE > 1
//----------------------------------------------------
if constexpr (YTILE >= 2) {
DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]);
}
if constexpr (YTILE >= 3) {
DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]);
}
if constexpr (YTILE >= 4) {
DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]);
}
if constexpr (YTILE >= 5) {
DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]);
}
if constexpr (YTILE >= 6) {
DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]);
}
if constexpr (YTILE >= 7) {
DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]);
}
if constexpr (YTILE >= 8) {
DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]);
}
}
}
}
}
#ifdef PCML
if (m >= M) {
m += CuCount * _WvPrGrp * YTILE;
kBase = 0;
continue;
}
#endif
//----------------------------------------------------
// Final reduction step using shuffle
//----------------------------------------------------
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
}
}
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i])
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
}
}
m += CuCount * _WvPrGrp * YTILE;
kBase = 0;
// Check whether there will be fragmenation!
// This will happen only for the last wave!
if (m < M && (m + YTILE) >= M) {
uint32_t startColumn = M - YTILE;
for (uint32_t i = 0; i < (m - startColumn); i++) {
commitColumn[i] = 0;
}
m = startColumn;
}
}
}
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
int mindiv(int N, int div1, int div2) {
int nPrRnd = div1 * div2;
int rnds0 = N / nPrRnd;
nPrRnd -= div1 * 3;
int rnds3 = N / nPrRnd;
nPrRnd -= div1;
int rnds4 = N / nPrRnd;
nPrRnd -= div1;
int rnds5 = N / nPrRnd;
nPrRnd -= div1;
int rnds6 = N / nPrRnd;
nPrRnd -= div1;
int rnds7 = N / nPrRnd;
nPrRnd -= div1;
int rnds8 = N / nPrRnd;
nPrRnd -= div1;
int rnds9 = N / nPrRnd;
nPrRnd -= div1;
int rtn = div2;
if (rnds0 == rnds3) rtn = div2 - 3;
if (rnds0 == rnds4) rtn = div2 - 4;
if (rnds0 == rnds5) rtn = div2 - 5;
if (rnds0 == rnds6) rtn = div2 - 6;
if (rnds0 == rnds7) rtn = div2 - 7;
if (rnds0 == rnds8) rtn = div2 - 8;
if (rnds0 == rnds9) rtn = div2 - 9;
return rtn;
}
torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
const int64_t CuCount) {
auto M_in = in_a.size(0);
auto K_in = in_a.size(1);
auto N_in = in_b.size(0);
TORCH_CHECK(in_a.dtype() == in_b.dtype());
TORCH_CHECK(K_in % 8 == 0, "k % 8 == 0");
TORCH_CHECK(in_a.dtype() == torch::kFloat16 ||
in_a.dtype() == torch::kBFloat16);
auto out_c = torch::empty(
{N_in, M_in},
torch::TensorOptions().dtype(in_b.dtype()).device(in_b.device()));
dim3 grid(CuCount);
const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
_N) \
{ \
dim3 block(64, _WvPrGrp); \
if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
CuCount); \
} else if (K_in * N_in <= 32 * 1024 * 1.2) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
CuCount); \
} else { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \
wvSplitK_hf_big_<fptype, 64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
CuCount); \
} \
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] {
using fptype = typename scalar<scalar_t>::type;
fptype* af4 = reinterpret_cast<fptype*>(in_a.data_ptr());
const fptype* bf4 = reinterpret_cast<const fptype*>(in_b.data_ptr());
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
switch (N_in) {
case 1:
WVSPLITK(16, 2, 2, 2, 2, 2, 2, 1)
break;
case 2:
WVSPLITK(16, 2, 2, 2, 2, 2, 2, 2)
break;
case 3:
WVSPLITK(16, 4, 7, 7, 1, 1, 1, 3)
break;
case 4:
WVSPLITK(16, 4, 7, 7, 1, 1, 1, 4)
break;
default:
throw std::runtime_error(
"Unsupported N value: " + std::to_string(M_in) + "," +
std::to_string(K_in) + "," + std::to_string(N_in));
}
});
return out_c;
}
#if defined(__HIP__MI300__) // TODO: Add NAVI support
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const fp8_t* B,
const fp8_t* __restrict__ A, scalar_t* C,
const float* __restrict__ s_A,
const float* __restrict__ s_B, const int _WvPrGrp,
const int CuCount) {
using scalar8 =
__attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float;
using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int;
using intx4 = __attribute__((__vector_size__(4 * sizeof(int)))) int;
union bigType {
char f8[A_CHUNK];
char2 c2[A_CHUNK / 2];
scalar_t h[A_CHUNK / 2];
float f[A_CHUNK / 4];
int i[A_CHUNK / 4];
long l[A_CHUNK / 8];
intx4 l2[A_CHUNK / 16];
scalar8 h8;
};
__shared__ fp8_t s[1024 * 64];
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) {
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
}
__syncthreads();
if (threadIdx.y >= _WvPrGrp) return;
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
floatx16 sum[N][YTILE];
float sA = *s_A;
float sB = *s_B;
while (m < M) {
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++) sum[n][i] = {0.f};
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
#pragma unroll
for (uint32_t n = 0; n < N; ++n) bigA[n][k2].h8 = {0.f};
#pragma unroll
for (uint32_t y = 0; y < YTILE; ++y) bigB[y][k2].h8 = {0.f};
}
// Fetch the weight matrix from memory!
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
const fp8_t* B_ = &B[(m + 0) * Kp + k_];
#pragma unroll
for (uint32_t y = 0; y < YTILE; ++y) {
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * Kp])));
}
}
// Fetch activation matrix from either just LDS or from both LDS / memory
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
for (int n = 0; n < N; n++) {
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
if (k >= K) break;
for (uint32_t n = 0; n < N; n++) {
for (int i = 0; i < A_CHUNK; i += 8) {
for (int y = 0; y < YTILE; ++y) {
sum[n][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bigA[n][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[n][y], 0, 0,
0);
}
}
}
}
}
// Final reduction
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
float accm0 = sum[n][y][0];
float accm16 = sum[n][y][8];
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][1]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][9]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][2]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][10]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][3]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][11]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][4]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][12]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][5]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][13]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][6]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][14]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][7]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][15]), "v"(accm16));
accm0 += __shfl(accm0, 36);
accm16 += __shfl(accm16, 52);
sum[n][y][0] = accm0 + __shfl(accm16, 16);
}
}
if (threadIdx.x == 0) {
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0] * sA * sB);
}
}
}
m += CuCount * _WvPrGrp * YTILE;
}
}
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
const fp8_t* B, const fp8_t* __restrict__ A,
scalar_t* C, const float* __restrict__ s_A,
const float* __restrict__ s_B,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
#if defined(__HIP__MI300__) // TODO: Add NAVI support
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitKQ_hf_(const int K, const int Kp, const int M, const fp8_t* B,
const fp8_t* __restrict__ A, scalar_t* C,
const float* __restrict__ s_A, const float* __restrict__ s_B,
const int _WvPrGrp, const int CuCount) {
using scalar8 =
__attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float;
using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int;
using intx4 = __attribute__((__vector_size__(4 * sizeof(int)))) int;
union bigType {
char f8[A_CHUNK];
char2 c2[A_CHUNK / 2];
scalar_t h[A_CHUNK / 2];
float f[A_CHUNK / 4];
int i[A_CHUNK / 4];
long l[A_CHUNK / 8];
intx4 l2[A_CHUNK / 16];
scalar8 h8;
};
__shared__ fp8_t s[1024 * 64];
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) {
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
}
__syncthreads();
if (threadIdx.y >= _WvPrGrp) return;
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
floatx16 sum[N][YTILE];
float sA = *s_A;
float sB = *s_B;
while (m < M) {
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++) sum[n][i] = {0};
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
// Fetch the weight matrix from memory!
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
const fp8_t* B_ = &B[(m + 0) * Kp + k_];
for (int y = 0; y < YTILE; ++y) {
if (y + m >= M) break; // To avoid mem access fault.
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * Kp])));
}
}
// Fetch activation matrix from either just LDS or from both LDS / memory
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
for (int n = 0; n < N; n++) {
if (k_ + K * n < 64 * 1024)
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
else
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
for (uint32_t n = 0; n < N; n++) {
for (int i = 0; i < A_CHUNK; i += 8) {
for (int y = 0; y < YTILE; ++y) {
sum[n][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bigA[n][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[n][y], 0, 0,
0);
}
}
}
}
}
// Final reduction
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
float accm0 = sum[n][y][0];
float accm16 = sum[n][y][8];
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][1]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][9]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][2]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][10]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][3]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][11]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][4]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][12]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][5]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][13]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][6]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][14]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][7]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][15]), "v"(accm16));
accm0 += __shfl(accm0, 36);
accm16 += __shfl(accm16, 52);
sum[n][y][0] = accm0 + __shfl(accm16, 16);
}
}
if (threadIdx.x == 0) {
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
if (y + m >= M) break; // To avoid mem access fault.
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0] * sA * sB);
}
}
}
m += CuCount * _WvPrGrp * YTILE;
}
}
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
const fp8_t* B, const fp8_t* __restrict__ A,
scalar_t* C, const float* __restrict__ s_A,
const float* __restrict__ s_B, const int _WvPrGrp,
const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
at::Tensor& scale_a, at::Tensor& scale_b,
const int64_t CuCount) {
static c10::ScalarType kFp8Type = is_fp8_ocp()
? c10::ScalarType::Float8_e4m3fn
: c10::ScalarType::Float8_e4m3fnuz;
auto M_in = in_a.size(0);
auto K_in = in_a.size(1);
auto N_in = in_b.size(0);
auto Kp_in = in_a.stride(0);
TORCH_CHECK(K_in % 16 == 0, "k % 16 == 0");
TORCH_CHECK(in_a.dtype() == in_b.dtype() && in_a.dtype() == kFp8Type);
TORCH_CHECK(out_c.dtype() == torch::kFloat16 ||
out_c.dtype() == torch::kBFloat16);
dim3 grid(CuCount);
const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
_N) \
{ \
dim3 block(64, _WvPrGrp); \
if ((K_in * N_in <= 64 * 1024) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \
s_a, s_b, __wvPrGrp, CuCount); \
} else { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
wvSplitKQ_hf_<fptype, fp8_t, 64, _YTILEm, _WvPrGrp, 16, _UNRLm, _N> \
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \
s_a, s_b, __wvPrGrp, CuCount); \
} \
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_c.scalar_type(), "wvSplitKQ", [&] {
using fptype = typename scalar<scalar_t>::type;
auto c_ptr = reinterpret_cast<fptype*>(out_c.data_ptr());
auto s_a = scale_a.data_ptr<float>();
auto s_b = scale_b.data_ptr<float>();
VLLM_DISPATCH_FP8_TYPES(in_a.scalar_type(), "wvSplitKQ", [&] {
auto a_ptr = in_a.data_ptr<fp8_t>();
auto b_ptr = in_b.data_ptr<fp8_t>();
switch (N_in) {
case 1:
WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 1)
break;
case 2:
WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 2)
break;
case 3:
WVSPLITKQ(16, 4, 7, 7, 1, 1, 1, 3)
break;
case 4:
WVSPLITKQ(16, 4, 7, 7, 1, 1, 1, 4)
break;
default:
throw std::runtime_error(
"Unsupported N value: " + std::to_string(M_in) + "," +
std::to_string(K_in) + "," + std::to_string(N_in));
}
});
});
}
......@@ -14,6 +14,24 @@
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
// vLLM custom ops for rocm
// Custom gemm op for matrix-vector multiplication
rocm_ops.def(
"LLMM1(Tensor in_a, Tensor in_b, int rows_per_block) -> "
"Tensor");
rocm_ops.impl("LLMM1", torch::kCUDA, &LLMM1);
// Custom gemm op for skinny matrix-matrix multiplication
rocm_ops.def(
"wvSplitK(Tensor in_a, Tensor in_b, int CuCount) -> "
"Tensor");
rocm_ops.impl("wvSplitK", torch::kCUDA, &wvSplitK);
// wvSplitK for fp8
rocm_ops.def(
"wvSplitKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, "
" Tensor scale_b, int CuCount) -> ()");
rocm_ops.impl("wvSplitKQ", torch::kCUDA, &wvSplitKQ);
// Custom attention op
// Compute the attention between an input query and the cached
// keys/values using PagedAttention.
......
......@@ -294,6 +294,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
") -> ()");
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
// Compute MLA decode using cutlass.
// ops.def(
// "cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
// " Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
// " Tensor page_table, float scale) -> ()");
// ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
......
......@@ -162,6 +162,9 @@ ENV UV_HTTP_TIMEOUT=500
COPY requirements/lint.txt requirements/lint.txt
COPY requirements/test.txt requirements/test.txt
COPY requirements/dev.txt requirements/dev.txt
# Workaround for #17068
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system mamba-ssm==2.2.4 --no-build-isolation
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/dev.txt
#################### DEV IMAGE ####################
......@@ -240,6 +243,8 @@ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
uv pip install --system https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.1.post2/flashinfer_python-0.2.1.post2+cu124torch2.6-cp38-abi3-linux_x86_64.whl ; \
fi
COPY examples examples
COPY benchmarks benchmarks
COPY ./vllm/collect_env.py .
# Although we build Flashinfer with AOT mode, there's still
# some issues w.r.t. JIT compilation. Therefore we need to
......@@ -263,6 +268,9 @@ ADD . /vllm-workspace/
ENV UV_HTTP_TIMEOUT=500
# install development dependencies (for testing)
# Workaround for #17068
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system mamba-ssm==2.2.4 --no-build-isolation
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/dev.txt
......@@ -289,6 +297,7 @@ RUN mv vllm test_docs/
#################### OPENAI API SERVER ####################
# base openai image with additional requirements, for any subsequent openai-style images
FROM vllm-base AS vllm-openai-base
ARG TARGETPLATFORM
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
......
......@@ -121,6 +121,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
ADD ./tests/ ./tests/
ADD ./examples/ ./examples/
ADD ./benchmarks/ ./benchmarks/
ADD ./vllm/collect_env.py .
# install development dependencies (for testing)
RUN --mount=type=cache,target=/root/.cache/uv \
......
# The vLLM Dockerfile is used to construct vLLM image against torch nightly that can be directly used for testing
# for torch nightly, cuda >=12.6 is required,
# use 12.8 due to FlashAttention issue with cuda 12.6 (https://github.com/vllm-project/vllm/issues/15435#issuecomment-2775924628)
ARG CUDA_VERSION=12.8.0
#
#################### BASE BUILD IMAGE ####################
# prepare basic build environment
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base
ARG CUDA_VERSION=12.8.0
ARG PYTHON_VERSION=3.12
ARG TARGETPLATFORM
ENV DEBIAN_FRONTEND=noninteractive
# Install Python and other dependencies
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
&& apt-get update -y \
&& apt-get install -y ccache software-properties-common git curl sudo \
&& add-apt-repository ppa:deadsnakes/ppa \
&& apt-get update -y \
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
&& python3 --version \
&& python3 -m pip --version
# Install uv for faster pip installs
RUN --mount=type=cache,target=/root/.cache/uv \
python3 -m pip install uv
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
ENV UV_HTTP_TIMEOUT=500
# Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519
# as it was causing spam when compiling the CUTLASS kernels
RUN apt-get install -y gcc-10 g++-10
RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 110 --slave /usr/bin/g++ g++ /usr/bin/g++-10
RUN <<EOF
gcc --version
EOF
# Workaround for https://github.com/openai/triton/issues/2507 and
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
# this won't be needed for future versions of this docker image
# or future versions of triton.
RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
WORKDIR /workspace
# install build and runtime dependencies
COPY requirements/common.txt requirements/common.txt
COPY use_existing_torch.py use_existing_torch.py
COPY pyproject.toml pyproject.toml
# install build and runtime dependencies without stable torch version
RUN python3 use_existing_torch.py
# install torch nightly
ARG PINNED_TORCH_VERSION
RUN --mount=type=cache,target=/root/.cache/uv \
if [ -n "$PINNED_TORCH_VERSION" ]; then \
pkgs="$PINNED_TORCH_VERSION"; \
else \
pkgs="torch torchaudio torchvision"; \
fi && \
uv pip install --system $pkgs --index-url https://download.pytorch.org/whl/nightly/cu128
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system numba==0.61.2
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/common.txt
# must put before installing xformers, so it can install the correct version of xfomrers.
ARG torch_cuda_arch_list='8.0;8.6;8.9;9.0'
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
# Build xformers with cuda and torch nightly
# following official xformers guidance: https://github.com/facebookresearch/xformers#build
# todo(elainewy): cache xformers build result for faster build
ARG max_jobs=16
ENV MAX_JOBS=${max_jobs}
ARG XFORMERS_COMMIT=f2de641ef670510cadab099ce6954031f52f191c
ENV CCACHE_DIR=/root/.cache/ccache
RUN --mount=type=cache,target=/root/.cache/ccache \
--mount=type=cache,target=/root/.cache/uv \
echo 'git clone xformers...' \
&& git clone https://github.com/facebookresearch/xformers.git --recursive \
&& cd xformers \
&& git checkout ${XFORMERS_COMMIT} \
&& git submodule update --init --recursive \
&& echo 'finish git clone xformers...' \
&& rm -rf build \
&& python3 setup.py bdist_wheel --dist-dir=../xformers-dist --verbose \
&& cd .. \
&& rm -rf xformers
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system xformers-dist/*.whl --verbose
# build can take a long time, and the torch nightly version fetched from url can be different in next docker stage.
# track the nightly torch version used in the build, when we set up runtime environment we can make sure the version is the same
RUN uv pip freeze | grep -i '^torch\|^torchvision\|^torchaudio' > torch_build_versions.txt
RUN cat torch_build_versions.txt
# cuda arch list used by torch
# can be useful for `test`
# explicitly set the list to avoid issues with torch 2.2
# see https://github.com/pytorch/pytorch/pull/123243
# Override the arch list for flash-attn to reduce the binary size
ARG vllm_fa_cmake_gpu_arches='80-real;90-real'
ENV VLLM_FA_CMAKE_GPU_ARCHES=${vllm_fa_cmake_gpu_arches}
#################### BASE BUILD IMAGE ####################
#################### WHEEL BUILD IMAGE ####################
FROM base AS build
ARG TARGETPLATFORM
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
ENV UV_HTTP_TIMEOUT=500
COPY . .
RUN python3 use_existing_torch.py
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/build.txt
ARG GIT_REPO_CHECK=0
RUN --mount=type=bind,source=.git,target=.git \
if [ "$GIT_REPO_CHECK" != "0" ]; then bash tools/check_repo.sh ; fi
# Max jobs used by Ninja to build extensions
ARG max_jobs=16
ENV MAX_JOBS=${max_jobs}
ARG nvcc_threads=2
ENV NVCC_THREADS=$nvcc_threads
ARG USE_SCCACHE
ARG SCCACHE_BUCKET_NAME=vllm-build-sccache
ARG SCCACHE_REGION_NAME=us-west-2
ARG SCCACHE_S3_NO_CREDENTIALS=0
# if USE_SCCACHE is set, use sccache to speed up compilation
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=.git,target=.git \
if [ "$USE_SCCACHE" = "1" ]; then \
echo "Installing sccache..." \
&& curl -L -o sccache.tar.gz https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz \
&& tar -xzf sccache.tar.gz \
&& sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \
&& rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \
&& export SCCACHE_BUCKET=${SCCACHE_BUCKET_NAME} \
&& export SCCACHE_REGION=${SCCACHE_REGION_NAME} \
&& export SCCACHE_S3_NO_CREDENTIALS=${SCCACHE_S3_NO_CREDENTIALS} \
&& export SCCACHE_IDLE_TIMEOUT=0 \
&& export CMAKE_BUILD_TYPE=Release \
&& sccache --show-stats \
&& python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38 \
&& sccache --show-stats; \
fi
ENV CCACHE_DIR=/root/.cache/ccache
RUN --mount=type=cache,target=/root/.cache/ccache \
--mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=.git,target=.git \
if [ "$USE_SCCACHE" != "1" ]; then \
# Clean any existing CMake artifacts
rm -rf .deps && \
mkdir -p .deps && \
python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \
fi
#################### WHEEL BUILD IMAGE ####################
################### VLLM INSTALLED IMAGE ####################
# Setup clean environment for vLLM and its dependencies for test and api server using ubuntu22.04 with AOT flashinfer
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS vllm-base
# prepare for environment starts
ARG CUDA_VERSION=12.8.0
ARG PYTHON_VERSION=3.12
WORKDIR /vllm-workspace
ENV DEBIAN_FRONTEND=noninteractive
ARG TARGETPLATFORM
RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \
echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment
# Install Python and other dependencies
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
&& apt-get update -y \
&& apt-get install -y ccache software-properties-common git curl wget sudo vim python3-pip \
&& apt-get install -y ffmpeg libsm6 libxext6 libgl1 \
&& add-apt-repository ppa:deadsnakes/ppa \
&& apt-get update -y \
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv libibverbs-dev \
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
&& python3 --version && python3 -m pip --version
RUN --mount=type=cache,target=/root/.cache/uv \
python3 -m pip install uv
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
ENV UV_HTTP_TIMEOUT=500
# Workaround for https://github.com/openai/triton/issues/2507 and
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
# this won't be needed for future versions of this docker image
# or future versions of triton.
RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
# get the nightly torch version used in the build to make sure the version is the same
COPY --from=base /workspace/torch_build_versions.txt ./torch_build_versions.txt
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system $(cat torch_build_versions.txt | xargs) --index-url https://download.pytorch.org/whl/nightly/cu128
# install the vllm wheel
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/vllm-dist \
--mount=type=cache,target=/root/.cache/uv \
uv pip install --system vllm-dist/*.whl --verbose
# install xformers again for the new environment
RUN --mount=type=bind,from=base,src=/workspace/xformers-dist,target=/vllm-workspace/xformers-dist \
--mount=type=cache,target=/root/.cache/uv \
uv pip install --system /vllm-workspace/xformers-dist/*.whl --verbose
ARG torch_cuda_arch_list='8.0;8.6;8.9;9.0'
# install package for build flashinfer
# see issue: https://github.com/flashinfer-ai/flashinfer/issues/738
RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2.post1
# build flashinfer for torch nightly from source around 10 mins
# release version: v0.2.2.post1
# todo(elainewy): cache flashinfer build result for faster build
ENV CCACHE_DIR=/root/.cache/ccache
RUN --mount=type=cache,target=/root/.cache/ccache \
--mount=type=cache,target=/root/.cache/uv \
echo "git clone flashinfer..." \
&& git clone --recursive https://github.com/flashinfer-ai/flashinfer.git \
&& cd flashinfer \
&& git checkout v0.2.2.post1 \
&& git submodule update --init --recursive \
&& echo "finish git clone flashinfer..." \
&& rm -rf build \
&& export TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} \
&& FLASHINFER_ENABLE_AOT=1 python3 setup.py bdist_wheel --dist-dir=../flashinfer-dist --verbose \
&& cd .. \
&& rm -rf flashinfer
# install flashinfer
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system flashinfer-dist/*.whl --verbose
# install common packages
COPY requirements/common.txt requirements/common.txt
COPY use_existing_torch.py use_existing_torch.py
COPY pyproject.toml pyproject.toml
COPY examples examples
COPY benchmarks benchmarks
COPY ./vllm/collect_env.py .
RUN python3 use_existing_torch.py
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/common.txt
################### VLLM INSTALLED IMAGE ####################
#################### UNITTEST IMAGE #############################
FROM vllm-base as test
COPY tests/ tests/
# install build and runtime dependencies without stable torch version
COPY requirements/nightly_torch_test.txt requirements/nightly_torch_test.txt
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
ENV UV_HTTP_TIMEOUT=500
# install development dependencies (for testing)
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -e tests/vllm_test_utils
# enable fast downloads from hf (for testing)
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system hf_transfer
ENV HF_HUB_ENABLE_HF_TRANSFER 1
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/nightly_torch_test.txt
#################### UNITTEST IMAGE #############################
......@@ -126,13 +126,16 @@ RUN --mount=type=cache,target=/root/.cache/uv \
FROM base-builder AS cv-builder
ARG MAX_JOBS
ARG OPENCV_VERSION=84
ARG OPENCV_VERSION=86
# patch for version 4.11.0.86
ARG OPENCV_PATCH=97f3f39
ARG ENABLE_HEADLESS=1
RUN --mount=type=cache,target=/root/.cache/uv \
source /opt/rh/gcc-toolset-13/enable && \
git clone --recursive https://github.com/opencv/opencv-python.git -b ${OPENCV_VERSION} && \
cd opencv-python && \
sed -i 's/"setuptools==59.2.0",/"setuptools<70.0",/g' pyproject.toml && \
sed -i -E -e 's/"setuptools.+",/"setuptools",/g' pyproject.toml && \
cd opencv && git cherry-pick --no-commit $OPENCV_PATCH && cd .. && \
python -m build --wheel --installer=uv --outdir /opencvwheels/
###############################################################
......@@ -148,9 +151,15 @@ COPY --from=arrow-builder /tmp/control /dev/null
COPY --from=cv-builder /tmp/control /dev/null
ARG VLLM_TARGET_DEVICE=cpu
ARG GRPC_PYTHON_BUILD_SYSTEM_OPENSSL=1
# this step installs vllm and populates uv cache
# with all the transitive dependencies
RUN --mount=type=cache,target=/root/.cache/uv \
source /opt/rh/gcc-toolset-13/enable && \
git clone https://github.com/huggingface/xet-core.git && cd xet-core/hf_xet/ && \
uv pip install maturin && \
uv build --wheel --out-dir /hf_wheels/
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,from=torch-builder,source=/torchwheels/,target=/torchwheels/,ro \
--mount=type=bind,from=arrow-builder,source=/arrowwheels/,target=/arrowwheels/,ro \
......@@ -159,7 +168,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
source /opt/rh/gcc-toolset-13/enable && \
uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl && \
sed -i -e 's/.*torch.*//g' /src/pyproject.toml /src/requirements/*.txt && \
uv pip install pandas pythran pybind11 && \
uv pip install pandas pythran pybind11 /hf_wheels/*.whl && \
# sentencepiece.pc is in some pkgconfig inside uv cache
export PKG_CONFIG_PATH=$(find / -type d -name "pkgconfig" 2>/dev/null | tr '\n' ':') && \
uv pip install -r /src/requirements/common.txt -r /src/requirements/cpu.txt -r /src/requirements/build.txt --no-build-isolation && \
......@@ -247,8 +256,9 @@ RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,from=torch-builder,source=/torchwheels/,target=/torchwheels/,ro \
--mount=type=bind,from=arrow-builder,source=/arrowwheels/,target=/arrowwheels/,ro \
--mount=type=bind,from=cv-builder,source=/opencvwheels/,target=/opencvwheels/,ro \
--mount=type=bind,from=vllmcache-builder,source=/hf_wheels/,target=/hf_wheels/,ro \
--mount=type=bind,from=vllmcache-builder,source=/vllmwheel/,target=/vllmwheel/,ro \
HOME=/root uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl /vllmwheel/*.whl
HOME=/root uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl /hf_wheels/*.whl /vllmwheel/*.whl
COPY ./ /workspace/vllm
WORKDIR /workspace/vllm
......
......@@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="1a7f4dfa"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="8970b25b"
ARG AITER_BRANCH="7e1ed08"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
FROM ${BASE_IMAGE} AS base
......
......@@ -58,7 +58,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
cd ../../python && \
export PYARROW_PARALLEL=4 && \
export ARROW_BUILD_TYPE=release && \
uv pip install -r requirements/build.txt && \
uv pip install -r requirements-build.txt && \
python setup.py build_ext --build-type=$ARROW_BUILD_TYPE --bundle-arrow-cpp bdist_wheel
FROM python-install AS numa-build
......@@ -96,6 +96,22 @@ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install -v torch==${TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/nightly/cpu && \
python setup.py bdist_wheel
FROM python-install AS hf-xet-builder
# Install hf-xet
WORKDIR /tmp
ENV CARGO_HOME=/root/.cargo
ENV RUSTUP_HOME=/root/.rustup
ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH"
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,from=rust,source=/root/.cargo,target=/root/.cargo,rw \
--mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \
git clone https://github.com/huggingface/xet-core.git && \
cd xet-core/hf_xet/ && \
uv pip install maturin patchelf && \
python -m maturin build --release --out dist && \
mkdir -p /tmp/hf-xet/dist && \
cp dist/*.whl /tmp/hf-xet/dist/
# Final build stage
FROM python-install AS vllm-cpu
ARG PYTHON_VERSION
......@@ -120,12 +136,15 @@ RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \
--mount=type=bind,from=pyarrow,source=/tmp/arrow/python/dist,target=/tmp/arrow-wheels \
--mount=type=bind,from=torch-vision,source=/tmp/vision/dist,target=/tmp/vision-wheels/ \
--mount=type=bind,from=hf-xet-builder,source=/tmp/hf-xet/dist,target=/tmp/hf-xet-wheels/ \
sed -i '/^torch/d' requirements/build.txt && \
ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl | head -n 1) && \
VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl | head -n 1) && \
HF_XET_WHL_FILE=$(ls /tmp/hf-xet-wheels/*.whl | head -n 1) && \
uv pip install -v \
$ARROW_WHL_FILE \
$VISION_WHL_FILE \
$HF_XET_WHL_FILE \
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
--index-strategy unsafe-best-match \
-r requirements/build.txt \
......@@ -149,4 +168,5 @@ USER 2000
WORKDIR /home/vllm
# Set the default entrypoint
ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"]
\ No newline at end of file
ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment