Commit dcb5624a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.5' into v0.8.5-dev

parents 55880ca2 ba41cc90
...@@ -269,6 +269,12 @@ void advance_step_flashinfer( ...@@ -269,6 +269,12 @@ void advance_step_flashinfer(
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds); torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
// void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
// torch::Tensor const& q_pe,
// torch::Tensor const& kv_c_and_k_pe_cache,
// torch::Tensor const& seq_lens,
// torch::Tensor const& page_table, double scale);
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor); torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
#ifndef USE_ROCM #ifndef USE_ROCM
......
...@@ -46,14 +46,26 @@ __global__ void compute_expert_offsets( ...@@ -46,14 +46,26 @@ __global__ void compute_expert_offsets(
} }
__global__ void compute_arg_sorts(const int* __restrict__ topk_ids, __global__ void compute_arg_sorts(const int* __restrict__ topk_ids,
const int32_t* __restrict__ expert_offsets,
int32_t* input_permutation, int32_t* input_permutation,
int32_t* output_permutation, int32_t* output_permutation,
int32_t* atomic_buffer, const int topk_length, int32_t* atomic_buffer, const int topk_length,
const int topk) { const int topk) {
int expert_id = blockIdx.x; int const blk_expert_id = blockIdx.x;
int const num_experts = gridDim.x;
int32_t const num_tokens = expert_offsets[num_experts];
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) { for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
if (topk_ids[i] == expert_id) { int const expert_id = topk_ids[i];
if (expert_id == -1 && blockIdx.x == 0) {
// output_permutation is used to re-order the moe outputs. It is
// used as c2 = c2[c_map], where c2 is a torch.tensor that is the
// output of the cutlass kernels and c_map is the output_permutation.
// c2 is initialized to zeros, therefore by setting the output_permutation
// to num_tokens, we are guaranteed to fill the moe outputs to zero
// for "invalid" topk_ids.
output_permutation[i] = num_tokens;
} else if (expert_id == blk_expert_id) {
int start = atomicAdd(&atomic_buffer[expert_id], 1); int start = atomicAdd(&atomic_buffer[expert_id], 1);
input_permutation[start] = i / topk; input_permutation[start] = i / topk;
output_permutation[i] = start; output_permutation[i] = start;
...@@ -83,6 +95,7 @@ void get_cutlass_moe_mm_data_caller( ...@@ -83,6 +95,7 @@ void get_cutlass_moe_mm_data_caller(
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts); static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>( compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()), static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<const int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(input_permutation.data_ptr()), static_cast<int32_t*>(input_permutation.data_ptr()),
static_cast<int32_t*>(output_permutation.data_ptr()), static_cast<int32_t*>(output_permutation.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
......
...@@ -336,7 +336,7 @@ inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out, ...@@ -336,7 +336,7 @@ inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
uint32_t const m = a.size(0); uint32_t const m = a.size(0);
uint32_t const mp2 = uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2 std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
if (mp2 <= 16) { if (mp2 <= 16) {
// M in [1, 16] // M in [1, 16]
......
...@@ -321,7 +321,7 @@ inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out, ...@@ -321,7 +321,7 @@ inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
uint32_t const m = a.size(0); uint32_t const m = a.size(0);
uint32_t const mp2 = uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2 std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
if (mp2 <= 16) { if (mp2 <= 16) {
// M in [1, 16] // M in [1, 16]
......
...@@ -134,7 +134,7 @@ typename T::Gemm::Arguments args_from_options( ...@@ -134,7 +134,7 @@ typename T::Gemm::Arguments args_from_options(
using StrideB = typename T::StrideB; using StrideB = typename T::StrideB;
using StrideD = typename T::StrideD; using StrideD = typename T::StrideD;
using Sm100BlkScaledConfig = using Sm100BlkScaledConfig =
typename T::Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
int m = static_cast<int>(M); int m = static_cast<int>(M);
int n = static_cast<int>(N); int n = static_cast<int>(N);
......
...@@ -9,7 +9,11 @@ ...@@ -9,7 +9,11 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <iostream> #include <iostream>
namespace marlin { #ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin
#endif
namespace MARLIN_NAMESPACE_NAME {
// Marlin params // Marlin params
...@@ -23,6 +27,7 @@ static constexpr int pipe_stages = ...@@ -23,6 +27,7 @@ static constexpr int pipe_stages =
static constexpr int min_thread_n = 64; static constexpr int min_thread_n = 64;
static constexpr int min_thread_k = 64; static constexpr int min_thread_k = 64;
static constexpr int max_thread_n = 256;
static constexpr int tile_size = 16; static constexpr int tile_size = 16;
static constexpr int max_par = 16; static constexpr int max_par = 16;
...@@ -84,4 +89,4 @@ __device__ inline void cp_async_wait() { ...@@ -84,4 +89,4 @@ __device__ inline void cp_async_wait() {
#endif #endif
} // namespace marlin } // namespace MARLIN_NAMESPACE_NAME
...@@ -5,7 +5,11 @@ ...@@ -5,7 +5,11 @@
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
namespace marlin { #ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin
#endif
namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t> template <typename scalar_t>
class ScalarType {}; class ScalarType {};
...@@ -54,7 +58,7 @@ class ScalarType<nv_bfloat16> { ...@@ -54,7 +58,7 @@ class ScalarType<nv_bfloat16> {
using FragS = Vec<nv_bfloat162, 1>; using FragS = Vec<nv_bfloat162, 1>;
using FragZP = Vec<nv_bfloat162, 4>; using FragZP = Vec<nv_bfloat162, 4>;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
static __device__ float inline num2float(const nv_bfloat16 x) { static __device__ float inline num2float(const nv_bfloat16 x) {
return __bfloat162float(x); return __bfloat162float(x);
} }
...@@ -74,6 +78,6 @@ class ScalarType<nv_bfloat16> { ...@@ -74,6 +78,6 @@ class ScalarType<nv_bfloat16> {
#endif #endif
}; };
} // namespace marlin } // namespace MARLIN_NAMESPACE_NAME
#endif #endif
...@@ -2,6 +2,15 @@ ...@@ -2,6 +2,15 @@
#include <torch/all.h> #include <torch/all.h>
torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
const int64_t rows_per_block);
torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
const int64_t CuCount);
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount);
void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& max_logits, torch::Tensor& tmp_out,
torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& query, torch::Tensor& key_cache,
......
This diff is collapsed.
...@@ -14,6 +14,24 @@ ...@@ -14,6 +14,24 @@
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
// vLLM custom ops for rocm // vLLM custom ops for rocm
// Custom gemm op for matrix-vector multiplication
rocm_ops.def(
"LLMM1(Tensor in_a, Tensor in_b, int rows_per_block) -> "
"Tensor");
rocm_ops.impl("LLMM1", torch::kCUDA, &LLMM1);
// Custom gemm op for skinny matrix-matrix multiplication
rocm_ops.def(
"wvSplitK(Tensor in_a, Tensor in_b, int CuCount) -> "
"Tensor");
rocm_ops.impl("wvSplitK", torch::kCUDA, &wvSplitK);
// wvSplitK for fp8
rocm_ops.def(
"wvSplitKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, "
" Tensor scale_b, int CuCount) -> ()");
rocm_ops.impl("wvSplitKQ", torch::kCUDA, &wvSplitKQ);
// Custom attention op // Custom attention op
// Compute the attention between an input query and the cached // Compute the attention between an input query and the cached
// keys/values using PagedAttention. // keys/values using PagedAttention.
......
...@@ -294,6 +294,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -294,6 +294,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
") -> ()"); ") -> ()");
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer); ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
// Compute MLA decode using cutlass.
// ops.def(
// "cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
// " Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
// " Tensor page_table, float scale) -> ()");
// ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
// Layernorm // Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor. // Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def( ops.def(
......
...@@ -162,6 +162,9 @@ ENV UV_HTTP_TIMEOUT=500 ...@@ -162,6 +162,9 @@ ENV UV_HTTP_TIMEOUT=500
COPY requirements/lint.txt requirements/lint.txt COPY requirements/lint.txt requirements/lint.txt
COPY requirements/test.txt requirements/test.txt COPY requirements/test.txt requirements/test.txt
COPY requirements/dev.txt requirements/dev.txt COPY requirements/dev.txt requirements/dev.txt
# Workaround for #17068
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system mamba-ssm==2.2.4 --no-build-isolation
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/dev.txt uv pip install --system -r requirements/dev.txt
#################### DEV IMAGE #################### #################### DEV IMAGE ####################
...@@ -240,6 +243,8 @@ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \ ...@@ -240,6 +243,8 @@ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
uv pip install --system https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.1.post2/flashinfer_python-0.2.1.post2+cu124torch2.6-cp38-abi3-linux_x86_64.whl ; \ uv pip install --system https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.1.post2/flashinfer_python-0.2.1.post2+cu124torch2.6-cp38-abi3-linux_x86_64.whl ; \
fi fi
COPY examples examples COPY examples examples
COPY benchmarks benchmarks
COPY ./vllm/collect_env.py .
# Although we build Flashinfer with AOT mode, there's still # Although we build Flashinfer with AOT mode, there's still
# some issues w.r.t. JIT compilation. Therefore we need to # some issues w.r.t. JIT compilation. Therefore we need to
...@@ -263,6 +268,9 @@ ADD . /vllm-workspace/ ...@@ -263,6 +268,9 @@ ADD . /vllm-workspace/
ENV UV_HTTP_TIMEOUT=500 ENV UV_HTTP_TIMEOUT=500
# install development dependencies (for testing) # install development dependencies (for testing)
# Workaround for #17068
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system mamba-ssm==2.2.4 --no-build-isolation
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/dev.txt uv pip install --system -r requirements/dev.txt
...@@ -289,6 +297,7 @@ RUN mv vllm test_docs/ ...@@ -289,6 +297,7 @@ RUN mv vllm test_docs/
#################### OPENAI API SERVER #################### #################### OPENAI API SERVER ####################
# base openai image with additional requirements, for any subsequent openai-style images # base openai image with additional requirements, for any subsequent openai-style images
FROM vllm-base AS vllm-openai-base FROM vllm-base AS vllm-openai-base
ARG TARGETPLATFORM
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out # This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694 # Reference: https://github.com/astral-sh/uv/pull/1694
......
...@@ -121,6 +121,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ ...@@ -121,6 +121,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
ADD ./tests/ ./tests/ ADD ./tests/ ./tests/
ADD ./examples/ ./examples/ ADD ./examples/ ./examples/
ADD ./benchmarks/ ./benchmarks/ ADD ./benchmarks/ ./benchmarks/
ADD ./vllm/collect_env.py .
# install development dependencies (for testing) # install development dependencies (for testing)
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \
......
This diff is collapsed.
This diff is collapsed.
...@@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ...@@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="1a7f4dfa" ARG FA_BRANCH="1a7f4dfa"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="8970b25b" ARG AITER_BRANCH="7e1ed08"
ARG AITER_REPO="https://github.com/ROCm/aiter.git" ARG AITER_REPO="https://github.com/ROCm/aiter.git"
FROM ${BASE_IMAGE} AS base FROM ${BASE_IMAGE} AS base
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment