Commit 006693ed authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.11.2' into v0.11.2-ori

parents 4b51e6f1 275de341
......@@ -67,8 +67,9 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
std::optional<torch::Tensor> const& bias);
#endif
#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \
defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100
#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \
defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100 || \
defined(ENABLE_SCALED_MM_SM120) && ENABLE_SCALED_MM_SM120
void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
......@@ -253,7 +254,7 @@ void cutlass_moe_mm(
bool per_act_token, bool per_out_ch) {
int32_t version_num = get_sm_version_num();
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
if (version_num >= 100) {
if (version_num >= 100 && version_num < 110) {
cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
......@@ -261,7 +262,7 @@ void cutlass_moe_mm(
}
#endif
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
if (version_num >= 90) {
if (version_num >= 90 && version_num < 100) {
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
......
......@@ -7,7 +7,7 @@
#include <hip/hip_bf16.h>
#include <hip/hip_bfloat16.h>
#include "../../../attention/attention_dtypes.h"
#include "../../../../attention/attention_dtypes.h"
namespace vllm {
#ifdef USE_ROCM
......
#include "common.cuh"
#include "dispatch_utils.h"
#include "../../cub_helpers.h"
#include "../vectorization_utils.cuh"
#include "cub_helpers.h"
#include "quantization/vectorization_utils.cuh"
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/Exceptions.h>
......
#pragma once
#include "../../../attention/attention_dtypes.h"
#include "../../../../attention/attention_dtypes.h"
#include <assert.h>
#include <float.h>
#include <stdint.h>
......
#include <ATen/cuda/CUDAContext.h>
#include "../per_token_group_quant_8bit.h"
#include "quantization/w8a8/per_token_group_quant_8bit.h"
#include <cmath>
......@@ -8,9 +8,9 @@
#include <torch/all.h>
#include "../vectorization.cuh"
#include "../vectorization_utils.cuh"
#include "../../dispatch_utils.h"
#include "quantization/vectorization.cuh"
#include "quantization/vectorization_utils.cuh"
#include "dispatch_utils.h"
__device__ __forceinline__ float GroupReduceMax(float val) {
unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff;
......@@ -212,4 +212,4 @@ void per_token_group_quant_fp8(const torch::Tensor& input,
double fp8_max, bool scale_ue8m0) {
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
fp8_min, fp8_max, scale_ue8m0);
}
}
\ No newline at end of file
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include "quantization/w8a8/per_token_group_quant_8bit.h"
void per_token_group_quant_int8(const torch::Tensor& input,
torch::Tensor& output_q,
torch::Tensor& output_s, int64_t group_size,
double eps, double int8_min, double int8_max) {
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
int8_min, int8_max);
}
\ No newline at end of file
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#ifndef USE_ROCM
#include "../per_token_group_quant_8bit.h"
#endif
#include <c10/cuda/CUDAGuard.h>
#include <cmath>
#include "../../cub_helpers.h"
#include "../../dispatch_utils.h"
#include "../vectorization_utils.cuh"
#include "dispatch_utils.h"
#include "quantization/vectorization_utils.cuh"
#include "cub_helpers.h"
static inline __device__ int8_t float_to_int8_rn(float x) {
#ifdef USE_ROCM
......@@ -25,7 +22,6 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
float dst = std::nearbyint(x);
// saturate
// See https://github.com/pytorch/pytorch/issues/127666
// See https://github.com/llvm/llvm-project/issues/95183
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
......@@ -84,7 +80,6 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
static_cast<int32_t>(std::numeric_limits<int8_t>::max());
// saturate
// See https://github.com/pytorch/pytorch/issues/127666
// See https://github.com/llvm/llvm-project/issues/95183
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
......@@ -176,7 +171,6 @@ __global__ void dynamic_scaled_int8_quant_kernel(
float inv_s = (absmax == 0.f) ? 0.f : 127.f / absmax;
// 2. quantize
vectorize_with_alignment<16>(
row_in, row_out, hidden_size, tid, stride,
[=] __device__(int8_t& dst, const scalar_t& src) {
......@@ -194,7 +188,6 @@ struct MinMax {
__host__ __device__ explicit MinMax(float v) : min(v), max(v) {}
// add a value to the MinMax
__host__ __device__ MinMax& operator+=(float v) {
min = fminf(min, v);
max = fmaxf(max, v);
......@@ -228,7 +221,6 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
const scalar_t* row_in = input + token_idx * hidden_size;
int8_t* row_out = output + token_idx * hidden_size;
// 1. calculate min & max
MinMax thread_mm;
vectorize_read_with_alignment<16>(row_in, hidden_size, tid, stride,
[&] __device__(const scalar_t& src) {
......@@ -261,7 +253,6 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
const float inv_s = 1.f / scale_sh;
const azp_t azp = azp_sh;
// 2. quantize
vectorize_with_alignment<16>(
row_in, row_out, hidden_size, tid, stride,
[=] __device__(int8_t& dst, const scalar_t& src) {
......@@ -285,6 +276,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 256));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
......@@ -316,6 +308,7 @@ void dynamic_scaled_int8_quant(
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 256));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
......@@ -332,14 +325,4 @@ void dynamic_scaled_int8_quant(
hidden_size);
}
});
}
#ifndef USE_ROCM
void per_token_group_quant_int8(const torch::Tensor& input,
torch::Tensor& output_q,
torch::Tensor& output_s, int64_t group_size,
double eps, double int8_min, double int8_max) {
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
int8_min, int8_max);
}
#endif
}
\ No newline at end of file
#pragma once
#include <torch/all.h>
// TODO(wentao): refactor the folder to 8bit, then includes fp8 and int8 folders
// 8-bit per-token-group quantization helper used by both FP8 and INT8
void per_token_group_quant_8bit(const torch::Tensor& input,
torch::Tensor& output_q,
......
......@@ -22,13 +22,14 @@ template <typename AllReduceKernel, typename T>
__global__ __quickreduce_launch_bounds_two_shot__ static void
allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks,
int rank, uint8_t** dbuffer_list,
uint32_t data_offset, uint32_t flag_color) {
uint32_t data_offset, uint32_t flag_color,
int64_t data_size_per_phase) {
int block = blockIdx.x;
int grid = gridDim.x;
while (block < num_blocks) {
AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset,
flag_color);
flag_color, data_size_per_phase);
block += grid;
flag_color++;
}
......@@ -41,21 +42,21 @@ allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks,
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
num_blocks, rank, dbuffer_list, data_offset, \
flag_color); \
flag_color, this->kMaxProblemSize); \
} else if (world_size == 4) { \
using LineCodec = __codec<T, 4>; \
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
num_blocks, rank, dbuffer_list, data_offset, \
flag_color); \
flag_color, this->kMaxProblemSize); \
} else if (world_size == 8) { \
using LineCodec = __codec<T, 8>; \
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
num_blocks, rank, dbuffer_list, data_offset, \
flag_color); \
flag_color, this->kMaxProblemSize); \
}
enum QuickReduceQuantLevel {
......
......@@ -553,13 +553,12 @@ struct AllReduceTwoshot {
int const rank, // rank index
uint8_t** __restrict__ buffer_list, // communication buffers
uint32_t const data_offset, // offset to start of the data buffer
uint32_t flag_color) {
uint32_t flag_color, int64_t data_size_per_phase) {
// Topology
int thread = threadIdx.x + threadIdx.y * kWavefront;
uint8_t* rank_buffer = buffer_list[rank];
Codec codec(thread, rank);
int block_id = blockIdx.x;
int grid_size = gridDim.x;
// --------------------------------------------------------
// Read input into registers
int32x4_t tA[kAtoms];
......@@ -588,12 +587,10 @@ struct AllReduceTwoshot {
// rank responsible for this segment.
uint32_t comm_data0_offset =
data_offset + block_id * Codec::kTransmittedTileSize;
uint32_t comm_data1_offset =
grid_size * Codec::kTransmittedTileSize + comm_data0_offset;
uint32_t comm_data1_offset = data_size_per_phase + comm_data0_offset;
uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t));
uint32_t comm_flags1_offset =
grid_size * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset;
uint32_t comm_flags1_offset = (data_offset / 2) + comm_flags0_offset;
for (int r = 0; r < kWorldSize; r++) {
int32x4_t* send_buffer =
......
......@@ -23,7 +23,7 @@
#include <algorithm>
#include "../attention/dtype_fp8.cuh"
#include "../quantization/fp8/amd/quant_utils.cuh"
#include "../quantization/w8a8/fp8/amd/quant_utils.cuh"
// ROCm 6.2 compatibility: map OCP fp8 types to FNUZ variants if OCP is absent
#if !defined(HIP_FP8_TYPE_OCP)
......@@ -40,7 +40,8 @@ using __hip_fp8_e5m2 = __hip_fp8_e5m2_fnuz;
#define __HIP__FP8MFMA__
#endif
#if defined(__HIPCC__) && (defined(__gfx1100__) || defined(__gfx1101__))
#if defined(__HIPCC__) && (defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1150__) || defined(__gfx1151__))
#define __HIP__GFX11__
#endif
......
......@@ -11,7 +11,7 @@
#include "../cuda_compat.h"
#include "dispatch_utils.h"
#include "quantization/fp8/common.cuh"
#include "quantization/w8a8/fp8/common.cuh"
#if defined(__HIPCC__) && \
(defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__))
......
......@@ -44,6 +44,270 @@ __global__ void apply_repetition_penalties_kernel(
}
}
static inline __device__ uint16_t extractBinIdx(float x) {
union {
__half h;
uint16_t u16;
} tmp;
tmp.h = __float2half_rn(x);
tmp.u16 = (x < 0.f) ? (~tmp.u16 & 0xffff) : (tmp.u16 | 0x8000);
return 511 - (tmp.u16 >> 7);
}
template <int kNumThreadsPerBlock = 512, int kNumBins = 512, int kTopK = 2048>
__device__ void topKPerRowJob(const float* logits, const int rowStart,
const int rowEnd, const int rowIdx,
int* outIndices, int stride0, int stride1) {
// The number of elements per thread for the final top-k sort.
static constexpr int kNumTopKItemsPerThread = kTopK / kNumThreadsPerBlock;
// The class to sort the elements during the final top-k sort.
#ifdef USE_ROCM
using TopKSort = hipcub::BlockRadixSort<float, kNumThreadsPerBlock,
kNumTopKItemsPerThread, int>;
#else
using TopKSort = cub::BlockRadixSort<float, kNumThreadsPerBlock,
kNumTopKItemsPerThread, int>;
#endif
// The number of slots for the final pass.
static constexpr int kNumFinalItems = 3072;
// The number of elements per thread for the final sort.
static constexpr int kNumFinalItemsPerThread =
kNumFinalItems / kNumThreadsPerBlock;
// The class to sort the elements during the final pass.
#ifdef USE_ROCM
using FinalSort = hipcub::BlockRadixSort<float, kNumThreadsPerBlock,
kNumFinalItemsPerThread, int>;
#else
using FinalSort = cub::BlockRadixSort<float, kNumThreadsPerBlock,
kNumFinalItemsPerThread, int>;
#endif
// The class to compute the inclusive prefix-sum over the histogram.
#ifdef USE_ROCM
using Scan = hipcub::BlockScan<int, kNumThreadsPerBlock>;
#else
using Scan = cub::BlockScan<int, kNumThreadsPerBlock>;
#endif
// Shared memory to compute the block scan.
__shared__ typename Scan::TempStorage smemScan;
// The structure to store the final items (for the final pass).
struct FinalItems {
// Shared memory to store the indices for the final pass.
int indices[kNumFinalItems];
// Shared memory to store the logits for the final pass.
float logits[kNumFinalItems];
};
// Shared memory to compute the block sort.
__shared__ union {
FinalItems items;
typename FinalSort::TempStorage finalSort;
typename TopKSort::TempStorage topKSort;
} smemFinal;
// Shared memory to store the histogram.
__shared__ int smemHistogram[kNumBins];
// Shared memory to store the selected indices.
__shared__ int smemIndices[kTopK];
// Shared memory to store the threshold bin.
__shared__ int smemThresholdBinIdx[1];
// Shared memory counter to register the candidates for the final phase.
__shared__ int smemFinalDstIdx[1];
// The length of the row.
int rowLen = rowEnd - rowStart;
// Shortcut if the length of the row is smaller than Top-K. Indices are not
// sorted by their corresponding logit.
if (rowLen <= kTopK) {
for (int rowIt = threadIdx.x; rowIt < rowLen;
rowIt += kNumThreadsPerBlock) {
int idx = rowStart + rowIt;
outIndices[rowIdx * kTopK + rowIt] = idx - rowStart;
}
for (int rowIt = rowLen + threadIdx.x; rowIt < kTopK;
rowIt += kNumThreadsPerBlock) {
outIndices[rowIdx * kTopK + rowIt] = -1;
}
return;
}
// Clear the histogram.
if (threadIdx.x < kNumBins) {
smemHistogram[threadIdx.x] = 0;
}
// Make sure the histogram is ready.
__syncthreads();
// Fetch elements one-by-one.
for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd;
rowIt += kNumThreadsPerBlock) {
uint16_t idx = extractBinIdx(logits[rowIdx * stride0 + rowIt * stride1]);
atomicAdd(&smemHistogram[idx], 1);
}
// Make sure the histogram is ready.
__syncthreads();
// Read the values from SMEM.
int binCount{0};
if (threadIdx.x < kNumBins) {
binCount = smemHistogram[threadIdx.x];
}
// Make sure each thread has read its value.
__syncthreads();
// Compute the prefix sum.
int prefixSum{0}, totalSum{0};
Scan(smemScan).ExclusiveSum(binCount, prefixSum, totalSum);
// Update the histogram with the prefix sums.
if (threadIdx.x < kNumBins) {
smemHistogram[threadIdx.x] = prefixSum;
}
// Make sure the data is in shared memory.
__syncthreads();
// Find the last valid bin.
if (threadIdx.x < kNumBins) {
int nextPrefixSum =
threadIdx.x == kNumBins - 1 ? totalSum : smemHistogram[threadIdx.x + 1];
if (prefixSum < kTopK && nextPrefixSum >= kTopK) {
smemThresholdBinIdx[0] = threadIdx.x;
}
}
// Clear the counter to store the items for the final phase.
if (threadIdx.x == 0) {
smemFinalDstIdx[0] = 0;
}
// Make sure the data is in shared memory.
__syncthreads();
// The threshold bin.
int thresholdBinIdx = smemThresholdBinIdx[0];
// Fetch elements one-by-one and populate the shared memory buffers.
for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd;
rowIt += kNumThreadsPerBlock) {
float logit = logits[rowIdx * stride0 + rowIt * stride1];
uint16_t idx = extractBinIdx(logit);
if (idx < thresholdBinIdx) {
int dstIdx = atomicAdd(&smemHistogram[idx], 1);
smemIndices[dstIdx] = rowIt;
} else if (idx == thresholdBinIdx) {
int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1);
if (dstIdx < kNumFinalItems) {
smemFinal.items.logits[dstIdx] = logit;
smemFinal.items.indices[dstIdx] = rowIt;
}
}
}
// Make sure the elements are in shared memory.
__syncthreads();
// The logits of the elements to be sorted in the final pass.
float finalLogits[kNumFinalItemsPerThread];
// The indices of the elements to be sorted in the final pass.
int finalIndices[kNumFinalItemsPerThread];
// Init.
#pragma unroll
for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) {
finalLogits[ii] = -FLT_MAX;
}
// Read the elements from SMEM.
#pragma unroll
for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) {
int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x;
if (srcIdx < smemFinalDstIdx[0]) {
finalLogits[ii] = smemFinal.items.logits[srcIdx];
finalIndices[ii] = smemFinal.items.indices[srcIdx];
}
}
// Make sure the shared memory has been read.
__syncthreads();
// Sort the elements.
FinalSort(smemFinal.finalSort)
.SortDescendingBlockedToStriped(finalLogits, finalIndices);
// Copy the data back to the shared memory storage.
int baseIdx = thresholdBinIdx > 0 ? smemHistogram[thresholdBinIdx - 1] : 0;
#pragma unroll
for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) {
int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x;
int dstIdx = baseIdx + srcIdx;
if (dstIdx < kTopK) {
smemIndices[dstIdx] = finalIndices[ii];
}
}
// Make sure the data is in shared memory.
__syncthreads();
// Store to global memory.
#pragma unroll
for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) {
int offset = rowIdx * kTopK + ii * kNumThreadsPerBlock + threadIdx.x;
outIndices[offset] =
smemIndices[ii * kNumThreadsPerBlock + threadIdx.x] - rowStart;
}
}
template <int kNumThreadsPerBlock = 512>
static __global__ void topKPerRow(const float* logits, const int* rowStarts,
const int* rowEnds, int* outIndices,
int stride0, int stride1) {
// The number of bins in the histogram.
static constexpr int kNumBins = 512;
// The top-k width.
static constexpr int kTopK = 2048;
// The row computed by this block.
int rowIdx = blockIdx.x;
// The range of logits within the row.
int rowStart = rowStarts[rowIdx];
int rowEnd = rowEnds[rowIdx];
topKPerRowJob<kNumThreadsPerBlock, kNumBins, kTopK>(
logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1);
}
template <int kNumThreadsPerBlock = 512>
static __global__ void topKPerRowDecode(const float* logits, const int* seqLens,
int* outIndices, int stride0,
int stride1, int next_n) {
// The number of bins in the histogram.
static constexpr int kNumBins = 512;
// The top-k width.
static constexpr int kTopK = 2048;
// The row computed by this block.
int rowIdx = blockIdx.x;
// The range of logits within the row.
int rowStart = 0;
int seq_len = seqLens[rowIdx / next_n];
int rowEnd = seq_len - next_n + (rowIdx % next_n) + 1;
topKPerRowJob<kNumThreadsPerBlock, kNumBins, kTopK>(
logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1);
}
} // namespace vllm
void apply_repetition_penalties_(
......@@ -85,4 +349,32 @@ void apply_repetition_penalties_(
repetition_penalties.data_ptr<scalar_t>(), num_seqs, vocab_size,
tile_size);
});
}
\ No newline at end of file
}
void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
const torch::Tensor& seqLens, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1) {
// Compute the results on the device.
constexpr int kNumThreadsPerBlock = 512;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::topKPerRowDecode<kNumThreadsPerBlock>
<<<numRows, kNumThreadsPerBlock, 0, stream>>>(
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
indices.data_ptr<int>(), static_cast<int>(stride0),
static_cast<int>(stride1), static_cast<int>(next_n));
}
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
const torch::Tensor& rowEnds, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1) {
// Compute the results on the device.
constexpr int kNumThreadsPerBlock = 512;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::topKPerRow<kNumThreadsPerBlock>
<<<numRows, kNumThreadsPerBlock, 0, stream>>>(
logits.data_ptr<float>(), rowStarts.data_ptr<int>(),
rowEnds.data_ptr<int>(), indices.data_ptr<int>(),
static_cast<int>(stride0), static_cast<int>(stride1));
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment