Commit 8d75f22e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori

parents ce888aa4 7d80c73d
...@@ -1241,33 +1241,16 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx, ...@@ -1241,33 +1241,16 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx,
} }
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support #endif // defined(__HIP__GFX9__) TODO: Add NAVI support
// Find the min val of div2 that doesn't increase N/(div1*div2)
int mindiv(int N, int div1, int div2) { int mindiv(int N, int div1, int div2) {
int nPrRnd = div1 * div2; int nPrRnd = div1 * div2;
int rnds0 = N / nPrRnd; int rnds[13];
nPrRnd -= div1 * 3; for (int i = 0; i < 13; i++) {
int rnds3 = N / nPrRnd; rnds[i] = (N + nPrRnd - 1) / nPrRnd;
nPrRnd -= div1; nPrRnd -= div1;
int rnds4 = N / nPrRnd; }
nPrRnd -= div1; for (int i = 12; i >= 0; i--)
int rnds5 = N / nPrRnd; if (rnds[0] == rnds[i]) return (div2 - i);
nPrRnd -= div1;
int rnds6 = N / nPrRnd;
nPrRnd -= div1;
int rnds7 = N / nPrRnd;
nPrRnd -= div1;
int rnds8 = N / nPrRnd;
nPrRnd -= div1;
int rnds9 = N / nPrRnd;
nPrRnd -= div1;
int rtn = div2;
if (rnds0 == rnds3) rtn = div2 - 3;
if (rnds0 == rnds4) rtn = div2 - 4;
if (rnds0 == rnds5) rtn = div2 - 5;
if (rnds0 == rnds6) rtn = div2 - 6;
if (rnds0 == rnds7) rtn = div2 - 7;
if (rnds0 == rnds8) rtn = div2 - 8;
if (rnds0 == rnds9) rtn = div2 - 9;
return rtn;
} }
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
...@@ -1300,26 +1283,37 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, ...@@ -1300,26 +1283,37 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int max_lds_len = get_lds_size() / 2; const int max_lds_len = get_lds_size() / 2;
#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ #define WVSPLITK(_YTILE, _UNRL, _N) \
_N) \ { \
{ \ dim3 block(64, 16); \
dim3 block(64, _WvPrGrp); \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \ if ((K_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ wvSplitK_hf_sml_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \ <<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ biasf4, c, __wvPrGrp, CuCount); \
biasf4, c, __wvPrGrp, CuCount); \ else if (K_in * N_in <= max_lds_len * 1.2) \
} else if (K_in * N_in <= max_lds_len * 1.2) { \ wvSplitK_hf_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ <<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ biasf4, c, __wvPrGrp, CuCount); \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ else \
biasf4, c, __wvPrGrp, CuCount); \ wvSplitK_hf_big_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
} else { \ <<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \ biasf4, c, __wvPrGrp, CuCount); \
wvSplitK_hf_big_<fptype, 64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \ }
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \ #define WVSPLIT_TILE(_sYT, __N) \
} \ { \
bool fit_lds = (K_in * N_in <= max_lds_len); \
if (_sYT <= 1) \
WVSPLITK(1, 4, __N) \
else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \
WVSPLITK(2, 2, __N) \
else if (_sYT <= 4 * 3) \
WVSPLITK(3, 2, __N) \
else if (__N == 4) \
WVSPLITK(4, 1, __N) \
else \
WVSPLITK(4, 2, __N) \
} }
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] { AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] {
...@@ -1331,18 +1325,23 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, ...@@ -1331,18 +1325,23 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
? reinterpret_cast<const fptype*>(in_bias->data_ptr()) ? reinterpret_cast<const fptype*>(in_bias->data_ptr())
: nullptr; : nullptr;
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr()); fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
// first shoot for biggest tile-size that keeps all simd busy,
// then cut the active waves to balance their distribution...
int sYT = (M_in + CuCount * 4 - 1) / (CuCount * 4);
switch (N_in) { switch (N_in) {
case 1: case 1:
WVSPLITK(16, 2, 2, 2, 2, 2, 2, 1) WVSPLIT_TILE(sYT, 1)
break; break;
case 2: case 2:
WVSPLITK(16, 2, 2, 2, 2, 2, 2, 2) WVSPLIT_TILE(sYT, 2)
break; break;
case 3: case 3:
WVSPLITK(16, 4, 7, 7, 1, 1, 1, 3) WVSPLIT_TILE(sYT, 3)
break; break;
case 4: case 4:
WVSPLITK(16, 4, 7, 7, 1, 1, 1, 4) WVSPLIT_TILE(sYT, 4)
break; break;
default: default:
throw std::runtime_error( throw std::runtime_error(
......
...@@ -44,33 +44,293 @@ __global__ void apply_repetition_penalties_kernel( ...@@ -44,33 +44,293 @@ __global__ void apply_repetition_penalties_kernel(
} }
} }
static inline __device__ uint16_t extractBinIdx(float x) { __device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t {
union { uint32_t bits = __float_as_uint(x);
__half h; return (bits & 0x80000000) ? bits : ~bits & 0x7fffffff;
uint16_t u16;
} tmp;
tmp.h = __float2half_rn(x);
tmp.u16 = (x < 0.f) ? (~tmp.u16 & 0xffff) : (tmp.u16 | 0x8000);
return 511 - (tmp.u16 >> 7);
} }
template <int kNumThreadsPerBlock = 512, int kNumBins = 512, int kTopK = 2048> template <int step>
__device__ void topKPerRowJob(const float* logits, const int rowStart, static inline __device__ uint32_t extractBinIdx(float x) {
const int rowEnd, const int rowIdx, if constexpr (step == 0) {
int* outIndices, int stride0, int stride1) { __half hx = __float2half(x);
// The number of elements per thread for the final top-k sort. uint16_t bits = __half_as_ushort(hx);
static constexpr int kNumTopKItemsPerThread = kTopK / kNumThreadsPerBlock; bits = (bits & 0x8000) ? bits : ~bits & 0x7fff;
// The class to sort the elements during the final top-k sort. return bits >> 5;
#ifdef USE_ROCM } else {
using TopKSort = hipcub::BlockRadixSort<float, kNumThreadsPerBlock, uint32_t bits = __float_as_uint(x);
kNumTopKItemsPerThread, int>; bits = (bits & 0x80000000) ? bits : ~bits & 0x7fffffff;
#else
using TopKSort = cub::BlockRadixSort<float, kNumThreadsPerBlock, if constexpr (step == 1) {
kNumTopKItemsPerThread, int>; return bits >> 21;
#endif } else if constexpr (step == 2) {
return (bits >> 10) & 0x7ff;
} else if constexpr (step == 3) {
return bits & 0x3ff;
}
}
}
template <int shift>
static inline __device__ bool isPartialMatch(float x, uint32_t pattern) {
if constexpr (shift == 0) {
return true;
}
uint32_t bits = __float_as_uint(x);
bits = (bits & 0x80000000) ? bits : ~bits & 0x7fffffff;
return (bits ^ pattern) >> shift == 0;
}
/**
* Map a Func over the input data, using vectorized load instructions if
* possible.
*
* @tparam T element type
* @tparam IdxT indexing type
* @tparam Func void (T x, IdxT idx)
*
* @param thread_rank rank of the calling thread among all participating threads
* @param num_threads number of the threads that participate in processing
* @param in the input data
* @param len the number of elements to read
* @param f the lambda taking two arguments (T x, IdxT idx)
*/
template <typename T, typename idxT, typename Func>
__device__ void vectorized_process(size_t thread_rank, size_t num_threads,
const T* in, idxT len, Func f) {
constexpr int WARP_SIZE = 32;
using WideT = float4;
if constexpr (sizeof(T) >= sizeof(WideT)) {
for (idxT i = thread_rank; i < len; i += num_threads) {
f(in[i], i);
}
} else {
static_assert(sizeof(WideT) % sizeof(T) == 0);
constexpr int items_per_scalar = sizeof(WideT) / sizeof(T);
// TODO: it's UB
union {
WideT scalar;
T array[items_per_scalar];
} wide;
int skip_cnt =
(reinterpret_cast<size_t>(in) % sizeof(WideT))
? ((sizeof(WideT) - reinterpret_cast<size_t>(in) % sizeof(WideT)) /
sizeof(T))
: 0;
if (skip_cnt > len) {
skip_cnt = len;
}
const WideT* in_cast = reinterpret_cast<decltype(in_cast)>(in + skip_cnt);
const idxT len_cast = (len - skip_cnt) / items_per_scalar;
for (idxT i = thread_rank; i < len_cast; i += num_threads) {
wide.scalar = in_cast[i];
const idxT real_i = skip_cnt + i * items_per_scalar;
#pragma unroll
for (int j = 0; j < items_per_scalar; ++j) {
f(wide.array[j], real_i + j);
}
}
static_assert(WARP_SIZE >= items_per_scalar);
// and because items_per_scalar > skip_cnt, WARP_SIZE > skip_cnt
// no need to use loop
if (thread_rank < skip_cnt) {
f(in[thread_rank], thread_rank);
}
// because len_cast = (len - skip_cnt) / items_per_scalar,
// len_cast * items_per_scalar + items_per_scalar > len - skip_cnt;
// and so
// len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <=
// WARP_SIZE no need to use loop
const idxT remain_i = skip_cnt + len_cast * items_per_scalar + thread_rank;
if (remain_i < len) {
f(in[remain_i], remain_i);
}
}
}
template <int step, int kNumThreadsPerBlock, int kNumBins, int kNumFinalItems,
bool multipleBlocksPerRow, bool mergeBlocks, typename SmemFinalType,
typename SmemOutputType>
__device__ bool processHistogramStep(
const int* indices, const float* logits, int rowEnd, uint32_t& logitPattern,
int& thresholdBinIdx, SmemOutputType& smemOutput, int* smemThresholdBinIdx,
int* smemFinalDstIdx, int* smemFinalBinSize, int* smemFoundTopKValues,
SmemFinalType& smemFinal, int stride1, int rowStart, int topK) {
// Clear the histogram.
#pragma unroll
for (int idx = threadIdx.x; idx < kNumBins; idx += kNumThreadsPerBlock) {
smemFinal.histo.data[idx] = 0;
}
// Make sure the histogram is ready.
__syncthreads();
// Update pattern
constexpr auto patternShift = step < 2 ? 0 : step == 2 ? 21 : 10;
if constexpr (step == 2) {
logitPattern = static_cast<uint32_t>(thresholdBinIdx & 0x7ff)
<< patternShift;
} else if constexpr (step == 3) {
logitPattern |= static_cast<uint32_t>(thresholdBinIdx & 0x7ff)
<< patternShift;
}
auto distributeToBins = [&](float logit, int /* idx */ = 0) {
if (isPartialMatch<patternShift>(logit, logitPattern)) {
uint32_t binIdx = extractBinIdx<step>(logit);
atomicAdd(&smemFinal.histo.data[binIdx], 1);
}
};
// Distribute the elements to the histogram bins.
if (stride1 == 1) {
vectorized_process(threadIdx.x, kNumThreadsPerBlock, logits + rowStart,
rowEnd - rowStart, distributeToBins);
} else {
for (int idx = rowStart + threadIdx.x; idx < rowEnd;
idx += kNumThreadsPerBlock) {
float logit = logits[idx * stride1];
distributeToBins(logit, idx);
}
}
// Make sure the histogram is ready.
__syncthreads();
// Reads the value of the starting position in the smemOutput array
int lastValue = smemFoundTopKValues[0];
for (int round = 0; round < kNumBins / kNumThreadsPerBlock; round++) {
// Read the values from SMEM.
int idx = threadIdx.x + kNumThreadsPerBlock * round;
int binCount{0};
binCount = smemFinal.histo.data[idx];
// Make sure each thread has read its value.
__syncthreads();
// Compute the prefix sum.
int prefixSum{0}, totalSum{0};
#ifdef USE_ROCM
using Scan = hipcub::BlockScan<int, kNumThreadsPerBlock>;
#else
using Scan = cub::BlockScan<int, kNumThreadsPerBlock>;
#endif
Scan(smemFinal.histo.scan).ExclusiveSum(binCount, prefixSum, totalSum);
// Update the histogram with the prefix sums.
prefixSum += lastValue;
totalSum += lastValue;
smemFinal.histo.data[idx] = prefixSum;
// Make sure the data is in shared memory.
__syncthreads();
// Find the last valid bin.
bool foundThreshold = false;
if (prefixSum < topK) {
int nextPrefixSum = threadIdx.x == kNumThreadsPerBlock - 1
? totalSum
: smemFinal.histo.data[idx + 1];
if (nextPrefixSum >= topK) {
smemThresholdBinIdx[0] = idx;
smemFinalBinSize[0] = nextPrefixSum - prefixSum;
foundThreshold = true;
}
}
// Early exit: if any thread found the threshold, we can skip remaining
// rounds
if (__syncthreads_or(foundThreshold)) {
break;
}
lastValue = totalSum;
}
// Make sure the data is in shared memory.
__syncthreads();
// The threshold bin.
thresholdBinIdx = smemThresholdBinIdx[0];
auto processBins = [&](float logit, int idx) {
if (isPartialMatch<patternShift>(logit, logitPattern)) {
uint32_t binIdx = extractBinIdx<step>(logit);
if (binIdx < thresholdBinIdx) {
// The element is part of the top-k selection
int dstIdx = atomicAdd(&smemFoundTopKValues[0], 1);
if constexpr (mergeBlocks) {
smemOutput[dstIdx] = indices[idx];
} else if constexpr (multipleBlocksPerRow) {
smemOutput[dstIdx] = idx + rowStart;
reinterpret_cast<float*>(smemOutput + topK)[dstIdx] = logit;
} else {
smemOutput[dstIdx] = idx;
}
}
if constexpr (step < 3) {
// Only fill the final items for sorting if the threshold bin fits
if (binIdx == thresholdBinIdx &&
smemFinalBinSize[0] <= kNumFinalItems) {
int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1);
smemFinal.items.logits[dstIdx] = logit;
if constexpr (mergeBlocks) {
smemFinal.items.indices[dstIdx] = indices[idx];
} else if constexpr (multipleBlocksPerRow) {
smemFinal.items.indices[dstIdx] = idx + rowStart;
} else {
smemFinal.items.indices[dstIdx] = idx;
}
}
} else {
if (binIdx == thresholdBinIdx) {
// The elements in the threshold bin share the same 32 bits at step 3
int dstIdx = atomicAdd(&smemFinal.histo.data[binIdx], 1);
if (dstIdx < topK) {
if constexpr (mergeBlocks) {
smemOutput[dstIdx] = indices[idx];
} else if constexpr (multipleBlocksPerRow) {
smemOutput[dstIdx] = idx + rowStart;
reinterpret_cast<float*>(smemOutput + topK)[dstIdx] = logit;
} else {
smemOutput[dstIdx] = idx;
}
}
}
}
}
};
if (stride1 == 1) {
vectorized_process(threadIdx.x, kNumThreadsPerBlock, logits + rowStart,
rowEnd - rowStart, processBins);
} else {
for (int idx = rowStart + threadIdx.x; idx < rowEnd;
idx += kNumThreadsPerBlock) {
float logit = logits[idx * stride1];
processBins(logit, idx);
}
}
// Make sure the elements are in shared memory.
__syncthreads();
// Check if we should continue to next step
return smemFinalBinSize[0] > kNumFinalItems;
}
// Follows half - 11 - 11 - 10 bit iterations
template <int kNumThreadsPerBlock, int kNumBins, bool useRadixSort,
bool multipleBlocksPerRow = false, bool mergeBlocks = false>
static __device__ void topKPerRowJob(const int* indices, const float* logits,
int rowStart, int rowEnd, int* outIndices,
float* outLogits, int stride1, int topK) {
// The number of slots for the final pass. // The number of slots for the final pass.
static constexpr int kNumFinalItems = 3072; static constexpr int kNumFinalItems = 2048;
// The number of elements per thread for the final sort. // The number of elements per thread for the final sort.
static constexpr int kNumFinalItemsPerThread = static constexpr int kNumFinalItemsPerThread =
kNumFinalItems / kNumThreadsPerBlock; kNumFinalItems / kNumThreadsPerBlock;
...@@ -83,6 +343,9 @@ __device__ void topKPerRowJob(const float* logits, const int rowStart, ...@@ -83,6 +343,9 @@ __device__ void topKPerRowJob(const float* logits, const int rowStart,
kNumFinalItemsPerThread, int>; kNumFinalItemsPerThread, int>;
#endif #endif
using FinalSortTempStorage =
std::conditional_t<useRadixSort, typename FinalSort::TempStorage, int>;
// The class to compute the inclusive prefix-sum over the histogram. // The class to compute the inclusive prefix-sum over the histogram.
#ifdef USE_ROCM #ifdef USE_ROCM
using Scan = hipcub::BlockScan<int, kNumThreadsPerBlock>; using Scan = hipcub::BlockScan<int, kNumThreadsPerBlock>;
...@@ -90,9 +353,6 @@ __device__ void topKPerRowJob(const float* logits, const int rowStart, ...@@ -90,9 +353,6 @@ __device__ void topKPerRowJob(const float* logits, const int rowStart,
using Scan = cub::BlockScan<int, kNumThreadsPerBlock>; using Scan = cub::BlockScan<int, kNumThreadsPerBlock>;
#endif #endif
// Shared memory to compute the block scan.
__shared__ typename Scan::TempStorage smemScan;
// The structure to store the final items (for the final pass). // The structure to store the final items (for the final pass).
struct FinalItems { struct FinalItems {
// Shared memory to store the indices for the final pass. // Shared memory to store the indices for the final pass.
...@@ -101,200 +361,225 @@ __device__ void topKPerRowJob(const float* logits, const int rowStart, ...@@ -101,200 +361,225 @@ __device__ void topKPerRowJob(const float* logits, const int rowStart,
float logits[kNumFinalItems]; float logits[kNumFinalItems];
}; };
struct Histogram {
typename Scan::TempStorage scan;
int data[kNumBins];
};
// Shared memory to compute the block sort. // Shared memory to compute the block sort.
__shared__ union { __shared__ union {
FinalItems items; FinalItems items;
typename FinalSort::TempStorage finalSort; FinalSortTempStorage finalSort;
typename TopKSort::TempStorage topKSort; Histogram histo;
} smemFinal; } smemFinal;
// Shared memory to store the histogram.
__shared__ int smemHistogram[kNumBins];
// Shared memory to store the selected indices. // Shared memory to store the selected indices.
__shared__ int smemIndices[kTopK]; // If we are processing using multiple blocks, we need to store the logits and
// indices.
extern __shared__ int32_t smemOutput[];
// Shared memory to store the threshold bin. // Shared memory to store the threshold bin.
__shared__ int smemThresholdBinIdx[1]; __shared__ int smemThresholdBinIdx[1];
// Shared memory counter to register the candidates for the final phase. // Shared memory counter to register the candidates for the final phase.
__shared__ int smemFinalDstIdx[1]; __shared__ int smemFinalDstIdx[1];
// Shared memory to determine if the threshold bin fits in the final items.
__shared__ int smemFinalBinSize[1];
// Shared memory to keep track of the top-k values found so far by the
// previous iterations
__shared__ int smemFoundTopKValues[1];
// The length of the row. // The length of the row.
int rowLen = rowEnd - rowStart; int rowLen = rowEnd - rowStart;
// Shortcut if the length of the row is smaller than Top-K. Indices are not // Shortcut if the length of the row is smaller than Top-K. Indices are not
// sorted by their corresponding logit. // sorted by their corresponding logit.
if (rowLen <= kTopK) { if (rowLen <= topK) {
for (int rowIt = threadIdx.x; rowIt < rowLen; for (int rowIt = threadIdx.x; rowIt < rowLen;
rowIt += kNumThreadsPerBlock) { rowIt += kNumThreadsPerBlock) {
int idx = rowStart + rowIt; if constexpr (multipleBlocksPerRow) {
outIndices[rowIdx * kTopK + rowIt] = idx - rowStart; outIndices[rowIt] = rowIt + rowStart;
outLogits[rowIt] = logits[rowIt + rowStart];
} else {
outIndices[rowIt] = rowIt;
}
} }
for (int rowIt = rowLen + threadIdx.x; rowIt < kTopK; for (int rowIt = rowLen + threadIdx.x; rowIt < topK;
rowIt += kNumThreadsPerBlock) { rowIt += kNumThreadsPerBlock) {
outIndices[rowIdx * kTopK + rowIt] = -1; outIndices[rowIt] = -1;
if constexpr (multipleBlocksPerRow) {
outLogits[rowIt] = -FLT_MAX;
}
} }
return;
}
// Clear the histogram.
if (threadIdx.x < kNumBins) {
smemHistogram[threadIdx.x] = 0;
}
// Make sure the histogram is ready.
__syncthreads();
// Fetch elements one-by-one.
for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd;
rowIt += kNumThreadsPerBlock) {
uint16_t idx = extractBinIdx(logits[rowIdx * stride0 + rowIt * stride1]);
atomicAdd(&smemHistogram[idx], 1);
}
// Make sure the histogram is ready.
__syncthreads();
// Read the values from SMEM.
int binCount{0};
if (threadIdx.x < kNumBins) {
binCount = smemHistogram[threadIdx.x];
}
// Make sure each thread has read its value.
__syncthreads();
// Compute the prefix sum.
int prefixSum{0}, totalSum{0};
Scan(smemScan).ExclusiveSum(binCount, prefixSum, totalSum);
// Update the histogram with the prefix sums.
if (threadIdx.x < kNumBins) {
smemHistogram[threadIdx.x] = prefixSum;
}
// Make sure the data is in shared memory.
__syncthreads();
// Find the last valid bin. return;
if (threadIdx.x < kNumBins) {
int nextPrefixSum =
threadIdx.x == kNumBins - 1 ? totalSum : smemHistogram[threadIdx.x + 1];
if (prefixSum < kTopK && nextPrefixSum >= kTopK) {
smemThresholdBinIdx[0] = threadIdx.x;
}
} }
// Initialize values
// Clear the counter to store the items for the final phase.
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
smemFinalDstIdx[0] = 0; smemFinalDstIdx[0] = 0;
smemFoundTopKValues[0] = 0;
} }
// Make sure the data is in shared memory.
__syncthreads(); __syncthreads();
int thresholdBinIdx = -1;
uint32_t logitPattern = 0;
// Step 0: Process first 11 bits of half representation
bool continueToNextStep =
processHistogramStep<0, kNumThreadsPerBlock, kNumBins, kNumFinalItems,
multipleBlocksPerRow, mergeBlocks>(
indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput,
smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize,
smemFoundTopKValues, smemFinal, stride1, rowStart, topK);
if (continueToNextStep) {
// Step 1: Process next 11 bits
continueToNextStep =
processHistogramStep<1, kNumThreadsPerBlock, kNumBins, kNumFinalItems,
multipleBlocksPerRow, mergeBlocks>(
indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput,
smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize,
smemFoundTopKValues, smemFinal, stride1, rowStart, topK);
}
// The threshold bin. if (continueToNextStep) {
int thresholdBinIdx = smemThresholdBinIdx[0]; // Step 2: Process next 11 bits
continueToNextStep =
// Fetch elements one-by-one and populate the shared memory buffers. processHistogramStep<2, kNumThreadsPerBlock, kNumBins, kNumFinalItems,
for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd; multipleBlocksPerRow, mergeBlocks>(
rowIt += kNumThreadsPerBlock) { indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput,
float logit = logits[rowIdx * stride0 + rowIt * stride1]; smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize,
uint16_t idx = extractBinIdx(logit); smemFoundTopKValues, smemFinal, stride1, rowStart, topK);
if (idx < thresholdBinIdx) {
int dstIdx = atomicAdd(&smemHistogram[idx], 1);
smemIndices[dstIdx] = rowIt;
} else if (idx == thresholdBinIdx) {
int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1);
if (dstIdx < kNumFinalItems) {
smemFinal.items.logits[dstIdx] = logit;
smemFinal.items.indices[dstIdx] = rowIt;
}
}
} }
// Make sure the elements are in shared memory. if (continueToNextStep) {
__syncthreads(); // Step 3: Process last 10 bits
processHistogramStep<3, kNumThreadsPerBlock, kNumBins, kNumFinalItems,
multipleBlocksPerRow, mergeBlocks>(
indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput,
smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize,
smemFoundTopKValues, smemFinal, stride1, rowStart, topK);
}
// The logits of the elements to be sorted in the final pass. if (!continueToNextStep) {
float finalLogits[kNumFinalItemsPerThread]; // The histogram did not proceed to the final 10 bits, therefore we need to
// The indices of the elements to be sorted in the final pass. // sort the final items The logits of the elements to be sorted in the final
int finalIndices[kNumFinalItemsPerThread]; // pass.
if constexpr (useRadixSort) {
// Sorting with radix sort
float finalLogits[kNumFinalItemsPerThread];
// The indices of the elements to be sorted in the final pass.
int finalIndices[kNumFinalItemsPerThread];
// Init.
#pragma unroll #pragma unroll
for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) {
finalLogits[ii] = -FLT_MAX; finalLogits[ii] = -FLT_MAX;
} }
// Read the elements from SMEM. // Read the elements from SMEM.
#pragma unroll #pragma unroll
for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) {
int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x;
if (srcIdx < smemFinalDstIdx[0]) { if (srcIdx < smemFinalDstIdx[0]) {
finalLogits[ii] = smemFinal.items.logits[srcIdx]; finalLogits[ii] = smemFinal.items.logits[srcIdx];
finalIndices[ii] = smemFinal.items.indices[srcIdx]; finalIndices[ii] = smemFinal.items.indices[srcIdx];
} }
} }
// Make sure the shared memory has been read.
__syncthreads();
// Make sure the shared memory has been read. // Sort the elements.
__syncthreads(); FinalSort(smemFinal.finalSort)
.SortDescendingBlockedToStriped(finalLogits, finalIndices);
// Sort the elements. // Copy the data back to the shared memory storage.
FinalSort(smemFinal.finalSort) int baseIdx = smemFoundTopKValues[0];
.SortDescendingBlockedToStriped(finalLogits, finalIndices);
// Copy the data back to the shared memory storage.
int baseIdx = thresholdBinIdx > 0 ? smemHistogram[thresholdBinIdx - 1] : 0;
#pragma unroll #pragma unroll
for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) {
int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x;
int dstIdx = baseIdx + srcIdx; int dstIdx = baseIdx + srcIdx;
if (dstIdx < kTopK) {
smemIndices[dstIdx] = finalIndices[ii]; if (dstIdx < topK) {
smemOutput[dstIdx] = finalIndices[ii];
if constexpr (multipleBlocksPerRow) {
reinterpret_cast<float*>(smemOutput + topK)[dstIdx] =
finalLogits[ii];
}
}
}
} else {
// Sorting with insertion sort
auto baseIdx = smemFoundTopKValues[0];
for (int i = threadIdx.x; i < smemFinalDstIdx[0];
i += kNumThreadsPerBlock) {
int outIndex = 0;
auto logit = smemFinal.items.logits[i];
for (int j = 0; j < smemFinalDstIdx[0]; j++) {
auto otherLogit = smemFinal.items.logits[j];
if (logit < otherLogit || (logit == otherLogit && i < j)) {
outIndex++;
}
}
// Store if outIndex is in bounds
if (outIndex + baseIdx < topK) {
smemOutput[outIndex + baseIdx] = smemFinal.items.indices[i];
if constexpr (multipleBlocksPerRow) {
reinterpret_cast<float*>(smemOutput + topK)[outIndex + baseIdx] =
smemFinal.items.logits[i];
}
}
}
} }
__syncthreads();
} }
// Make sure the data is in shared memory. // Store to global memory.
__syncthreads(); for (int i = threadIdx.x; i < topK; i += kNumThreadsPerBlock) {
if constexpr (multipleBlocksPerRow) {
// Store to global memory. outIndices[i] = smemOutput[i];
#pragma unroll outLogits[i] = reinterpret_cast<float*>(smemOutput + topK)[i];
for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) { } else {
int offset = rowIdx * kTopK + ii * kNumThreadsPerBlock + threadIdx.x; if (stride1 == 1) {
outIndices[offset] = // stride1 == 1 will use vectorized_process, which indexes already skip
smemIndices[ii * kNumThreadsPerBlock + threadIdx.x] - rowStart; // the rowStart.
outIndices[i] = smemOutput[i];
} else {
outIndices[i] = smemOutput[i] - rowStart;
}
}
} }
} }
template <int kNumThreadsPerBlock = 512> template <int kNumThreadsPerBlock, bool useRadixSort>
static __global__ void topKPerRow(const float* logits, const int* rowStarts, static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowPrefill(
const int* rowEnds, int* outIndices, const float* logits, const int* rowStarts, const int* rowEnds,
int stride0, int stride1) { int* outIndices, int stride0, int stride1, const int topK,
const int offsetIndex) {
// The number of bins in the histogram. // The number of bins in the histogram.
static constexpr int kNumBins = 512; static constexpr int kNumBins = 2048;
// The top-k width.
static constexpr int kTopK = 2048;
// The row computed by this block. // The row computed by this block.
int rowIdx = blockIdx.x; int rowIdx = blockIdx.x + offsetIndex;
// The range of logits within the row. // The range of logits within the row.
int rowStart = rowStarts[rowIdx]; int rowStart = rowStarts[rowIdx];
int rowEnd = rowEnds[rowIdx]; int rowEnd = rowEnds[rowIdx];
topKPerRowJob<kNumThreadsPerBlock, kNumBins, kTopK>( // Local pointers to this block
logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1); outIndices += rowIdx * topK;
logits += rowIdx * stride0;
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort>(
nullptr, logits, rowStart, rowEnd, outIndices, nullptr, stride1, topK);
} }
template <int kNumThreadsPerBlock = 512> template <int kNumThreadsPerBlock, bool useRadixSort,
static __global__ void topKPerRowDecode(const float* logits, const int* seqLens, bool multipleBlocksPerRow = false, bool mergeBlocks = false>
int* outIndices, int stride0, static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
int stride1, int next_n) { const float* logits, const int* seqLens, int* outIndices, int stride0,
int stride1, const int topK, int next_n, float* outLogits = nullptr,
const int numBlocksToMerge = 0, const int* indices = nullptr) {
// The number of bins in the histogram. // The number of bins in the histogram.
static constexpr int kNumBins = 512; static constexpr int kNumBins = 2048;
// The top-k width.
static constexpr int kTopK = 2048;
// The row computed by this block. // The row computed by this block.
int rowIdx = blockIdx.x; int rowIdx = blockIdx.x;
...@@ -304,8 +589,25 @@ static __global__ void topKPerRowDecode(const float* logits, const int* seqLens, ...@@ -304,8 +589,25 @@ static __global__ void topKPerRowDecode(const float* logits, const int* seqLens,
int seq_len = seqLens[rowIdx / next_n]; int seq_len = seqLens[rowIdx / next_n];
int rowEnd = seq_len - next_n + (rowIdx % next_n) + 1; int rowEnd = seq_len - next_n + (rowIdx % next_n) + 1;
topKPerRowJob<kNumThreadsPerBlock, kNumBins, kTopK>( // Local pointers to this block
logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1); if constexpr (!multipleBlocksPerRow && !mergeBlocks) {
outIndices += rowIdx * topK;
} else if constexpr (multipleBlocksPerRow) {
const auto blockSize = rowEnd / gridDim.y; // 16384 / 2 = 8192
rowStart = blockSize * blockIdx.y; // 8192 * 1 = 8192
rowEnd = gridDim.y == blockIdx.y + 1 ? rowEnd : rowStart + blockSize;
outIndices += rowIdx * gridDim.y * topK + blockIdx.y * topK;
outLogits += rowIdx * gridDim.y * topK + blockIdx.y * topK;
} else if constexpr (mergeBlocks) {
rowEnd = numBlocksToMerge * topK;
indices += rowIdx * numBlocksToMerge * topK;
outIndices += rowIdx * topK;
}
logits += rowIdx * stride0;
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort,
multipleBlocksPerRow, mergeBlocks>(
indices, logits, rowStart, rowEnd, outIndices, outLogits, stride1, topK);
} }
} // namespace vllm } // namespace vllm
...@@ -353,28 +655,84 @@ void apply_repetition_penalties_( ...@@ -353,28 +655,84 @@ void apply_repetition_penalties_(
void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n, void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
const torch::Tensor& seqLens, torch::Tensor& indices, const torch::Tensor& seqLens, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1) { int64_t numRows, int64_t stride0, int64_t stride1,
// Compute the results on the device. int64_t topK) {
constexpr int kSortingAlgorithmThreshold = 12288;
constexpr int kSplitWorkThreshold = 200 * 1000;
constexpr int kNumThreadsPerBlock = 512; constexpr int kNumThreadsPerBlock = 512;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const auto numColumns = logits.size(1);
vllm::topKPerRowDecode<kNumThreadsPerBlock>
<<<numRows, kNumThreadsPerBlock, 0, stream>>>( if (numColumns < kSortingAlgorithmThreshold) {
logits.data_ptr<float>(), seqLens.data_ptr<int>(), // Use insertion sort
indices.data_ptr<int>(), static_cast<int>(stride0), vllm::topKPerRowDecode<kNumThreadsPerBlock, false>
static_cast<int>(stride1), static_cast<int>(next_n)); <<<numRows, kNumThreadsPerBlock, topK * sizeof(int32_t), stream>>>(
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
indices.data_ptr<int>(), static_cast<int>(stride0),
static_cast<int>(stride1), static_cast<int>(topK),
static_cast<int>(next_n));
} else if (numColumns < kSplitWorkThreshold) {
// From this threshold, use radix sort instead
vllm::topKPerRowDecode<kNumThreadsPerBlock, true>
<<<numRows, kNumThreadsPerBlock, topK * sizeof(int32_t), stream>>>(
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
indices.data_ptr<int>(), static_cast<int>(stride0),
static_cast<int>(stride1), static_cast<int>(topK),
static_cast<int>(next_n));
} else {
// Long sequences are run in two steps
constexpr auto multipleBlocksPerRowConfig = 10;
const auto outIndicesAux =
torch::empty({numRows, multipleBlocksPerRowConfig, topK},
torch::dtype(torch::kInt32).device(logits.device()));
const auto outLogitsAux =
torch::empty({numRows, multipleBlocksPerRowConfig, topK},
torch::dtype(torch::kFloat).device(logits.device()));
vllm::topKPerRowDecode<kNumThreadsPerBlock, true, true>
<<<dim3(numRows, multipleBlocksPerRowConfig), kNumThreadsPerBlock,
2 * topK * sizeof(int32_t), stream>>>(
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
outIndicesAux.data_ptr<int>(), static_cast<int>(stride0),
static_cast<int>(stride1), static_cast<int>(topK),
static_cast<int>(next_n), outLogitsAux.data_ptr<float>());
constexpr int kNumThreadsPerBlockMerge = 1024;
vllm::topKPerRowDecode<kNumThreadsPerBlockMerge, true, false, true>
<<<numRows, kNumThreadsPerBlockMerge, topK * sizeof(int32_t), stream>>>(
outLogitsAux.data_ptr<float>(), seqLens.data_ptr<int>(),
indices.data_ptr<int>(), multipleBlocksPerRowConfig * topK, 1,
static_cast<int>(topK), static_cast<int>(next_n), nullptr,
multipleBlocksPerRowConfig, outIndicesAux.data_ptr<int>());
}
} }
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts, void top_k_per_row_prefill(const torch::Tensor& logits,
const torch::Tensor& rowEnds, torch::Tensor& indices, const torch::Tensor& rowStarts,
int64_t numRows, int64_t stride0, int64_t stride1) { const torch::Tensor& rowEnds, torch::Tensor& indices,
// Compute the results on the device. int64_t numRows, int64_t stride0, int64_t stride1,
int64_t topK) {
constexpr int kSortingAlgorithmThreshold = 12288;
constexpr int kNumThreadsPerBlock = 512; constexpr int kNumThreadsPerBlock = 512;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::topKPerRow<kNumThreadsPerBlock> int numInsertionBlocks =
<<<numRows, kNumThreadsPerBlock, 0, stream>>>( std::min(static_cast<int>(numRows), kSortingAlgorithmThreshold);
logits.data_ptr<float>(), rowStarts.data_ptr<int>(), vllm::topKPerRowPrefill<kNumThreadsPerBlock, false>
rowEnds.data_ptr<int>(), indices.data_ptr<int>(), <<<numInsertionBlocks, kNumThreadsPerBlock, topK * sizeof(int32_t),
static_cast<int>(stride0), static_cast<int>(stride1)); stream>>>(logits.data_ptr<float>(), rowStarts.data_ptr<int>(),
rowEnds.data_ptr<int>(), indices.data_ptr<int>(),
static_cast<int>(stride0), static_cast<int>(stride1),
static_cast<int>(topK), 0);
if (numRows > kSortingAlgorithmThreshold) {
int numRadixBlocks = numRows - kSortingAlgorithmThreshold;
vllm::topKPerRowPrefill<kNumThreadsPerBlock, true>
<<<numRadixBlocks, kNumThreadsPerBlock, topK * sizeof(int32_t),
stream>>>(logits.data_ptr<float>(), rowStarts.data_ptr<int>(),
rowEnds.data_ptr<int>(), indices.data_ptr<int>(),
static_cast<int>(stride0), static_cast<int>(stride1),
static_cast<int>(topK), kSortingAlgorithmThreshold);
}
} }
...@@ -179,15 +179,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -179,15 +179,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Optimized top-k per row operation // Optimized top-k per row operation
ops.def( ops.def(
"top_k_per_row(Tensor logits, Tensor rowStarts, Tensor rowEnds, " "top_k_per_row_prefill(Tensor logits, Tensor rowStarts, Tensor rowEnds, "
"Tensor! indices, int numRows, int stride0, " "Tensor! indices, int numRows, int stride0, "
"int stride1) -> ()"); "int stride1, int topK) -> ()");
ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row); ops.impl("top_k_per_row_prefill", torch::kCUDA, &top_k_per_row_prefill);
ops.def( ops.def(
"top_k_per_row_decode(Tensor logits, int next_n, " "top_k_per_row_decode(Tensor logits, int next_n, "
"Tensor seq_lens, Tensor! indices, int numRows, " "Tensor seq_lens, Tensor! indices, "
"int stride0, int stride1) -> ()"); "int numRows, int stride0, int stride1, int topK) -> ()");
ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode); ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode);
// Layernorm-quant // Layernorm-quant
...@@ -215,6 +215,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -215,6 +215,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA, ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA,
&rms_norm_dynamic_per_token_quant); &rms_norm_dynamic_per_token_quant);
// Fused Layernorm + Block quant kernels
ops.def(
"rms_norm_per_block_quant(Tensor! result, Tensor input, "
"Tensor weight, Tensor! scale, float epsilon, "
"Tensor? scale_ub, Tensor!? residual, int group_size, "
"bool is_scale_transposed) -> ()");
ops.impl("rms_norm_per_block_quant", torch::kCUDA, &rms_norm_per_block_quant);
// Rotary embedding // Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key. // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def( ops.def(
...@@ -346,6 +354,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -346,6 +354,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor"); ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor");
// conditionally compiled so impl registration is in source file // conditionally compiled so impl registration is in source file
// CUTLASS w4a8 grouped GEMM
ops.def(
"cutlass_w4a8_moe_mm("
" Tensor! out_tensors,"
" Tensor a_tensors,"
" Tensor b_tensors,"
" Tensor a_scales,"
" Tensor b_scales,"
" Tensor b_group_scales,"
" int b_group_size,"
" Tensor expert_offsets,"
" Tensor problem_sizes,"
" Tensor a_strides,"
" Tensor b_strides,"
" Tensor c_strides,"
" Tensor group_scale_strides,"
" str? maybe_schedule"
") -> ()");
ops.def(
"cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, "
"Tensor)");
// conditionally compiled so impl registration is in source file
#endif #endif
// Dequantization for GGML. // Dequantization for GGML.
...@@ -462,7 +493,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -462,7 +493,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor! problem_sizes1, " " Tensor! problem_sizes1, "
" Tensor! problem_sizes2, " " Tensor! problem_sizes2, "
" int num_experts, int n, int k, " " int num_experts, int n, int k, "
" Tensor? blockscale_offsets) -> ()"); " Tensor? blockscale_offsets, "
" bool? force_swap_ab) -> ()");
ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA, ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA,
&get_cutlass_moe_mm_problem_sizes); &get_cutlass_moe_mm_problem_sizes);
...@@ -621,6 +653,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -621,6 +653,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("per_token_group_fp8_quant", torch::kCUDA, ops.impl("per_token_group_fp8_quant", torch::kCUDA,
&per_token_group_quant_fp8); &per_token_group_quant_fp8);
// Compute per-token-group 8-bit quantized tensor and UE8M0-packed,
// TMA-aligned scales for DeepGEMM.
ops.def(
"per_token_group_fp8_quant_packed(Tensor input, Tensor! output_q, "
"Tensor! output_s_packed, int group_size, float eps, float fp8_min, "
"float fp8_max) -> ()");
ops.impl("per_token_group_fp8_quant_packed", torch::kCUDA,
&per_token_group_quant_8bit_packed);
// Compute per-token-group INT8 quantized tensor and scaling factor. // Compute per-token-group INT8 quantized tensor and scaling factor.
ops.def( ops.def(
"per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! " "per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! "
......
...@@ -150,8 +150,8 @@ ARG torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0 10.0 12.0' ...@@ -150,8 +150,8 @@ ARG torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0 10.0 12.0'
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
#################### BASE BUILD IMAGE #################### #################### BASE BUILD IMAGE ####################
#################### WHEEL BUILD IMAGE #################### #################### CSRC BUILD IMAGE ####################
FROM base AS build FROM base AS csrc-build
ARG TARGETPLATFORM ARG TARGETPLATFORM
ARG PIP_INDEX_URL UV_INDEX_URL ARG PIP_INDEX_URL UV_INDEX_URL
...@@ -172,10 +172,13 @@ RUN --mount=type=cache,target=/root/.cache/uv \ ...@@ -172,10 +172,13 @@ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --python /opt/venv/bin/python3 -r requirements/build.txt \ uv pip install --python /opt/venv/bin/python3 -r requirements/build.txt \
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
COPY . . WORKDIR /workspace
ARG GIT_REPO_CHECK=0
RUN --mount=type=bind,source=.git,target=.git \ COPY pyproject.toml setup.py CMakeLists.txt ./
if [ "$GIT_REPO_CHECK" != "0" ]; then bash tools/check_repo.sh ; fi COPY cmake cmake/
COPY csrc csrc/
COPY vllm/envs.py vllm/envs.py
COPY vllm/__init__.py vllm/__init__.py
# max jobs used by Ninja to build extensions # max jobs used by Ninja to build extensions
ARG max_jobs=2 ARG max_jobs=2
...@@ -193,11 +196,14 @@ ARG SCCACHE_S3_NO_CREDENTIALS=0 ...@@ -193,11 +196,14 @@ ARG SCCACHE_S3_NO_CREDENTIALS=0
# Flag to control whether to use pre-built vLLM wheels # Flag to control whether to use pre-built vLLM wheels
ARG VLLM_USE_PRECOMPILED="" ARG VLLM_USE_PRECOMPILED=""
ARG VLLM_MERGE_BASE_COMMIT=""
ARG VLLM_MAIN_CUDA_VERSION="" ARG VLLM_MAIN_CUDA_VERSION=""
# Use dummy version for csrc-build wheel (only .so files are extracted, version doesn't matter)
ENV SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0+csrc.build"
# if USE_SCCACHE is set, use sccache to speed up compilation # if USE_SCCACHE is set, use sccache to speed up compilation
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=.git,target=.git \
if [ "$USE_SCCACHE" = "1" ]; then \ if [ "$USE_SCCACHE" = "1" ]; then \
echo "Installing sccache..." \ echo "Installing sccache..." \
&& curl -L -o sccache.tar.gz ${SCCACHE_DOWNLOAD_URL} \ && curl -L -o sccache.tar.gz ${SCCACHE_DOWNLOAD_URL} \
...@@ -211,6 +217,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ ...@@ -211,6 +217,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
&& export SCCACHE_IDLE_TIMEOUT=0 \ && export SCCACHE_IDLE_TIMEOUT=0 \
&& export CMAKE_BUILD_TYPE=Release \ && export CMAKE_BUILD_TYPE=Release \
&& export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" \ && export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" \
&& export VLLM_PRECOMPILED_WHEEL_COMMIT="${VLLM_MERGE_BASE_COMMIT}" \
&& export VLLM_MAIN_CUDA_VERSION="${VLLM_MAIN_CUDA_VERSION}" \ && export VLLM_MAIN_CUDA_VERSION="${VLLM_MAIN_CUDA_VERSION}" \
&& export VLLM_DOCKER_BUILD_CONTEXT=1 \ && export VLLM_DOCKER_BUILD_CONTEXT=1 \
&& sccache --show-stats \ && sccache --show-stats \
...@@ -223,15 +230,61 @@ ENV VLLM_TARGET_DEVICE=${vllm_target_device} ...@@ -223,15 +230,61 @@ ENV VLLM_TARGET_DEVICE=${vllm_target_device}
ENV CCACHE_DIR=/root/.cache/ccache ENV CCACHE_DIR=/root/.cache/ccache
RUN --mount=type=cache,target=/root/.cache/ccache \ RUN --mount=type=cache,target=/root/.cache/ccache \
--mount=type=cache,target=/root/.cache/uv \ --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=.git,target=.git \
if [ "$USE_SCCACHE" != "1" ]; then \ if [ "$USE_SCCACHE" != "1" ]; then \
# Clean any existing CMake artifacts # Clean any existing CMake artifacts
rm -rf .deps && \ rm -rf .deps && \
mkdir -p .deps && \ mkdir -p .deps && \
export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" && \ export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" && \
export VLLM_PRECOMPILED_WHEEL_COMMIT="${VLLM_MERGE_BASE_COMMIT}" && \
export VLLM_DOCKER_BUILD_CONTEXT=1 && \ export VLLM_DOCKER_BUILD_CONTEXT=1 && \
python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \ python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \
fi fi
#################### CSRC BUILD IMAGE ####################
#################### WHEEL BUILD IMAGE ####################
FROM base AS build
ARG TARGETPLATFORM
ARG PIP_INDEX_URL UV_INDEX_URL
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
ARG PYTORCH_CUDA_INDEX_BASE_URL
# install build dependencies
COPY requirements/build.txt requirements/build.txt
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
ENV UV_HTTP_TIMEOUT=500
ENV UV_INDEX_STRATEGY="unsafe-best-match"
# Use copy mode to avoid hardlink failures with Docker cache mounts
ENV UV_LINK_MODE=copy
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --python /opt/venv/bin/python3 -r requirements/build.txt \
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
WORKDIR /workspace
COPY --from=csrc-build /workspace/dist /precompiled-wheels
COPY . .
ARG GIT_REPO_CHECK=0
RUN --mount=type=bind,source=.git,target=.git \
if [ "$GIT_REPO_CHECK" != "0" ]; then bash tools/check_repo.sh ; fi
ARG vllm_target_device="cuda"
ENV VLLM_TARGET_DEVICE=${vllm_target_device}
# Skip adding +precompiled suffix to version (preserves git-derived version)
ENV VLLM_SKIP_PRECOMPILED_VERSION_SUFFIX=1
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=.git,target=.git \
if [ "${vllm_target_device}" = "cuda" ]; then \
export VLLM_PRECOMPILED_WHEEL_LOCATION=$(ls /precompiled-wheels/*.whl); \
fi && \
python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38
# Install DeepGEMM from source # Install DeepGEMM from source
ARG DEEPGEMM_GIT_REF ARG DEEPGEMM_GIT_REF
...@@ -527,7 +580,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ ...@@ -527,7 +580,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
else \ else \
BITSANDBYTES_VERSION="0.46.1"; \ BITSANDBYTES_VERSION="0.46.1"; \
fi; \ fi; \
uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' 'runai-model-streamer[s3,gcs]>=0.15.0' uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' 'runai-model-streamer[s3,gcs]>=0.15.3'
ENV VLLM_USAGE_SOURCE production-docker-image ENV VLLM_USAGE_SOURCE production-docker-image
......
...@@ -65,7 +65,6 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/tests /tests ...@@ -65,7 +65,6 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/tests /tests
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/examples /examples COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/examples /examples
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/docker/Dockerfile.rocm /docker/ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/docker/Dockerfile.rocm /docker/
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/.buildkite /.buildkite COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/.buildkite /.buildkite
# Centralized v1 package - copied to both test and final stages
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/vllm/v1 /vllm_v1 COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/vllm/v1 /vllm_v1
# ----------------------- # -----------------------
...@@ -98,7 +97,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ ...@@ -98,7 +97,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system hf_transfer uv pip install --system hf_transfer
ENV HF_HUB_ENABLE_HF_TRANSFER=1 ENV HF_HUB_ENABLE_HF_TRANSFER=1
# Copy in the v1 package # Copy in the v1 package (for python-only install test group)
COPY --from=export_vllm /vllm_v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1 COPY --from=export_vllm /vllm_v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1
# Source code is used in the `python_only_compile.sh` test # Source code is used in the `python_only_compile.sh` test
...@@ -130,9 +129,6 @@ RUN --mount=type=bind,from=export_vllm,src=/,target=/install \ ...@@ -130,9 +129,6 @@ RUN --mount=type=bind,from=export_vllm,src=/,target=/install \
&& pip uninstall -y vllm \ && pip uninstall -y vllm \
&& uv pip install --system *.whl && uv pip install --system *.whl
# Copy in the v1 package
COPY --from=export_vllm /vllm_v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1
ARG COMMON_WORKDIR ARG COMMON_WORKDIR
# Copy over the benchmark scripts as well # Copy over the benchmark scripts as well
......
...@@ -5,11 +5,7 @@ nav: ...@@ -5,11 +5,7 @@ nav:
- Getting Started: - Getting Started:
- getting_started/quickstart.md - getting_started/quickstart.md
- getting_started/installation - getting_started/installation
- Examples: - Examples: examples
- examples/README.md
- Offline Inference: examples/offline_inference
- Online Serving: examples/online_serving
- Others: examples/others
- General: - General:
- usage/v1_guide.md - usage/v1_guide.md
- usage/* - usage/*
...@@ -63,6 +59,7 @@ nav: ...@@ -63,6 +59,7 @@ nav:
- CLI Reference: cli - CLI Reference: cli
- Community: - Community:
- community/* - community/*
- Governance: governance
- Blog: https://blog.vllm.ai - Blog: https://blog.vllm.ai
- Forum: https://discuss.vllm.ai - Forum: https://discuss.vllm.ai
- Slack: https://slack.vllm.ai - Slack: https://slack.vllm.ai
...@@ -15,6 +15,7 @@ API documentation for vLLM's configuration classes. ...@@ -15,6 +15,7 @@ API documentation for vLLM's configuration classes.
- [vllm.config.MultiModalConfig][] - [vllm.config.MultiModalConfig][]
- [vllm.config.PoolerConfig][] - [vllm.config.PoolerConfig][]
- [vllm.config.StructuredOutputsConfig][] - [vllm.config.StructuredOutputsConfig][]
- [vllm.config.ProfilerConfig][]
- [vllm.config.ObservabilityConfig][] - [vllm.config.ObservabilityConfig][]
- [vllm.config.KVTransferConfig][] - [vllm.config.KVTransferConfig][]
- [vllm.config.CompilationConfig][] - [vllm.config.CompilationConfig][]
......
...@@ -670,6 +670,35 @@ vllm bench serve \ ...@@ -670,6 +670,35 @@ vllm bench serve \
</details> </details>
### 🧪 Hashing Benchmarks
<details class="admonition abstract" markdown="1">
<summary>Show more</summary>
Two helper scripts live in `benchmarks/` to compare hashing options used by prefix caching and related utilities. They are standalone (no server required) and help choose a hash algorithm before enabling prefix caching in production.
- `benchmarks/benchmark_hash.py`: Micro-benchmark that measures per-call latency of three implementations on a representative `(bytes, tuple[int])` payload.
```bash
python benchmarks/benchmark_hash.py --iterations 20000 --seed 42
```
- `benchmarks/benchmark_prefix_block_hash.py`: End-to-end block hashing benchmark that runs the full prefix-cache hash pipeline (`hash_block_tokens`) across many fake blocks and reports throughput.
```bash
python benchmarks/benchmark_prefix_block_hash.py --num-blocks 20000 --block-size 32 --trials 5
```
Supported algorithms: `sha256`, `sha256_cbor`, `xxhash`, `xxhash_cbor`. Install optional deps to exercise all variants:
```bash
uv pip install xxhash cbor2
```
If an algorithm’s dependency is missing, the script will skip it and continue.
</details>
### ⚡ Request Prioritization Benchmark ### ⚡ Request Prioritization Benchmark
<details class="admonition abstract" markdown="1"> <details class="admonition abstract" markdown="1">
......
...@@ -18,6 +18,7 @@ Compute Resources: ...@@ -18,6 +18,7 @@ Compute Resources:
- Alibaba Cloud - Alibaba Cloud
- AMD - AMD
- Anyscale - Anyscale
- Arm
- AWS - AWS
- Crusoe Cloud - Crusoe Cloud
- Databricks - Databricks
......
# Nightly Builds of vLLM Wheels
vLLM maintains a per-commit wheel repository (commonly referred to as "nightly") at `https://wheels.vllm.ai` that provides pre-built wheels for every commit on the `main` branch since `v0.5.3`. This document explains how the nightly wheel index mechanism works.
## Build and Upload Process on CI
### Wheel Building
Wheels are built in the `Release` pipeline (`.buildkite/release-pipeline.yaml`) after a PR is merged into the main branch, with multiple variants:
- **Backend variants**: `cpu` and `cuXXX` (e.g., `cu129`, `cu130`).
- **Architecture variants**: `x86_64` and `aarch64`.
Each build step:
1. Builds the wheel in a Docker container.
2. Renames the wheel filename to use the correct manylinux tag (currently `manylinux_2_31`) for PEP 600 compliance.
3. Uploads the wheel to S3 bucket `vllm-wheels` under `/{commit_hash}/`.
### Index Generation
After uploading each wheel, the `.buildkite/scripts/upload-wheels.sh` script:
1. **Lists all existing wheels** in the commit directory from S3
2. **Generates indices** using `.buildkite/scripts/generate-nightly-index.py`:
- Parses wheel filenames to extract metadata (version, variant, platform tags).
- Creates HTML index files (`index.html`) for PyPI compatibility.
- Generates machine-readable `metadata.json` files.
3. **Uploads indices** to multiple locations (overriding existing ones):
- `/{commit_hash}/` - Always uploaded for commit-specific access.
- `/nightly/` - Only for commits on `main` branch (not PRs).
- `/{version}/` - Only for release wheels (no `dev` in its version).
!!! tip "Handling Concurrent Builds"
The index generation script can handle multiple variants being built concurrently by always listing all wheels in the commit directory before generating indices, avoiding race conditions.
## Directory Structure
The S3 bucket structure follows this pattern:
```text
s3://vllm-wheels/
├── {commit_hash}/ # Commit-specific wheels and indices
│ ├── vllm-*.whl # All wheel files
│ ├── index.html # Project list (default variant)
│ ├── vllm/
│ │ ├── index.html # Package index (default variant)
│ │ └── metadata.json # Metadata (default variant)
│ ├── cu129/ # Variant subdirectory
│ │ ├── index.html # Project list (cu129 variant)
│ │ └── vllm/
│ │ ├── index.html # Package index (cu129 variant)
│ │ └── metadata.json # Metadata (cu129 variant)
│ ├── cu130/ # Variant subdirectory
│ ├── cpu/ # Variant subdirectory
│ └── .../ # More variant subdirectories
├── nightly/ # Latest main branch wheels (mirror of latest commit)
└── {version}/ # Release version indices (e.g., 0.11.2)
```
All built wheels are stored in `/{commit_hash}/`, while different indices are generated and reference them.
This avoids duplication of wheel files.
For example, you can specify the following URLs to use different indices:
- `https://wheels.vllm.ai/nightly/cu130` for the latest main branch wheels built with CUDA 13.0.
- `https://wheels.vllm.ai/{commit_hash}` for wheels built at a specific commit (default variant).
- `https://wheels.vllm.ai/0.12.0/cpu` for 0.12.0 release wheels built for CPU variant.
Please note that not all variants are present on every commit. The available variants are subject to change over time, e.g., changing cu130 to cu131.
### Variant Organization
Indices are organized by variant:
- **Default variant**: Wheels without variant suffix (i.e., built with the current `VLLM_MAIN_CUDA_VERSION`) are placed in the root.
- **Variant subdirectories**: Wheels with variant suffixes (e.g., `+cu130`, `.cpu`) are organized in subdirectories.
- **Alias to default**: The default variant can have an alias (e.g., `cu129` for now) for consistency and convenience.
The variant is extracted from the wheel filename (as described in the [file name convention](https://packaging.python.org/en/latest/specifications/binary-distribution-format/#file-name-convention)):
- The variant is encoded in the local version identifier (e.g. `+cu129` or `dev<N>+g<hash>.cu130`).
- Examples:
- `vllm-0.11.2.dev278+gdbc3d9991-cp38-abi3-manylinux1_x86_64.whl` → default variant
- `vllm-0.10.2rc2+cu129-cp38-abi3-manylinux2014_aarch64.whl``cu129` variant
- `vllm-0.11.1rc8.dev14+gaa384b3c0.cu130-cp38-abi3-manylinux1_x86_64.whl``cu130` variant
## Index Generation Details
The `generate-nightly-index.py` script performs the following:
1. **Parses wheel filenames** using regex to extract:
- Package name
- Version (with variant extracted)
- Python tag, ABI tag, platform tag
- Build tag (if present)
2. **Groups wheels by variant**, then by package name:
- Currently only `vllm` is built, but the structure supports multiple packages in the future.
3. **Generates HTML indices** (compliant with the [Simple repository API](https://packaging.python.org/en/latest/specifications/simple-repository-api/#simple-repository-api)):
- Top-level `index.html`: Lists all packages and variant subdirectories
- Package-level `index.html`: Lists all wheel files for that package
- Uses relative paths to wheel files for portability
4. **Generates metadata.json**:
- Machine-readable JSON containing all wheel metadata
- Includes `path` field with URL-encoded relative path to wheel file
- Used by `setup.py` to locate compatible pre-compiled wheels during Python-only builds
### Special Handling for AWS Services
The wheels and indices are directly stored on AWS S3, and we use AWS CloudFront as a CDN in front of the S3 bucket.
Since S3 does not provide proper directory listing, to support PyPI-compatible simple repository API behavior, we deploy a CloudFront Function that:
- redirects any URL that does not end with `/` and does not look like a file (i.e., does not contain a dot `.` in the last path segment) to the same URL with a trailing `/`
- appends `/index.html` to any URL that ends with `/`
For example, the following requests would be handled as:
- `/nightly` -> `/nightly/index.html`
- `/nightly/cu130/` -> `/nightly/cu130/index.html`
- `/nightly/index.html` or `/nightly/vllm.whl` -> unchanged
!!! note "AWS S3 Filename Escaping"
S3 will automatically escape filenames upon upload according to its [naming rule](https://docs.aws.amazon.com/AmazonS3/latest/userguide/object-keys.html). The direct impact on vllm is that `+` in filenames will be converted to `%2B`. We take special care in the index generation script to escape filenames properly when generating the HTML indices and JSON metadata, to ensure the URLs are correct and can be directly used.
## Usage of precompiled wheels in `setup.py` {#precompiled-wheels-usage}
When installing vLLM with `VLLM_USE_PRECOMPILED=1`, the `setup.py` script:
1. **Determines wheel location** via `precompiled_wheel_utils.determine_wheel_url()`:
- Env var `VLLM_PRECOMPILED_WHEEL_LOCATION` (user-specified URL/path) always takes precedence and skips all other steps.
- Determines the variant from `VLLM_MAIN_CUDA_VERSION` (can be overridden with env var `VLLM_PRECOMPILED_WHEEL_VARIANT`); the default variant will also be tried as a fallback.
- Determines the _base commit_ (explained later) of this branch (can be overridden with env var `VLLM_PRECOMPILED_WHEEL_COMMIT`).
2. **Fetches metadata** from `https://wheels.vllm.ai/{commit}/vllm/metadata.json` (for the default variant) or `https://wheels.vllm.ai/{commit}/{variant}/vllm/metadata.json` (for a specific variant).
3. **Selects compatible wheel** based on:
- Package name (`vllm`)
- Platform tag (architecture match)
4. **Downloads and extracts** precompiled binaries from the wheel:
- C++ extension modules (`.so` files)
- Flash Attention Python modules
- Triton kernel Python files
5. **Patches package_data** to include extracted files in the installation
!!! note "What is the base commit?"
The base commit is determined by finding the merge-base
between the current branch and upstream `main`, ensuring
compatibility between source code and precompiled binaries.
_Note: it's users' responsibility to ensure there is no native code (e.g., C++ or CUDA) changes before using precompiled wheels._
## Implementation Files
Key files involved in the nightly wheel mechanism:
- **`.buildkite/release-pipeline.yaml`**: CI pipeline that builds wheels
- **`.buildkite/scripts/upload-wheels.sh`**: Script that uploads wheels and generates indices
- **`.buildkite/scripts/generate-nightly-index.py`**: Python script that generates PyPI-compatible indices
- **`setup.py`**: Contains `precompiled_wheel_utils` class for fetching and using precompiled wheels
...@@ -5,16 +5,15 @@ ...@@ -5,16 +5,15 @@
## Profile with PyTorch Profiler ## Profile with PyTorch Profiler
We support tracing vLLM workers using the `torch.profiler` module. You can enable tracing by setting the `VLLM_TORCH_PROFILER_DIR` environment variable to the directory where you want to save the traces: `VLLM_TORCH_PROFILER_DIR=/mnt/traces/`. Additionally, you can control the profiling content by specifying the following environment variables: We support tracing vLLM workers using the `torch.profiler` module. You can enable the torch profiler by setting `--profiler-config`
when launching the server, and setting the entries `profiler` to `'torch'` and `torch_profiler_dir` to the directory where you want to save the traces. Additionally, you can control the profiling content by specifying the following additional arguments in the config:
- `VLLM_TORCH_PROFILER_RECORD_SHAPES=1` to enable recording Tensor Shapes, off by default - `torch_profiler_record_shapes` to enable recording Tensor Shapes, off by default
- `VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1` to record memory, off by default - `torch_profiler_with_memory` to record memory, off by default
- `VLLM_TORCH_PROFILER_WITH_STACK=1` to enable recording stack information, on by default - `torch_profiler_with_stack` to enable recording stack information, on by default
- `VLLM_TORCH_PROFILER_WITH_FLOPS=1` to enable recording FLOPs, off by default - `torch_profiler_with_flops` to enable recording FLOPs, off by default
- `VLLM_TORCH_PROFILER_USE_GZIP=0` to disable gzip-compressing profiling files, on by default - `torch_profiler_use_gzip` to control gzip-compressing profiling files, on by default
- `VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL=0` to disable dumping and printing the aggregated CUDA self time table, on by default - `torch_profiler_dump_cuda_time_total` to control dumping and printing the aggregated CUDA self time table, on by default
The OpenAI server also needs to be started with the `VLLM_TORCH_PROFILER_DIR` environment variable set.
When using `vllm bench serve`, you can enable profiling by passing the `--profile` flag. When using `vllm bench serve`, you can enable profiling by passing the `--profile` flag.
...@@ -40,8 +39,7 @@ Refer to [examples/offline_inference/simple_profiling.py](../../examples/offline ...@@ -40,8 +39,7 @@ Refer to [examples/offline_inference/simple_profiling.py](../../examples/offline
#### OpenAI Server #### OpenAI Server
```bash ```bash
VLLM_TORCH_PROFILER_DIR=./vllm_profile \ vllm serve meta-llama/Llama-3.1-8B-Instruct --profiler-config '{"profiler": "torch", "torch_profiler_dir": "./vllm_profile"}'
vllm serve meta-llama/Llama-3.1-8B-Instruct
``` ```
vllm bench command: vllm bench command:
...@@ -104,13 +102,12 @@ To profile the server, you will want to prepend your `vllm serve` command with ` ...@@ -104,13 +102,12 @@ To profile the server, you will want to prepend your `vllm serve` command with `
```bash ```bash
# server # server
VLLM_TORCH_CUDA_PROFILE=1 \
nsys profile \ nsys profile \
--trace-fork-before-exec=true \ --trace-fork-before-exec=true \
--cuda-graph-trace=node \ --cuda-graph-trace=node \
--capture-range=cudaProfilerApi \ --capture-range=cudaProfilerApi \
--capture-range-end repeat \ --capture-range-end repeat \
vllm serve meta-llama/Llama-3.1-8B-Instruct vllm serve meta-llama/Llama-3.1-8B-Instruct --profiler-config.profiler cuda
# client # client
vllm bench serve \ vllm bench serve \
......
# Kthena
[**Kthena**](https://github.com/volcano-sh/kthena) is a Kubernetes-native LLM inference platform that transforms how organizations deploy and manage Large Language Models in production. Built with declarative model lifecycle management and intelligent request routing, it provides high performance and enterprise-grade scalability for LLM inference workloads.
This guide shows how to deploy a production-grade, **multi-node vLLM** service on Kubernetes.
We’ll:
- Install the required components (Kthena + Volcano).
- Deploy a multi-node vLLM model via Kthena’s `ModelServing` CR.
- Validate the deployment.
---
## 1. Prerequisites
You need:
- A Kubernetes cluster with **GPU nodes**.
- `kubectl` access with cluster-admin or equivalent permissions.
- **Volcano** installed for gang scheduling.
- **Kthena** installed with the `ModelServing` CRD available.
- A valid **Hugging Face token** if loading models from Hugging Face Hub.
### 1.1 Install Volcano
```bash
helm repo add volcano-sh https://volcano-sh.github.io/helm-charts
helm repo update
helm install volcano volcano-sh/volcano -n volcano-system --create-namespace
```
This provides the gang-scheduling and network topology features used by Kthena.
### 1.2 Install Kthena
```bash
helm install kthena oci://ghcr.io/volcano-sh/charts/kthena --version v0.1.0 --namespace kthena-system --create-namespace
```
- The `kthena-system` namespace is created.
- Kthena controllers and CRDs, including `ModelServing`, are installed and healthy.
Validate:
```bash
kubectl get crd | grep modelserving
```
You should see:
```text
modelservings.workload.serving.volcano.sh ...
```
---
## 2. The Multi-Node vLLM `ModelServing` Example
Kthena provides an example manifest to deploy a **multi-node vLLM cluster running Llama**. Conceptually this is equivalent to the vLLM production stack Helm deployment, but expressed with `ModelServing`.
A simplified version of the example (`llama-multinode`) looks like:
- `spec.replicas: 1` – one `ServingGroup` (one logical model deployment).
- `roles`:
- `entryTemplate` – defines **leader** pods that run:
- vLLM’s **multi-node cluster bootstrap script** (Ray cluster).
- vLLM **OpenAI-compatible API server**.
- `workerTemplate` – defines **worker** pods that join the leader’s Ray cluster.
Key points from the example YAML:
- **Image**: `vllm/vllm-openai:latest` (matches upstream vLLM images).
- **Command** (leader):
```yaml
command:
- sh
- -c
- >
bash /vllm-workspace/examples/online_serving/multi-node-serving.sh leader --ray_cluster_size=2;
python3 -m vllm.entrypoints.openai.api_server
--port 8080
--model meta-llama/Llama-3.1-405B-Instruct
--tensor-parallel-size 8
--pipeline-parallel-size 2
```
- **Command** (worker):
```yaml
command:
- sh
- -c
- >
bash /vllm-workspace/examples/online_serving/multi-node-serving.sh worker --ray_address=$(ENTRY_ADDRESS)
```
---
## 3. Deploying Multi-Node llama vLLM via Kthena
### 3.1 Prepare the Manifest
**Recommended**: use a Secret instead of a raw env var:
```bash
kubectl create secret generic hf-token \
-n default \
--from-literal=HUGGING_FACE_HUB_TOKEN='<your-token>'
```
### 3.2 Apply the `ModelServing`
```bash
cat <<EOF | kubectl apply -f -
apiVersion: workload.serving.volcano.sh/v1alpha1
kind: ModelServing
metadata:
name: llama-multinode
namespace: default
spec:
schedulerName: volcano
replicas: 1 # group replicas
template:
restartGracePeriodSeconds: 60
gangPolicy:
minRoleReplicas:
405b: 1
roles:
- name: 405b
replicas: 2
entryTemplate:
spec:
containers:
- name: leader
image: vllm/vllm-openai:latest
env:
- name: HUGGING_FACE_HUB_TOKEN
valueFrom:
secretKeyRef:
name: hf-token
key: HUGGING_FACE_HUB_TOKEN
command:
- sh
- -c
- "bash /vllm-workspace/examples/online_serving/multi-node-serving.sh leader --ray_cluster_size=2;
python3 -m vllm.entrypoints.openai.api_server --port 8080 --model meta-llama/Llama-3.1-405B-Instruct --tensor-parallel-size 8 --pipeline-parallel-size 2"
resources:
limits:
nvidia.com/gpu: "8"
memory: 1124Gi
ephemeral-storage: 800Gi
requests:
ephemeral-storage: 800Gi
cpu: 125
ports:
- containerPort: 8080
readinessProbe:
tcpSocket:
port: 8080
initialDelaySeconds: 15
periodSeconds: 10
volumeMounts:
- mountPath: /dev/shm
name: dshm
volumes:
- name: dshm
emptyDir:
medium: Memory
sizeLimit: 15Gi
workerReplicas: 1
workerTemplate:
spec:
containers:
- name: worker
image: vllm/vllm-openai:latest
command:
- sh
- -c
- "bash /vllm-workspace/examples/online_serving/multi-node-serving.sh worker --ray_address=$(ENTRY_ADDRESS)"
resources:
limits:
nvidia.com/gpu: "8"
memory: 1124Gi
ephemeral-storage: 800Gi
requests:
ephemeral-storage: 800Gi
cpu: 125
env:
- name: HUGGING_FACE_HUB_TOKEN
valueFrom:
secretKeyRef:
name: hf-token
key: HUGGING_FACE_HUB_TOKEN
volumeMounts:
- mountPath: /dev/shm
name: dshm
volumes:
- name: dshm
emptyDir:
medium: Memory
sizeLimit: 15Gi
EOF
```
Kthena will:
- Create a `ModelServing` object.
- Derive a `PodGroup` for Volcano gang scheduling.
- Create the leader and worker pods for each `ServingGroup` and `Role`.
---
## 4. Verifying the Deployment
### 4.1 Check ModelServing Status
Use the snippet from the Kthena docs:
```bash
kubectl get modelserving -oyaml | grep status -A 10
```
You should see something like:
```yaml
status:
availableReplicas: 1
conditions:
- type: Available
status: "True"
reason: AllGroupsReady
message: All Serving groups are ready
- type: Progressing
status: "False"
...
replicas: 1
updatedReplicas: 1
```
### 4.2 Check Pods
List pods for your deployment:
```bash
kubectl get pod -owide -l modelserving.volcano.sh/name=llama-multinode
```
Example output (from docs):
```text
NAMESPACE NAME READY STATUS RESTARTS AGE IP NODE ...
default llama-multinode-0-405b-0-0 1/1 Running 0 15m 10.244.0.56 192.168.5.12 ...
default llama-multinode-0-405b-0-1 1/1 Running 0 15m 10.244.0.58 192.168.5.43 ...
default llama-multinode-0-405b-1-0 1/1 Running 0 15m 10.244.0.57 192.168.5.58 ...
default llama-multinode-0-405b-1-1 1/1 Running 0 15m 10.244.0.53 192.168.5.36 ...
```
Pod name pattern:
- `llama-multinode-<group-idx>-<role-name>-<replica-idx>-<ordinal>`.
The first number indicates `ServingGroup`. The second (`405b`) is the `Role`. The remaining indices identify the pod within the role.
---
## 6. Accessing the vLLM OpenAI-Compatible API
Expose the entry via a Service:
```yaml
apiVersion: v1
kind: Service
metadata:
name: llama-multinode-openai
namespace: default
spec:
selector:
modelserving.volcano.sh/name: llama-multinode
modelserving.volcano.sh/entry: "true"
# optionally further narrow to leader role if you label it
ports:
- name: http
port: 80
targetPort: 8080
type: ClusterIP
```
Port-forward from your local machine:
```bash
kubectl port-forward svc/llama-multinode-openai 30080:80 -n default
```
Then:
- List models:
```bash
curl -s http://localhost:30080/v1/models
```
- Send a completion request (mirroring vLLM production stack docs):
```bash
curl -X POST http://localhost:30080/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Llama-3.1-405B-Instruct",
"prompt": "Once upon a time,",
"max_tokens": 10
}'
```
You should see an OpenAI-style response from vLLM.
---
## 7. Clean Up
To remove the deployment and its resources:
```bash
kubectl delete modelserving llama-multinode -n default
```
If you’re done with the entire stack:
```bash
helm uninstall kthena -n kthena-system # or your Kthena release name
helm uninstall volcano -n volcano-system
```
...@@ -14,6 +14,7 @@ Alternatively, you can deploy vLLM to Kubernetes using any of the following: ...@@ -14,6 +14,7 @@ Alternatively, you can deploy vLLM to Kubernetes using any of the following:
- [InftyAI/llmaz](integrations/llmaz.md) - [InftyAI/llmaz](integrations/llmaz.md)
- [KAITO](integrations/kaito.md) - [KAITO](integrations/kaito.md)
- [KServe](integrations/kserve.md) - [KServe](integrations/kserve.md)
- [Kthena](integrations/kthena.md)
- [KubeRay](integrations/kuberay.md) - [KubeRay](integrations/kuberay.md)
- [kubernetes-sigs/lws](frameworks/lws.md) - [kubernetes-sigs/lws](frameworks/lws.md)
- [meta-llama/llama-stack](integrations/llamastack.md) - [meta-llama/llama-stack](integrations/llamastack.md)
......
...@@ -86,7 +86,7 @@ LLM(model, enforce_eager=True) ...@@ -86,7 +86,7 @@ LLM(model, enforce_eager=True)
``` ```
To turn off just torch.compile, pass `mode = NONE` to the compilation config. To turn off just torch.compile, pass `mode = NONE` to the compilation config.
(`-cc` is short for `--compilation_config`; `-O.*` dotted syntax is deprecated): (`-cc` is short for `--compilation_config`):
```sh ```sh
# Online # Online
......
...@@ -79,7 +79,7 @@ The `post_process*` methods take `PoolingRequestOutput` objects as input and gen ...@@ -79,7 +79,7 @@ The `post_process*` methods take `PoolingRequestOutput` objects as input and gen
The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters. The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters.
The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/pooling/pooling/serving.py). The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/pooling/pooling/serving.py).
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/online_serving/pooling/prithvi_geospatial_mae.py](../../examples/online_serving/pooling/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py)) inference examples. An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/pooling/plugin/prithvi_geospatial_mae_client.py](../../examples/pooling/plugin/prithvi_geospatial_mae_client.py)) and offline ([examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py](../../examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py)) inference examples.
## Using an IO Processor plugin ## Using an IO Processor plugin
......
...@@ -57,15 +57,15 @@ vLLM also provides [a reference example](../../examples/online_serving/prometheu ...@@ -57,15 +57,15 @@ vLLM also provides [a reference example](../../examples/online_serving/prometheu
The subset of metrics exposed in the Grafana dashboard gives us an indication of which metrics are especially important: The subset of metrics exposed in the Grafana dashboard gives us an indication of which metrics are especially important:
- `vllm:e2e_request_latency_seconds_bucket` - End to end request latency measured in seconds. - `vllm:e2e_request_latency_seconds_bucket` - End to end request latency measured in seconds.
- `vllm:prompt_tokens_total` - Prompt tokens. - `vllm:prompt_tokens` - Prompt tokens.
- `vllm:generation_tokens_total` - Generation tokens. - `vllm:generation_tokens` - Generation tokens.
- `vllm:time_per_output_token_seconds` - Inter-token latency (Time Per Output Token, TPOT) in seconds. - `vllm:time_per_output_token_seconds` - Inter-token latency (Time Per Output Token, TPOT) in seconds.
- `vllm:time_to_first_token_seconds` - Time to First Token (TTFT) latency in seconds. - `vllm:time_to_first_token_seconds` - Time to First Token (TTFT) latency in seconds.
- `vllm:num_requests_running` (also, `_swapped` and `_waiting`) - Number of requests in the RUNNING, WAITING, and SWAPPED states. - `vllm:num_requests_running` (also, `_swapped` and `_waiting`) - Number of requests in the RUNNING, WAITING, and SWAPPED states.
- `vllm:gpu_cache_usage_perc` - Percentage of used cache blocks by vLLM. - `vllm:kv_cache_usage_perc` - Percentage of used cache blocks by vLLM.
- `vllm:request_prompt_tokens` - Request prompt length. - `vllm:request_prompt_tokens` - Request prompt length.
- `vllm:request_generation_tokens` - Request generation length. - `vllm:request_generation_tokens` - Request generation length.
- `vllm:request_success_total` - Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached. - `vllm:request_success` - Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached.
- `vllm:request_queue_time_seconds` - Queue time. - `vllm:request_queue_time_seconds` - Queue time.
- `vllm:request_prefill_time_seconds` - Requests prefill time. - `vllm:request_prefill_time_seconds` - Requests prefill time.
- `vllm:request_decode_time_seconds` - Requests decode time. - `vllm:request_decode_time_seconds` - Requests decode time.
...@@ -571,9 +571,9 @@ model and then validate those tokens with the larger model. ...@@ -571,9 +571,9 @@ model and then validate those tokens with the larger model.
- `vllm:spec_decode_draft_acceptance_rate` (Gauge) - `vllm:spec_decode_draft_acceptance_rate` (Gauge)
- `vllm:spec_decode_efficiency` (Gauge) - `vllm:spec_decode_efficiency` (Gauge)
- `vllm:spec_decode_num_accepted_tokens_total` (Counter) - `vllm:spec_decode_num_accepted_tokens` (Counter)
- `vllm:spec_decode_num_draft_tokens_total` (Counter) - `vllm:spec_decode_num_draft_tokens` (Counter)
- `vllm:spec_decode_num_emitted_tokens_total` (Counter) - `vllm:spec_decode_num_emitted_tokens` (Counter)
There is a PR under review (<https://github.com/vllm-project/vllm/pull/12193>) to add "prompt lookup (ngram)" There is a PR under review (<https://github.com/vllm-project/vllm/pull/12193>) to add "prompt lookup (ngram)"
speculative decoding to v1. Other techniques will follow. We should speculative decoding to v1. Other techniques will follow. We should
......
...@@ -90,7 +90,6 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels ...@@ -90,7 +90,6 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels
| cutlass_fp8 | standard,</br>batched | fp8 | A,T | silu, gelu | Y | Y | [`cutlass_moe_fp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp8],</br>[`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],</br>[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] | | cutlass_fp8 | standard,</br>batched | fp8 | A,T | silu, gelu | Y | Y | [`cutlass_moe_fp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp8],</br>[`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],</br>[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] |
| flashinfer | standard | nvfp4,</br>fp8 | T | <sup>5</sup> | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],</br>[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] | | flashinfer | standard | nvfp4,</br>fp8 | T | <sup>5</sup> | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],</br>[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] | | gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
| deep gemm+triton<sup>2</sup> | standard,</br>batched | all<sup>1</sup> | G(128),A,T | silu, gelu | <sup>6</sup> | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],</br>[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] |
| marlin | standard,</br>batched | <sup>3</sup> / N/A | <sup>3</sup> / N/A | silu,</br>swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],</br>[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] | | marlin | standard,</br>batched | <sup>3</sup> / N/A | <sup>3</sup> / N/A | silu,</br>swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],</br>[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] | | trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] | | pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
...@@ -114,5 +113,5 @@ The following table shows "families" of modular kernels that are intended to wor ...@@ -114,5 +113,5 @@ The following table shows "families" of modular kernels that are intended to wor
| backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses | | backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses |
|---------|-----------------------------------------|----------------------------------------------| |---------|-----------------------------------------|----------------------------------------------|
| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` | | deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` |
| deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`BatchedTritonOrDeepGemmExperts`,</br>`CutlassBatchedExpertsFp8`,</br>`BatchedMarlinExperts` | | deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`CutlassBatchedExpertsFp8`,</br>`BatchedMarlinExperts` |
| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` | | flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |
...@@ -22,8 +22,8 @@ In the example above, the KV cache in the first block can be uniquely identified ...@@ -22,8 +22,8 @@ In the example above, the KV cache in the first block can be uniquely identified
We only cache full blocks. We only cache full blocks.
!!! note "Note 2" !!! note "Note 2"
The above hash key structure is not 100% collision free. Theoretically it’s still possible for the different prefix tokens to have the same hash value. To avoid any hash collisions **in a multi-tenant setup, we advise to use SHA256** as hash function instead of the default builtin hash. The above hash key structure is not 100% collision free. Theoretically it’s still possible for the different prefix tokens to have the same hash value. To avoid any hash collisions **in a multi-tenant setup, we use SHA256** as hash function instead of the builtin hash.
SHA256 is supported since vLLM v0.8.3 and must be enabled with a command line argument. It comes with a performance impact of about 100-200ns per token (~6ms for 50k tokens of context). SHA256 is supported since vLLM v0.8.3 and the default since v0.10.2. It comes with a negligible performance impact of about 75ns per token (<4ms for 50k tokens of context).
**A hashing example with multi-modality inputs** **A hashing example with multi-modality inputs**
In this example, we illustrate how prefix caching works with multi-modality inputs (e.g., images). Assuming we have a request with the following messages: In this example, we illustrate how prefix caching works with multi-modality inputs (e.g., images). Assuming we have a request with the following messages:
......
...@@ -54,7 +54,7 @@ th:not(:first-child) { ...@@ -54,7 +54,7 @@ th:not(:first-child) {
| beam-search | ✅ | ✅ | ✅ | [](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [](https://github.com/vllm-project/vllm/issues/7968) | ❔ | ✅ | ✅ | | | beam-search | ✅ | ✅ | ✅ | [](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [](https://github.com/vllm-project/vllm/issues/7968) | ❔ | ✅ | ✅ | |
| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ | | [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ |
\* Chunked prefill and prefix caching are only applicable to last-token pooling. \* Chunked prefill and prefix caching are only applicable to last-token or all pooling with causal attention.
<sup>^</sup> LoRA is only applicable to the language backbone of multimodal models. <sup>^</sup> LoRA is only applicable to the language backbone of multimodal models.
### Feature x Hardware ### Feature x Hardware
...@@ -68,8 +68,8 @@ th:not(:first-child) { ...@@ -68,8 +68,8 @@ th:not(:first-child) {
| CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | [](https://github.com/vllm-project/vllm/issues/26970) | | CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | [](https://github.com/vllm-project/vllm/issues/26970) |
| [pooling](../models/pooling_models.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [pooling](../models/pooling_models.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| <abbr title="Encoder-Decoder Models">enc-dec</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | | <abbr title="Encoder-Decoder Models">enc-dec</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |
| [mm](multimodal_inputs.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | [🟠](https://github.com/vllm-project/vllm/issues/26965) | | [mm](multimodal_inputs.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ |
| <abbr title="Logprobs">logP</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | <abbr title="Logprobs">logP</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| <abbr title="Prompt Logprobs">prmpt logP</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | <abbr title="Prompt Logprobs">prmpt logP</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| <abbr title="Async Output Processing">async output</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | | <abbr title="Async Output Processing">async output</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ |
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment