Commit 539aa992 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.2' into v0.6.2-dev

parents 93872128 7193774b
......@@ -71,7 +71,7 @@ torch::Tensor run_impl(PyTorchArguments args) {
auto arguments = MacheteKernel::create_arguments(
stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr,
layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0),
args.group_size.value_or(K));
args.group_size);
TORCH_CHECK(MacheteKernel::can_implement(arguments),
"Machete kernel cannot be run with these arguments");
......
......@@ -53,7 +53,7 @@ torch::Tensor prepack_impl(torch::Tensor const B) {
// clang-format on
// Allocate output
torch::Tensor D = torch::empty_like(B);
torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous);
prepack_B<PrepackedLayoutB>(stream, B_ptr, layout_Bt,
static_cast<ElementB*>(D.mutable_data_ptr()));
......
/*
* Copyright (c) 2024, The vLLM team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <hip/hip_bf16.h>
#include "cuda_compat.h"
#include <algorithm>
#include "../attention/dtype_fp8.cuh"
#include "../quantization/fp8/amd/quant_utils.cuh"
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \
defined(__gfx941__) || defined(__gfx942__))
#define __HIP__MI300_MI250__
#endif
#if defined(NDEBUG)
#undef NDEBUG
#include <assert.h>
#define UNREACHABLE_CODE assert(false);
#define NDEBUG
#else
#define UNREACHABLE_CODE assert(false);
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
#define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32
#define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16
using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float;
using float16x4 =
__attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16;
typedef float16x4 _Half4;
typedef struct _Half8 {
_Half4 xy[2];
} _Half8;
using bit16_t = uint16_t;
using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t;
typedef bit16x4 _B16x4;
typedef struct _B16x8 {
_B16x4 xy[2];
} _B16x8;
using _B8x8 = uint2;
////// Non temporal load stores ///////
template <typename T>
__device__ __forceinline__ T load(T* addr) {
return addr[0];
}
template <typename T>
__device__ __forceinline__ void store(T value, T* addr) {
addr[0] = value;
}
template <typename T, int absz, int cbid, int blgp>
__device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA,
const _B16x4& inpB,
const floatx4& inpC) {
if constexpr (std::is_same<T, _Float16>::value) {
return __builtin_amdgcn_mfma_f32_4x4x4f16(inpA, inpB, inpC, absz, cbid,
blgp);
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
return __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(inpA, inpB, inpC, absz, cbid,
blgp);
} else {
static_assert(false, "unsupported 16b dtype");
}
}
template <typename T>
__device__ __forceinline__ float to_float(const T& inp) {
if constexpr (std::is_same<T, _Float16>::value) {
return (float)inp;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
return __bfloat162float(inp);
} else {
static_assert(false, "unsupported 16b dtype");
}
}
template <typename T>
__device__ __forceinline__ T from_float(const float& inp) {
if constexpr (std::is_same<T, _Float16>::value) {
return (_Float16)inp;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
return __float2bfloat16(inp);
} else {
static_assert(false, "unsupported 16b dtype");
}
}
template <typename T>
__device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) {
union tmpcvt {
uint16_t u;
_Float16 f;
__hip_bfloat16 b;
} t16;
_B16x4 ret;
if constexpr (std::is_same<T, _Float16>::value) {
#pragma unroll
for (int i = 0; i < 4; i++) {
t16.f = (_Float16)inp[i];
ret[i] = t16.u;
}
return ret;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
#pragma unroll
for (int i = 0; i < 4; i++) {
t16.b = __float2bfloat16(inp[i]);
ret[i] = t16.u;
}
return ret;
} else {
static_assert(false, "unsupported 16b dtype");
}
}
template <typename T>
__device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1,
const _B16x4& inp2) {
union tmpcvt {
uint16_t u;
_Float16 f;
__hip_bfloat16 b;
} t1, t2, res;
_B16x4 ret;
if constexpr (std::is_same<T, _Float16>::value) {
#pragma unroll
for (int i = 0; i < 4; i++) {
t1.u = inp1[i];
t2.u = inp2[i];
res.f = t1.f + t2.f;
ret[i] = res.u;
}
return ret;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
#pragma unroll
for (int i = 0; i < 4; i++) {
t1.u = inp1[i];
t2.u = inp2[i];
res.b = t1.b + t2.b;
ret[i] = res.u;
}
return ret;
} else {
static_assert(false, "unsupported 16b dtype");
}
}
template <typename T, vllm::Fp8KVCacheDataType KV_DTYPE>
__device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input,
const float scale) {
union alignas(16) {
uint4 u4;
_B16x8 u16x8;
vllm::bf16_8_t b16x8;
} tmp;
if constexpr (std::is_same<T, _Float16>::value) {
tmp.u4 = vllm::fp8::scaled_convert<uint4, _B8x8, KV_DTYPE>(input, scale);
return tmp.u16x8;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
tmp.b16x8 = vllm::fp8::scaled_convert<vllm::bf16_8_t, _B8x8, KV_DTYPE>(
input, scale);
return tmp.u16x8;
} else {
static_assert(false, "unsupported 16b dtype");
}
}
///////////////////////////////////////
// grid (num_seqs, num_partitions,num_heads/gqa_ratio)
// block (partition size)
template <typename scalar_t, typename cache_t,
vllm::Fp8KVCacheDataType KV_DTYPE, int BLOCK_SIZE, int HEAD_SIZE,
int NUM_THREADS,
int GQA_RATIO>
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
int max_ctx_blocks, float k_scale, float v_scale) {
constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
const int warpid = threadIdx.x / WARP_SIZE;
const int laneid = threadIdx.x % WARP_SIZE;
const int lane4id = laneid % 4;
const int seq_idx = blockIdx.x;
const int partition_idx = blockIdx.y;
const int partition_size = blockDim.x;
const int max_num_partitions = gridDim.y;
const int context_len = context_lens[seq_idx];
const int partition_start_token_idx = partition_idx * partition_size;
// exit if partition is out of context for seq
if (partition_start_token_idx >= context_len) {
return;
}
constexpr int QHLOOP =
DIVIDE_ROUND_UP(GQA_RATIO, 4); // each 4 lanes fetch 4 different qheads,
// total qheads =8, so qhloop is 2
constexpr int GQA_RATIO4 = 4 * QHLOOP;
__shared__ float shared_qk_max[NWARPS][GQA_RATIO4 + 1];
__shared__ float shared_exp_sum[NWARPS][GQA_RATIO4 + 1];
_B16x8 Qlocal[QHLOOP];
constexpr int x = 16 / sizeof(scalar_t);
constexpr int KHELOOP = HEAD_SIZE / x;
_B16x8 Klocal[KHELOOP];
_B8x8 Klocalb8[KHELOOP];
constexpr int VHELOOP =
HEAD_SIZE /
WARP_SIZE; // v head_size dimension is distributed across lanes
constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2
// 8xtokens
_B16x8 Vlocal[VHELOOP][VTLOOP];
_B8x8 Vlocalb8[VHELOOP][VTLOOP];
floatx4 dout[QHLOOP];
float qk_max[QHLOOP];
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
dout[h] = {0};
qk_max[h] = -FLT_MAX;
}
const int wg_start_head_idx = blockIdx.z * GQA_RATIO;
const int wg_start_kv_head_idx = blockIdx.z;
const int warp_start_token_idx =
partition_start_token_idx + warpid * WARP_SIZE;
if (warp_start_token_idx >= context_len) { // warp out of context
#pragma unroll
for (int h = 0; h < GQA_RATIO4; h++) {
shared_qk_max[warpid][h] = -FLT_MAX;
shared_exp_sum[warpid][h] = 0.0f;
}
} else { // warp within context
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
const int last_ctx_block = num_context_blocks - 1;
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
const int local_token_idx = threadIdx.x;
const int global_token_idx = partition_start_token_idx + local_token_idx;
const int block_idx = (global_token_idx < context_len)
? global_token_idx / BLOCK_SIZE
: last_ctx_block;
// fetch block number for q and k
// int32 physical_block_number leads to overflow when multiplied with
// kv_block_stride
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
// fetch vphysical block numbers up front
constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE;
int vphysical_blocks[VBLOCKS];
const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE;
#pragma unroll
for (int b = 0; b < VBLOCKS; b++) {
const int vblock_idx = warp_start_block_idx + b;
const int vblock_idx_ctx =
(vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block;
vphysical_blocks[b] = block_table[vblock_idx_ctx];
}
// each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems
const scalar_t* q_ptr =
q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE;
const _B16x8* q_ptrh8 = reinterpret_cast<const _B16x8*>(q_ptr);
const int qhead_elemh8 = laneid / 4;
#pragma unroll
for (int h = 0; h < QHLOOP - 1; h++) {
const int qhead_idx = h * 4 + lane4id;
Qlocal[h] = q_ptrh8[qhead_idx * HEAD_SIZE / 8 + qhead_elemh8];
}
const int final_qhead_idx = 4 * (QHLOOP - 1) + lane4id;
if (final_qhead_idx < GQA_RATIO) {
Qlocal[QHLOOP - 1] =
q_ptrh8[final_qhead_idx * HEAD_SIZE / 8 + qhead_elemh8];
} else {
Qlocal[QHLOOP - 1].xy[0] = {0};
Qlocal[QHLOOP - 1].xy[1] = {0};
}
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride +
wg_start_kv_head_idx * kv_head_stride;
const int physical_block_offset =
local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset
// is already cast as _H8
if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) {
const _B16x8* k_ptrh8 = reinterpret_cast<const _B16x8*>(k_ptr);
#pragma unroll
for (int d = 0; d < KHELOOP; d++) {
Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset];
}
} else {
constexpr int X = 16 / sizeof(cache_t);
const cache_t* k_ptr2 = k_ptr + physical_block_offset * X;
#pragma unroll
for (int d = 0; d < KHELOOP; d++) {
const int head_elem = d * 8;
const int offset1 = head_elem / X;
const int offset2 = head_elem % X;
const cache_t* k_ptr3 = k_ptr2 + offset1 * BLOCK_SIZE * X + offset2;
Klocalb8[d] = *reinterpret_cast<const _B8x8*>(k_ptr3);
}
}
float alibi_slope[QHLOOP];
if (alibi_slopes != nullptr) {
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
const int qhead_idx = h * 4 + lane4id;
alibi_slope[h] = (qhead_idx < GQA_RATIO)
? alibi_slopes[wg_start_head_idx + qhead_idx]
: 0.f;
}
}
const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride;
if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) {
const _B16x8* v_ptrh8 = reinterpret_cast<const _B16x8*>(v_ptr);
// iterate over each v block
#pragma unroll
for (int b = 0; b < VBLOCKS; b++) {
// int32 physical_block_number leads to overflow when multiplied with
// kv_block_stride
const int64_t vphysical_block_number =
static_cast<int64_t>(vphysical_blocks[b]);
const _B16x8* v_ptrh8b =
v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8;
// iterate over each head elem (within head_size)
#pragma unroll
for (int h = 0; h < VHELOOP; h++) {
const int head_size_elem = h * WARP_SIZE + laneid;
const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8;
// iterate over all velems within block
#pragma unroll
for (int d = 0; d < BLOCK_SIZE / 8; d++) {
Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
}
}
}
} else {
const _B8x8* v_ptrh8 = reinterpret_cast<const _B8x8*>(v_ptr);
// iterate over each v block
#pragma unroll
for (int b = 0; b < VBLOCKS; b++) {
// int32 physical_block_number leads to overflow when multiplied with
// kv_block_stride
const int64_t vphysical_block_number =
static_cast<int64_t>(vphysical_blocks[b]);
const _B8x8* v_ptrh8b =
v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8;
// iterate over each head elem (within head_size)
#pragma unroll
for (int h = 0; h < VHELOOP; h++) {
const int head_size_elem = h * WARP_SIZE + laneid;
const _B8x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8;
// iterate over all velems within block
#pragma unroll
for (int d = 0; d < BLOCK_SIZE / 8; d++) {
// Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
const _B8x8 Vlocalb8 = v_ptrh8be[d];
Vlocal[h][b * BLOCK_SIZE / 8 + d] =
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Vlocalb8, v_scale);
}
}
}
}
if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) {
#pragma unroll
for (int d = 0; d < KHELOOP; d++) {
Klocal[d] =
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Klocalb8[d], k_scale);
}
}
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
dout[h] = gcn_mfma_instr<scalar_t, 4, 0, 0>(Qlocal[h].xy[0],
Klocal[0].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 0, 0>(Qlocal[h].xy[1],
Klocal[0].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 1, 0>(Qlocal[h].xy[0],
Klocal[1].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 1, 0>(Qlocal[h].xy[1],
Klocal[1].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 2, 0>(Qlocal[h].xy[0],
Klocal[2].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 2, 0>(Qlocal[h].xy[1],
Klocal[2].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 3, 0>(Qlocal[h].xy[0],
Klocal[3].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 3, 0>(Qlocal[h].xy[1],
Klocal[3].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 4, 0>(Qlocal[h].xy[0],
Klocal[4].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 4, 0>(Qlocal[h].xy[1],
Klocal[4].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 5, 0>(Qlocal[h].xy[0],
Klocal[5].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 5, 0>(Qlocal[h].xy[1],
Klocal[5].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 6, 0>(Qlocal[h].xy[0],
Klocal[6].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 6, 0>(Qlocal[h].xy[1],
Klocal[6].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 7, 0>(Qlocal[h].xy[0],
Klocal[7].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 7, 0>(Qlocal[h].xy[1],
Klocal[7].xy[1], dout[h]);
if constexpr (KHELOOP > 8) {
dout[h] = gcn_mfma_instr<scalar_t, 4, 8, 0>(Qlocal[h].xy[0],
Klocal[8].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 8, 0>(Qlocal[h].xy[1],
Klocal[8].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 9, 0>(Qlocal[h].xy[0],
Klocal[9].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 9, 0>(Qlocal[h].xy[1],
Klocal[9].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 10, 0>(Qlocal[h].xy[0],
Klocal[10].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 10, 0>(Qlocal[h].xy[1],
Klocal[10].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 11, 0>(Qlocal[h].xy[0],
Klocal[11].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 11, 0>(Qlocal[h].xy[1],
Klocal[11].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 12, 0>(Qlocal[h].xy[0],
Klocal[12].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 12, 0>(Qlocal[h].xy[1],
Klocal[12].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 13, 0>(Qlocal[h].xy[0],
Klocal[13].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 13, 0>(Qlocal[h].xy[1],
Klocal[13].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 14, 0>(Qlocal[h].xy[0],
Klocal[14].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 14, 0>(Qlocal[h].xy[1],
Klocal[14].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 15, 0>(Qlocal[h].xy[0],
Klocal[15].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 15, 0>(Qlocal[h].xy[1],
Klocal[15].xy[1], dout[h]);
} // KHELOOP>8
dout[h] *= scale;
}
// transpose dout so that 4 token ids are in each lane, and 4 heads are across
// 4 lanes
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
floatx4 tmp = {0};
#pragma unroll
for (int i = 0; i < 4; i++) {
const float B = (lane4id == i) ? 1.0f : 0.0f;
// const float A = (global_token_idx < context_len) ? dout[h][i] : 0.0f;
tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(dout[h][i], B, tmp, 0, 0, 0);
// tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(A, B, tmp, 0, 0, 0);
}
dout[h] = tmp;
}
const int lane4_token_idx = 4 * (global_token_idx >> 2);
const int alibi_offset = lane4_token_idx - context_len + 1;
if (alibi_slopes != nullptr) {
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
#pragma unroll
for (int i = 0; i < 4; i++) {
dout[h][i] += alibi_slope[h] * (alibi_offset + i);
}
}
}
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
qk_max[h] = -FLT_MAX;
#pragma unroll
for (int i = 0; i < 4; i++) {
qk_max[h] = (lane4_token_idx + i < context_len)
? fmaxf(qk_max[h], dout[h][i])
: qk_max[h];
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) {
qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask));
}
}
float exp_sum[QHLOOP];
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
exp_sum[h] = 0.0f;
#pragma unroll
for (int i = 0; i < 4; i++) {
dout[h][i] = (lane4_token_idx + i < context_len)
? __expf(dout[h][i] - qk_max[h])
: 0.0f;
exp_sum[h] += dout[h][i];
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) {
exp_sum[h] += __shfl_xor(exp_sum[h], mask);
}
}
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
const int head_idx = 4 * h + lane4id;
shared_qk_max[warpid][head_idx] = qk_max[h];
shared_exp_sum[warpid][head_idx] = exp_sum[h];
}
} // warp within context
__syncthreads();
const int num_heads = gridDim.z * GQA_RATIO;
float* max_logits_ptr =
max_logits + seq_idx * num_heads * max_num_partitions + partition_idx;
float* exp_sums_ptr =
exp_sums + seq_idx * num_heads * max_num_partitions + partition_idx;
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
float global_qk_max = -FLT_MAX;
float warp_qk_max[NWARPS];
const int head_idx = 4 * h + lane4id;
#pragma unroll
for (int w = 0; w < NWARPS; w++) {
warp_qk_max[w] = shared_qk_max[w][head_idx];
global_qk_max = fmaxf(global_qk_max, warp_qk_max[w]);
}
float global_exp_sum = 0.0f;
#pragma unroll
for (int w = 0; w < NWARPS; w++) {
global_exp_sum +=
shared_exp_sum[w][head_idx] * __expf(warp_qk_max[w] - global_qk_max);
}
if (head_idx < GQA_RATIO) {
max_logits_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] =
global_qk_max;
exp_sums_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] =
global_exp_sum;
}
const float global_inv_sum_scale = __fdividef(1.f, global_exp_sum + 1e-6f) *
__expf(qk_max[h] - global_qk_max);
dout[h] *= global_inv_sum_scale;
}
// logits[h] -> every 4 lanes hold 4 heads, each lane holds 4 tokens, there
// are 4x16 tokens across warp
_B16x4 logits[QHLOOP];
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
logits[h] = from_floatx4<scalar_t>(dout[h]);
}
__shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1];
if (warp_start_token_idx >= context_len) { // warp out of context
#pragma unroll
for (int qh = 0; qh < QHLOOP; qh++) {
#pragma unroll
for (int vh = 0; vh < VHELOOP; vh++) {
vout_shared[qh][vh][laneid][warpid] = {0};
}
}
} else { // warp in context
// iterate across heads
#pragma unroll
for (int qh = 0; qh < QHLOOP; qh++) {
// iterate over each v head elem (within head_size)
#pragma unroll
for (int vh = 0; vh < VHELOOP; vh++) {
floatx4 acc = {0};
// iterate over tokens
acc = gcn_mfma_instr<scalar_t, 4, 0, 0>(logits[qh], Vlocal[vh][0].xy[0],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 1, 0>(logits[qh], Vlocal[vh][0].xy[1],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 2, 0>(logits[qh], Vlocal[vh][1].xy[0],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 3, 0>(logits[qh], Vlocal[vh][1].xy[1],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 4, 0>(logits[qh], Vlocal[vh][2].xy[0],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 5, 0>(logits[qh], Vlocal[vh][2].xy[1],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 6, 0>(logits[qh], Vlocal[vh][3].xy[0],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 7, 0>(logits[qh], Vlocal[vh][3].xy[1],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 8, 0>(logits[qh], Vlocal[vh][4].xy[0],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 9, 0>(logits[qh], Vlocal[vh][4].xy[1],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 10, 0>(logits[qh],
Vlocal[vh][5].xy[0], acc);
acc = gcn_mfma_instr<scalar_t, 4, 11, 0>(logits[qh],
Vlocal[vh][5].xy[1], acc);
acc = gcn_mfma_instr<scalar_t, 4, 12, 0>(logits[qh],
Vlocal[vh][6].xy[0], acc);
acc = gcn_mfma_instr<scalar_t, 4, 13, 0>(logits[qh],
Vlocal[vh][6].xy[1], acc);
acc = gcn_mfma_instr<scalar_t, 4, 14, 0>(logits[qh],
Vlocal[vh][7].xy[0], acc);
acc = gcn_mfma_instr<scalar_t, 4, 15, 0>(logits[qh],
Vlocal[vh][7].xy[1], acc);
vout_shared[qh][vh][laneid][warpid] = from_floatx4<scalar_t>(acc);
}
}
} // warp in context
__syncthreads();
if (warpid == 0) {
_B16x4 vout[QHLOOP][VHELOOP];
// iterate across heads
scalar_t* out_ptr;
int out_num_partitions;
if (context_len > partition_size) {
out_num_partitions = max_num_partitions;
out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
partition_idx * HEAD_SIZE;
} else {
out_num_partitions = 1;
out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE;
}
#pragma unroll
for (int qh = 0; qh < QHLOOP; qh++) {
// iterate over each v head elem (within head_size)
#pragma unroll
for (int vh = 0; vh < VHELOOP; vh++) {
vout[qh][vh] = {0};
#pragma unroll
for (int w = 0; w < NWARPS; w++) {
vout[qh][vh] =
addx4<scalar_t>(vout[qh][vh], vout_shared[qh][vh][laneid][w]);
}
const int head_size_elem = vh * WARP_SIZE + laneid;
bit16_t* out_ptr_b16 = reinterpret_cast<bit16_t*>(out_ptr);
#pragma unroll
for (int i = 0; i < 4; i++) {
const int head_idx = 4 * qh + i;
if (head_idx < GQA_RATIO) {
out_ptr_b16[(wg_start_head_idx + head_idx) * out_num_partitions *
HEAD_SIZE +
head_size_elem] = vout[qh][vh][i];
}
}
}
}
}
}
// Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
int PARTITION_SIZE>
__global__
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_partitions) {
const int num_heads = gridDim.x;
const int head_idx = blockIdx.x;
const int seq_idx = blockIdx.y;
const int context_len = context_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
if (num_partitions == 1) {
// if num_partitions==1, main kernel will write to out directly, no work in
// reduction kernel
return;
}
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int warpid = threadIdx.x / WARP_SIZE;
const int laneid = threadIdx.x % WARP_SIZE;
__shared__ float shared_global_exp_sum;
__shared__ float shared_exp_sums[2 * WARP_SIZE];
if (warpid == 0) {
const float* max_logits_ptr = max_logits +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
// valid partition is the last valid partition in case threadid > num
// partitions
const int valid_partition =
(threadIdx.x < num_partitions) ? threadIdx.x : num_partitions - 1;
const int valid_partition2 = (WARP_SIZE + threadIdx.x < num_partitions)
? WARP_SIZE + threadIdx.x
: num_partitions - 1;
float reg_max_logit = max_logits_ptr[valid_partition];
float reg_max_logit2 = max_logits_ptr[valid_partition2];
float max_logit = fmaxf(reg_max_logit, reg_max_logit2);
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask));
}
const float* exp_sums_ptr = exp_sums +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float global_exp_sum = 0.0f;
float rescaled_exp_sum = exp_sums_ptr[valid_partition];
float rescaled_exp_sum2 = exp_sums_ptr[valid_partition2];
rescaled_exp_sum *=
(threadIdx.x < num_partitions) ? expf(reg_max_logit - max_logit) : 0.0f;
rescaled_exp_sum2 *= (threadIdx.x + WARP_SIZE < num_partitions)
? expf(reg_max_logit2 - max_logit)
: 0.0f;
global_exp_sum += rescaled_exp_sum + rescaled_exp_sum2;
shared_exp_sums[threadIdx.x] = rescaled_exp_sum;
shared_exp_sums[threadIdx.x + WARP_SIZE] = rescaled_exp_sum2;
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
global_exp_sum += __shfl_xor(global_exp_sum, mask);
}
if (threadIdx.x == 0) {
shared_global_exp_sum = global_exp_sum;
}
} // warpid == 0
const scalar_t* tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x;
constexpr int MAX_NPAR = 64;
scalar_t tmps[MAX_NPAR];
const float dzero = 0.0f;
#pragma unroll
for (int j = 0; j < MAX_NPAR; j++) {
tmps[j] = from_float<scalar_t>(dzero);
}
const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE;
const int num_partition_offset = (num_partitions)*HEAD_SIZE;
int idx = 0;
constexpr int JCHUNK = 16;
#pragma unroll
for (int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE) {
// lastj is last valid partition
const int lastj_offset =
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
__syncthreads();
if (num_partitions > JCHUNK) {
#pragma unroll
for (int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE;
j += HEAD_SIZE) {
const int lastj_offset =
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
if (num_partitions > 2 * JCHUNK) {
#pragma unroll
for (int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE;
j += HEAD_SIZE) {
const int lastj_offset =
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
}
} // num_partitions > JCHUNK
// Aggregate tmp_out to out.
float acc = 0.0f;
#pragma unroll
for (int j = 0; j < JCHUNK; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
}
if (num_partitions > JCHUNK) {
#pragma unroll
for (int j = JCHUNK; j < 2 * JCHUNK; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
}
if (num_partitions > 2 * JCHUNK) {
#pragma unroll
for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
}
}
}
if (num_partitions > MAX_NPAR) {
idx = 0;
#pragma unroll
for (int j = MAX_NPAR * HEAD_SIZE; j < 2 * MAX_NPAR * HEAD_SIZE;
j += HEAD_SIZE) {
// lastj is last valid partition
const int lastj_offset =
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
#pragma unroll
for (int j = 0; j < MAX_NPAR; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j + MAX_NPAR];
}
}
const float inv_global_exp_sum =
__fdividef(1.0f, shared_global_exp_sum + 1e-6f);
acc *= inv_global_exp_sum;
scalar_t* out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
out_ptr[threadIdx.x] = from_float<scalar_t>(acc);
}
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
template <typename scalar_t, typename cache_t,
vllm::Fp8KVCacheDataType KV_DTYPE, int BLOCK_SIZE, int HEAD_SIZE,
int NUM_THREADS,
int GQA_RATIO>
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
int max_ctx_blocks, float k_scale, float v_scale) {
UNREACHABLE_CODE
}
// Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
int PARTITION_SIZE>
__global__
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_partitions){UNREACHABLE_CODE}
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \
paged_attention_ll4mi_QKV_kernel<T, KVT, KV_DTYPE, BLOCK_SIZE, HEAD_SIZE, \
NTHR, GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \
k_scale, v_scale);
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
int BLOCK_SIZE, int HEAD_SIZE, int PARTITION_SIZE = 512>
void paged_attention_custom_launcher(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, const int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& context_lens,
int max_context_len, const c10::optional<torch::Tensor>& alibi_slopes,
float k_scale, float v_scale) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
KVT* key_cache_ptr = reinterpret_cast<KVT*>(key_cache.data_ptr());
KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
const int max_num_partitions =
DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
const int gqa_ratio = num_heads / num_kv_heads;
assert(num_heads % num_kv_heads == 0);
assert(head_size == HEAD_SIZE);
assert(max_num_partitions <= 128);
constexpr int NTHR = PARTITION_SIZE;
dim3 grid(num_seqs, max_num_partitions, num_kv_heads);
dim3 block(NTHR);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (gqa_ratio) {
case 1:
LAUNCH_CUSTOM_ATTENTION(1);
break;
case 2:
LAUNCH_CUSTOM_ATTENTION(2);
break;
case 3:
LAUNCH_CUSTOM_ATTENTION(3);
break;
case 4:
LAUNCH_CUSTOM_ATTENTION(4);
break;
case 5:
LAUNCH_CUSTOM_ATTENTION(5);
break;
case 6:
LAUNCH_CUSTOM_ATTENTION(6);
break;
case 7:
LAUNCH_CUSTOM_ATTENTION(7);
break;
case 8:
LAUNCH_CUSTOM_ATTENTION(8);
break;
case 9:
LAUNCH_CUSTOM_ATTENTION(9);
break;
case 10:
LAUNCH_CUSTOM_ATTENTION(10);
break;
case 11:
LAUNCH_CUSTOM_ATTENTION(11);
break;
case 12:
LAUNCH_CUSTOM_ATTENTION(12);
break;
case 13:
LAUNCH_CUSTOM_ATTENTION(13);
break;
case 14:
LAUNCH_CUSTOM_ATTENTION(14);
break;
case 15:
LAUNCH_CUSTOM_ATTENTION(15);
break;
case 16:
LAUNCH_CUSTOM_ATTENTION(16);
break;
default:
TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio);
break;
}
// dim3 grid2(num_heads,num_seqs,head_size/HEAD_ELEMS_PER_WG);
// dim3 block2(1024);
// LAUNCH_CUSTOM_ATTENTION2;
// reduction kernel is only required if max_context_len > partition size,
// otherwise main kernel writes directly to final output
// note there are cases with graphing where max_context_len is the max
// supported by graphing, not the actual max among all the sequences: in that
// case reduction kernel will still run but return immediately
if (max_context_len > PARTITION_SIZE) {
dim3 reduce_grid(num_heads, num_seqs);
dim3 reduce_block(head_size);
paged_attention_ll4mi_reduce_kernel<T, HEAD_SIZE, HEAD_SIZE, PARTITION_SIZE>
<<<reduce_grid, reduce_block, 0, stream>>>(
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr,
context_lens_ptr, max_num_partitions);
}
}
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, max_context_len, \
alibi_slopes, k_scale, v_scale);
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
switch (block_size) { \
case 16: \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \
break; \
case 32: \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \
switch (head_size) { \
case 64: \
CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); \
break; \
case 128: \
CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); \
break; \
default: \
TORCH_CHECK(false, "Unsupported head size: ", head_size); \
break; \
}
void paged_attention(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor&
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs]
int64_t block_size, int64_t max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale) {
const int head_size = query.size(2);
if (kv_cache_dtype == "auto") {
if (query.dtype() == at::ScalarType::Half) {
CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, _Float16,
vllm::Fp8KVCacheDataType::kAuto);
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, __hip_bfloat16,
vllm::Fp8KVCacheDataType::kAuto);
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
} else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
if (query.dtype() == at::ScalarType::Half) {
CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3);
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
} else {
TORCH_CHECK(false, "Unsupported KV cache dtype: ", kv_cache_dtype);
}
}
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
\ No newline at end of file
#pragma once
#include <torch/all.h>
void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
torch::Tensor& max_logits, torch::Tensor& tmp_out,
torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads,
double scale, torch::Tensor& block_tables,
torch::Tensor& context_lens, int64_t block_size,
int64_t max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale,
double v_scale);
#include "core/registration.h"
#include "rocm/ops.h"
// Note on op signatures:
// The X_meta signatures are for the meta functions corresponding to op X.
// They must be kept in sync with the signature for X. Generally, only
// functions that return Tensors require a meta function.
//
// See the following links for detailed docs on op registration and function
// schemas.
// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
// vLLM custom ops for rocm
// Custom attention op
// Compute the attention between an input query and the cached
// keys/values using PagedAttention.
rocm_ops.def(
"paged_attention(Tensor! out, Tensor exp_sums,"
" Tensor max_logits, Tensor tmp_out,"
" Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads,"
" float scale, Tensor block_tables,"
" Tensor context_lens, int block_size,"
" int max_context_len,"
" Tensor? alibi_slopes,"
" str kv_cache_dtype,"
" float k_scale, float v_scale) -> ()");
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
......@@ -257,6 +257,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"-> Tensor");
ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B);
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
ops.impl("permute_cols", torch::kCUDA, &permute_cols);
// gptq_marlin Optimized Quantized GEMM for GPTQ.
ops.def(
"gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
......@@ -337,15 +340,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor! A, Tensor! B, Tensor! C,"
"Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
"bool delta_softplus,"
"Tensor? index_, Tensor(a! -> *)? x) -> Tensor(a)[]");
"Tensor? index_, Tensor!? x) -> Tensor[]");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
ops.def(
"causal_conv1d_update(Tensor! x,"
"Tensor! conv_state,"
"Tensor! weight,"
"Tensor? bias_,"
"bool silu_activation) -> Tensor");
"Tensor? bias,"
"bool silu_activation,"
"Tensor? conv_state_indices) -> Tensor");
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
ops.def(
......@@ -353,7 +357,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? bias_,"
"Tensor? seq_idx_,"
"Tensor? initial_states_,"
"Tensor? final_states_out_,"
"Tensor!? final_states_out_,"
"bool silu_activation) -> Tensor");
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
#endif
......@@ -402,14 +406,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Compute int8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> "
"()");
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
"Tensor? azp) -> ()");
ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);
// Compute int8 quantized tensor and scaling factor
ops.def(
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
"()");
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
"Tensor!? azp) -> ()");
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
&dynamic_scaled_int8_quant);
}
......@@ -477,11 +481,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
"bool full_nvlink) -> int");
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
custom_ar.def(
"should_custom_ar(Tensor inp, int max_size, int world_size, "
"bool full_nvlink) -> bool");
custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar);
custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);
......
......@@ -21,8 +21,8 @@ Traces can be visualized using https://ui.perfetto.dev/.
.. tip::
To stop the profiler - it flushes out all the profile trace files to the directory. This takes time, for example for about 100 requests worth of data for a llama 70b, it takes about 10 minutes to flush out on a H100.
Set the env variable VLLM_RPC_GET_DATA_TIMEOUT_MS to a big number before you start the server. Say something like 30 minutes.
``export VLLM_RPC_GET_DATA_TIMEOUT_MS=1800000``
Set the env variable VLLM_RPC_TIMEOUT to a big number before you start the server. Say something like 30 minutes.
``export VLLM_RPC_TIMEOUT=1800000``
Example commands and usage:
===========================
......
......@@ -3,15 +3,17 @@
Installation with ROCm
======================
vLLM supports AMD GPUs with ROCm 6.1.
vLLM supports AMD GPUs with ROCm 6.2.
Requirements
------------
* OS: Linux
* Python: 3.8 -- 3.11
* Python: 3.9 -- 3.12
* GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100)
* ROCm 6.1
* ROCm 6.2
Note: PyTorch 2.5+/ROCm6.2 dropped the support for python 3.8.
Installation options:
......@@ -26,8 +28,18 @@ Option 1: Build from source with docker (recommended)
You can build and install vLLM from source.
First, build a docker image from `Dockerfile.rocm <https://github.com/vllm-project/vllm/blob/main/Dockerfile.rocm>`_ and launch a docker container from the image.
It is important that the user kicks off the docker build using buildkit. Either the user put DOCKER_BUILDKIT=1 as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon:
.. code-block:: console
{
"features": {
"buildkit": true
}
}
`Dockerfile.rocm <https://github.com/vllm-project/vllm/blob/main/Dockerfile.rocm>`_ uses ROCm 6.1 by default, but also supports ROCm 5.7 and 6.0 in older vLLM branches.
`Dockerfile.rocm <https://github.com/vllm-project/vllm/blob/main/Dockerfile.rocm>`_ uses ROCm 6.2 by default, but also supports ROCm 5.7, 6.0 and 6.1 in older vLLM branches.
It provides flexibility to customize the build of docker image using the following arguments:
* `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image.
......@@ -39,13 +51,13 @@ It provides flexibility to customize the build of docker image using the followi
Their values can be passed in when running ``docker build`` with ``--build-arg`` options.
To build vllm on ROCm 6.1 for MI200 and MI300 series, you can use the default:
To build vllm on ROCm 6.2 for MI200 and MI300 series, you can use the default:
.. code-block:: console
$ DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm -t vllm-rocm .
To build vllm on ROCm 6.1 for Radeon RX7900 series (gfx1100), you should specify ``BUILD_FA`` as below:
To build vllm on ROCm 6.2 for Radeon RX7900 series (gfx1100), you should specify ``BUILD_FA`` as below:
.. code-block:: console
......@@ -79,37 +91,55 @@ Option 2: Build from source
- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
- `PyTorch <https://pytorch.org/>`_
- `hipBLAS <https://rocm.docs.amd.com/projects/hipBLAS/en/latest/install.html>`_
For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging`, `rocm/pytorch-nightly`.
For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0`, `rocm/pytorch-nightly`.
Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guild in PyTorch `Getting Started <https://pytorch.org/get-started/locally/>`_
Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch `Getting Started <https://pytorch.org/get-started/locally/>`_
1. Install `Triton flash attention for ROCm <https://github.com/ROCm/triton>`_
Install ROCm's Triton flash attention (the default triton-mlir branch) following the instructions from `ROCm/triton <https://github.com/ROCm/triton/blob/triton-mlir/README.md>`_
.. code-block:: console
$ python3 -m pip install ninja cmake wheel pybind11
$ pip uninstall -y triton
$ git clone https://github.com/OpenAI/triton.git
$ cd triton
$ git checkout e192dba
$ cd python
$ pip3 install .
$ cd ../..
.. note::
- If you see HTTP issue related to downloading packages during building triton, please try again as the HTTP error is intermittent.
2. Optionally, if you choose to use CK flash attention, you can install `flash attention for ROCm <https://github.com/ROCm/flash-attention/tree/ck_tile>`_
Install ROCm's flash attention (v2.5.9.post1) following the instructions from `ROCm/flash-attention <https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support>`_
Alternatively, wheels intended for vLLM use can be accessed under the releases.
.. note::
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
3. Build vLLM.
For example, for ROCm 6.2, suppose your gfx arch is `gfx90a`.
Note to get your gfx architecture, run `rocminfo |grep gfx`.
.. code-block:: console
.. code-block:: console
$ cd vllm
$ pip install -U -r requirements-rocm.txt
$ python setup.py develop # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation
$ git clone https://github.com/ROCm/flash-attention.git
$ cd flash-attention
$ git checkout 3cea2fb
$ git submodule update --init
$ GPU_ARCHS="gfx90a" python3 setup.py install
$ cd ..
.. note::
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
.. tip::
3. Build vLLM.
For example, vLLM v0.5.3 on ROCM 6.1 can be built with the following steps:
For example, vLLM on ROCM 6.2 can be built with the following steps:
.. code-block:: console
......@@ -117,7 +147,7 @@ Alternatively, wheels intended for vLLM use can be accessed under the releases.
$ # Install PyTorch
$ pip uninstall torch -y
$ pip install --no-cache-dir --pre torch==2.5.0.dev20240726 --index-url https://download.pytorch.org/whl/nightly/rocm6.1
$ pip install --no-cache-dir --pre torch==2.6.0.dev20240918 --index-url https://download.pytorch.org/whl/nightly/rocm6.2
$ # Build & install AMD SMI
$ pip install /opt/rocm/share/amd_smi
......@@ -127,15 +157,14 @@ Alternatively, wheels intended for vLLM use can be accessed under the releases.
$ pip install "numpy<2"
$ pip install -r requirements-rocm.txt
$ # Apply the patch to ROCM 6.1 (requires root permission)
$ wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P /opt/rocm/lib
$ rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so*
$ # Build vLLM for MI210/MI250/MI300.
$ export PYTORCH_ROCM_ARCH="gfx90a;gfx942"
$ python3 setup.py develop
This may take 5-10 minutes. Currently, :code:`pip install .` does not work for ROCm installation.
.. tip::
- Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers.
......
......@@ -56,7 +56,7 @@ Build from source
.. code-block:: console
$ pip install --upgrade pip
$ pip install wheel packaging ninja "setuptools>=49.4.0" numpy
$ pip install cmake>=3.26 wheel packaging ninja "setuptools-scm>=8" numpy
$ pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
- Third, build and install oneDNN library from source:
......
......@@ -98,6 +98,13 @@ Here are some common issues that can cause hangs:
If the script runs successfully, you should see the message ``sanity check is successful!``.
Note that multi-node environment is more complicated than single-node. If you see errors such as ``torch.distributed.DistNetworkError``, it is likely that the network/DNS setup is incorrect. In that case, you can manually assign node rank and specify the IP via command line arguments:
- In the first node, run ``NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --node-rank 0 --master_addr $MASTER_ADDR test.py``.
- In the second node, run ``NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --node-rank 1 --master_addr $MASTER_ADDR test.py``.
Adjust ``--nproc-per-node``, ``--nnodes``, and ``--node-rank`` according to your setup. The difference is that you need to execute different commands (with different ``--node-rank``) on different nodes.
If the problem persists, feel free to `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_, with a detailed description of the issue, your environment, and the logs.
Some known issues:
......
......@@ -72,6 +72,29 @@ You can also build and install vLLM from source:
$ cd vllm
$ pip install -e . # This may take 5-10 minutes.
.. note::
This will uninstall existing PyTorch, and install the version required by vLLM. If you want to use an existing PyTorch installation, there need to be some changes:
.. code-block:: console
$ git clone https://github.com/vllm-project/vllm.git
$ cd vllm
$ python use_existing_torch.py
$ pip install -r requirements-build.txt
$ pip install -e . --no-build-isolation
The differences are:
- ``python use_existing_torch.py``: This script will remove all the PyTorch versions in the requirements files, so that the existing PyTorch installation will be used.
- ``pip install -r requirements-build.txt``: You need to manually install the requirements for building vLLM.
- ``pip install -e . --no-build-isolation``: You need to disable build isolation, so that the build system can use the existing PyTorch installation.
This is especially useful when the PyTorch dependency cannot be easily installed via pip, e.g.:
- build vLLM with PyTorch nightly or a custom PyTorch build.
- build vLLM with aarch64 and cuda (GH200), where the PyTorch wheels are not available on PyPI. Currently, only PyTorch nightly has wheels for aarch64 with CUDA. You can run ``pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124`` to install PyTorch nightly, and then build vLLM on top of it.
.. note::
vLLM can fully run only on Linux, but you can still build it on other systems (for example, macOS). This build is only for development purposes, allowing for imports and a more convenient dev environment. The binaries will not be compiled and not work on non-Linux systems. You can create such a build with the following commands:
......@@ -95,6 +118,8 @@ You can also build and install vLLM from source:
$ export MAX_JOBS=6
$ pip install -e .
This is especially useful when you are building on less powerful machines. For example, when you use WSL, it only `gives you half of the memory by default <https://learn.microsoft.com/en-us/windows/wsl/wsl-config>`_, and you'd better use ``export MAX_JOBS=1`` to avoid compiling multiple files simultaneously and running out of memory. The side effect is that the build process will be much slower. If you only touch the Python code, slow compilation is okay, as you are building in an editable mode: you can just change the code and run the Python script without any re-compilation or re-installation.
.. tip::
If you have trouble building vLLM, we recommend using the NVIDIA PyTorch Docker image.
......
......@@ -3,8 +3,8 @@
Installation with Neuron
========================
vLLM 0.3.3 onwards supports model inferencing and serving on AWS Trainium/Inferentia with Neuron SDK.
At the moment Paged Attention is not supported in Neuron SDK, but naive continuous batching is supported in transformers-neuronx.
vLLM 0.3.3 onwards supports model inferencing and serving on AWS Trainium/Inferentia with Neuron SDK with continuous batching.
Paged Attention and Chunked Prefill are currently in development and will be available soon.
Data types currently supported in Neuron SDK are FP16 and BF16.
Requirements
......
......@@ -17,8 +17,8 @@ Requirements
------------
* OS: Linux
* Supported Hardware: Intel Data Center GPU (Intel ARC GPU WIP)
* OneAPI requirements: oneAPI 2024.1
* Supported Hardware: Intel Data Center GPU, Intel ARC GPU
* OneAPI requirements: oneAPI 2024.2
.. _xpu_backend_quick_start_dockerfile:
......@@ -40,7 +40,7 @@ Quick start using Dockerfile
Build from source
-----------------
- First, install required driver and intel OneAPI 2024.1 or later.
- First, install required driver and intel OneAPI 2024.2 or later.
- Second, install Python packages for vLLM XPU backend building:
......
......@@ -43,7 +43,7 @@ vLLM is flexible and easy to use with:
* Tensor parallelism and pipeline parallelism support for distributed inference
* Streaming outputs
* OpenAI-compatible API server
* Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron.
* Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Trainium and Inferentia Accelerators.
* Prefix caching support
* Multi-lora support
......@@ -107,6 +107,7 @@ Documentation
quantization/supported_hardware
quantization/auto_awq
quantization/bnb
quantization/gguf
quantization/int8
quantization/fp8
quantization/fp8_e5m2_kvcache
......
......@@ -159,3 +159,67 @@ Example request to unload a LoRA adapter:
-d '{
"lora_name": "sql_adapter"
}'
New format for `--lora-modules`
-------------------------------
In the previous version, users would provide LoRA modules via the following format, either as a key-value pair or in JSON format. For example:
.. code-block:: bash
--lora-modules sql-lora=$HOME/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/
This would only include the `name` and `path` for each LoRA module, but did not provide a way to specify a `base_model_name`.
Now, you can specify a base_model_name alongside the name and path using JSON format. For example:
.. code-block:: bash
--lora-modules '{"name": "sql-lora", "path": "/path/to/lora", "base_model_name": "meta-llama/Llama-2-7b"}'
To provide the backward compatibility support, you can still use the old key-value format (name=path), but the `base_model_name` will remain unspecified in that case.
Lora model lineage in model card
--------------------------------
The new format of `--lora-modules` is mainly to support the display of parent model information in the model card. Here's an explanation of how your current response supports this:
- The `parent` field of LoRA model `sql-lora` now links to its base model `meta-llama/Llama-2-7b-hf`. This correctly reflects the hierarchical relationship between the base model and the LoRA adapter.
- The `root` field points to the artifact location of the lora adapter.
.. code-block:: bash
$ curl http://localhost:8000/v1/models
{
"object": "list",
"data": [
{
"id": "meta-llama/Llama-2-7b-hf",
"object": "model",
"created": 1715644056,
"owned_by": "vllm",
"root": "~/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9/",
"parent": null,
"permission": [
{
.....
}
]
},
{
"id": "sql-lora",
"object": "model",
"created": 1715644056,
"owned_by": "vllm",
"root": "~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/",
"parent": meta-llama/Llama-2-7b-hf,
"permission": [
{
....
}
]
}
]
}
......@@ -107,6 +107,10 @@ Decoder-only Language Models
- MiniCPM
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc.
-
* - :code:`MiniCPM3ForCausalLM`
- MiniCPM3
- :code:`openbmb/MiniCPM3-4B`, etc.
-
* - :code:`MistralForCausalLM`
- Mistral, Mistral-Instruct
- :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc.
......@@ -123,6 +127,10 @@ Decoder-only Language Models
- Nemotron-3, Nemotron-4, Minitron
- :code:`nvidia/Minitron-8B-Base`, :code:`mgoin/Nemotron-4-340B-Base-hf-FP8`, etc.
- ✅︎
* - :code:`OLMoEForCausalLM`
- OLMoE
- :code:`allenai/OLMoE-1B-7B-0924`, :code:`allenai/OLMoE-1B-7B-0924-Instruct`, etc.
-
* - :code:`OLMoForCausalLM`
- OLMo
- :code:`allenai/OLMo-1B-hf`, :code:`allenai/OLMo-7B-hf`, etc.
......@@ -175,6 +183,10 @@ Decoder-only Language Models
- Starcoder2
- :code:`bigcode/starcoder2-3b`, :code:`bigcode/starcoder2-7b`, :code:`bigcode/starcoder2-15b`, etc.
-
* - :code:`SolarForCausalLM`
- EXAONE-3
- :code:`upstage/solar-pro-preview-instruct`, etc.
-
* - :code:`XverseForCausalLM`
- Xverse
- :code:`xverse/XVERSE-7B-Chat`, :code:`xverse/XVERSE-13B-Chat`, :code:`xverse/XVERSE-65B-Chat`, etc.
......@@ -230,13 +242,23 @@ Multimodal Language Models
* - :code:`LlavaNextVideoForConditionalGeneration`
- LLaVA-NeXT-Video
- Video
- :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. (see note)
- :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc.
-
* - :code:`LlavaOnevisionForConditionalGeneration`
- LLaVA-Onevision
- Image\ :sup:`+` / Video
- :code:`llava-hf/llava-onevision-qwen2-7b-ov-hf`, :code:`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc.
-
* - :code:`MiniCPMV`
- MiniCPM-V
- Image\ :sup:`+`
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
-
* - :code:`MllamaForConditionalGeneration`
- Llama 3.2
- Image
- :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc.
-
* - :code:`PaliGemmaForConditionalGeneration`
- PaliGemma
- Image\ :sup:`E`
......@@ -276,7 +298,7 @@ Multimodal Language Models
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
.. note::
For :code:`LLaVA-NeXT-Video` and :code:`Qwen2-VL`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now.
For :code:`Qwen2-VL`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now.
This can be installed by running the following command:
.. code-block:: bash
......
......@@ -11,7 +11,7 @@ Below are the steps to utilize BitsAndBytes with vLLM.
.. code-block:: console
$ pip install bitsandbytes>=0.42.0
$ pip install bitsandbytes>=0.44.0
vLLM reads the model's config file and supports both in-flight quantization and pre-quantized checkpoint.
......
.. _gguf:
GGUF
==================
.. warning::
Please note that GGUF support in vLLM is highly experimental and under-optimized at the moment, it might be incompatible with other features. Currently, you can use GGUF as a way to reduce memory footprint. If you encounter any issues, please report them to the vLLM team.
.. warning::
Currently, vllm only supports loading single-file GGUF models. If you have a multi-files GGUF model, you can use `gguf-split <https://github.com/ggerganov/llama.cpp/pull/6135>`_ tool to merge them to a single-file model.
To run a GGUF model with vLLM, you can download and use the local GGUF model from `TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF <https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF>`_ with the following command:
.. code-block:: console
$ wget https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf
$ # We recommend using the tokenizer from base model to avoid long-time and buggy tokenizer conversion.
$ vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlama-1.1B-Chat-v1.0
You can also add ``--tensor-parallel-size 2`` to enable tensor parallelism inference with 2 GPUs:
.. code-block:: console
$ # We recommend using the tokenizer from base model to avoid long-time and buggy tokenizer conversion.
$ vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlama-1.1B-Chat-v1.0 --tensor-parallel-size 2
.. warning::
We recommend using the tokenizer from base model instead of GGUF model. Because the tokenizer conversion from GGUF is time-consuming and unstable, especially for some models with large vocab size.
You can also use the GGUF model directly through the LLM entrypoint:
.. code-block:: python
from vllm import LLM, SamplingParams
# In this script, we demonstrate how to pass input to the chat method:
conversation = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": "Hello"
},
{
"role": "assistant",
"content": "Hello! How can I assist you today?"
},
{
"role": "user",
"content": "Write an essay about the importance of higher education.",
},
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
llm = LLM(model="./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf",
tokenizer="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.chat(conversation, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
......@@ -79,23 +79,17 @@ def initialize_engine(model: str, quantization: str,
# It quantizes the model when loading, with some config info from the
# LoRA adapter repo. So need to set the parameter of load_format and
# qlora_adapter_name_or_path as below.
engine_args = EngineArgs(
model=model,
quantization=quantization,
qlora_adapter_name_or_path=lora_repo,
load_format="bitsandbytes",
enable_lora=True,
max_lora_rank=64,
# set it only in GPUs of limited memory
enforce_eager=True)
engine_args = EngineArgs(model=model,
quantization=quantization,
qlora_adapter_name_or_path=lora_repo,
load_format="bitsandbytes",
enable_lora=True,
max_lora_rank=64)
else:
engine_args = EngineArgs(
model=model,
quantization=quantization,
enable_lora=True,
max_loras=4,
# set it only in GPUs of limited memory
enforce_eager=True)
engine_args = EngineArgs(model=model,
quantization=quantization,
enable_lora=True,
max_loras=4)
return LLMEngine.from_engine_args(engine_args)
......
# ruff: noqa
import json
import random
import string
from vllm import LLM
from vllm.sampling_params import SamplingParams
# This script is an offline demo for function calling
#
# If you want to run a server/client setup, please follow this code:
#
# - Server:
#
# ```bash
# vllm serve mistralai/Mistral-7B-Instruct-v0.3 --tokenizer-mode mistral --load-format mistral --config-format mistral
# ```
#
# - Client:
#
# ```bash
# curl --location 'http://<your-node-url>:8000/v1/chat/completions' \
# --header 'Content-Type: application/json' \
# --header 'Authorization: Bearer token' \
# --data '{
# "model": "mistralai/Mistral-7B-Instruct-v0.3"
# "messages": [
# {
# "role": "user",
# "content": [
# {"type" : "text", "text": "Describe this image in detail please."},
# {"type": "image_url", "image_url": {"url": "https://s3.amazonaws.com/cms.ipressroom.com/338/files/201808/5b894ee1a138352221103195_A680%7Ejogging-edit/A680%7Ejogging-edit_hero.jpg"}},
# {"type" : "text", "text": "and this one as well. Answer in French."},
# {"type": "image_url", "image_url": {"url": "https://www.wolframcloud.com/obj/resourcesystem/images/a0e/a0ee3983-46c6-4c92-b85d-059044639928/6af8cfb971db031b.png"}}
# ]
# }
# ]
# }'
# ```
#
# Usage:
# python demo.py simple
# python demo.py advanced
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
# or switch to "mistralai/Mistral-Nemo-Instruct-2407"
# or "mistralai/Mistral-Large-Instruct-2407"
# or any other mistral model with function calling ability
sampling_params = SamplingParams(max_tokens=8192, temperature=0.0)
llm = LLM(model=model_name,
tokenizer_mode="mistral",
config_format="mistral",
load_format="mistral")
def generate_random_id(length=9):
characters = string.ascii_letters + string.digits
random_id = ''.join(random.choice(characters) for _ in range(length))
return random_id
# simulate an API that can be called
def get_current_weather(city: str, state: str, unit: 'str'):
return (f"The weather in {city}, {state} is 85 degrees {unit}. It is "
"partly cloudly, with highs in the 90's.")
tool_funtions = {"get_current_weather": get_current_weather}
tools = [{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to find the weather for, e.g. 'San Francisco'"
},
"state": {
"type":
"string",
"description":
"the two-letter abbreviation for the state that the city is"
" in, e.g. 'CA' which would mean 'California'"
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["city", "state", "unit"]
}
}
}]
messages = [{
"role":
"user",
"content":
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
}]
outputs = llm.chat(messages, sampling_params=sampling_params, tools=tools)
output = outputs[0].outputs[0].text.strip()
# append the assistant message
messages.append({
"role": "assistant",
"content": output,
})
# let's now actually parse and execute the model's output simulating an API call by using the
# above defined function
tool_calls = json.loads(output)
tool_answers = [
tool_funtions[call['name']](**call['arguments']) for call in tool_calls
]
# append the answer as a tool message and let the LLM give you an answer
messages.append({
"role": "tool",
"content": "\n\n".join(tool_answers),
"tool_call_id": generate_random_id(),
})
outputs = llm.chat(messages, sampling_params, tools=tools)
print(outputs[0].outputs[0].text.strip())
# yields
# 'The weather in Dallas, TX is 85 degrees fahrenheit. '
# 'It is partly cloudly, with highs in the 90's.'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment