Commit 1e77d04e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.5.0-dtk24.04.1' into v0.5.2-dtk24.04.1

# Conflicts:
#	csrc/attention/attention_kernels.cu
#	csrc/attention/attention_utils.cuh
#	csrc/layernorm_kernels.cu
#	vllm/model_executor/layers/linear.py
#	vllm/model_executor/models/baichuan.py
#	vllm/model_executor/models/llama.py
parents 6fa22430 c62f8e9a
......@@ -5,11 +5,13 @@ project(vllm_extensions LANGUAGES CXX)
# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py)
set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM")
set(CMAKE_BUILD_TYPE "Release")
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
add_compile_options(-w)
#
# Supported python versions. These versions will be searched in order, the
# first match will be selected. These should be kept in sync with setup.py.
......@@ -116,7 +118,7 @@ endif()
override_gpu_arches(VLLM_GPU_ARCHES
${VLLM_GPU_LANG}
"${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}")
#
# Query torch for additional GPU compilation flags for the given
# `VLLM_GPU_LANG`.
......@@ -143,8 +145,10 @@ set(VLLM_EXT_SRC
"csrc/cache_kernels.cu"
"csrc/attention/attention_kernels.cu"
"csrc/pos_encoding_kernels.cu"
"csrc/pos_encoding_tgi_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/transpose_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
......
......@@ -15,12 +15,17 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
| LlamaForCausalLM | LLaMA-3 | Yes | Yes |
| LlamaForCausalLM | Codellama | Yes | Yes |
| QWenLMHeadModel | QWen | Yes | Yes |
| Qwen2ForCausalLM | QWen1.5 | Yes | Yes |
| Qwen2ForCausalLM | CodeQwen1.5 | Yes | Yes |
| Qwen2ForCausalLM | QWen2 | Yes | Yes |
| ChatGLMModel | chatglm2 | Yes | Yes |
| ChatGLMModel | chatglm3 | Yes | Yes |
| BaiChuanForCausalLM | Baichuan-7B | Yes | Yes |
| BaiChuanForCausalLM | Baichuan2-7B | Yes | Yes |
| ChatGLMModel | chatglm2-6b | Yes | Yes |
| ChatGLMModel | chatglm3-6b | Yes | Yes |
| InternLMForCausalLM | InternLM | Yes | Yes |
| InternLM2ForCausalLM | InternLM2 | Yes | Yes |
| LlamaForCausalLM | deepseek | Yes | Yes |
| DeepseekV2ForCausalLM | DeepSeek-V2 | Yes | Yes |
| LlamaForCausalLM | Yi | Yes | Yes |
| MixtralForCausalLM | Mixtral-8x7B | Yes | Yes |
......@@ -56,9 +61,15 @@ git clone http://developer.hpccube.com/codes/OpenDAS/vllm.git # 根据需要的
VLLM_INSTALL_PUNICA_KERNELS=1 python setup.py bdist_wheel
cd dist
pip install vllm*
cd csrc/quantization/gptq
python setup.py bdist_wheel
cd dist
pip install gptq_kernel
2. 源码编译安装
VLLM_INSTALL_PUNICA_KERNELS=1 python3 setup.py install
cd csrc/quantization/gptq
python setup.py install
```
#### 运行基础环境准备
......
......@@ -117,6 +117,10 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
"import torch.utils.cpp_extension as t; print(';'.join(t.COMMON_HIP_FLAGS + t.COMMON_HIPCC_FLAGS))"
"Failed to determine torch nvcc compiler flags")
list(REMOVE_ITEM GPU_FLAGS
"-DUSE_ROCM=1"
)
list(APPEND GPU_FLAGS
"-DUSE_ROCM"
# "-DENABLE_FP8"
......@@ -124,7 +128,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
"-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc"
"--gpu-max-threads-per-block=1024")
message(STATUS "${GPU_FLAGS}")
endif()
set(${OUT_GPU_FLAGS} ${GPU_FLAGS} PARENT_SCOPE)
endfunction()
......
/*
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
......@@ -39,6 +20,8 @@ typedef __hip_bfloat16 __nv_bfloat16;
#define WARP_SIZE warpSize
#endif
#include "static_switch.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
......@@ -86,7 +69,9 @@ inline __device__ float block_sum(float* red_smem, float sum) {
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE,
int PARTITION_SIZE = 0> // Zero means no partitioning.
int REUSE_KV_TIMES = 1,
bool odd_nheads = false,
int PARTITION_SIZE = 0,std::enable_if_t<!std::is_same<scalar_t, uint16_t>::value, int> = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
......@@ -98,7 +83,39 @@ __device__ void paged_attention_kernel(
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, // [num_heads]
const int num_heads, // [num_heads]
const int num_kv_heads, // [num_kv_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {}
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE,
int REUSE_KV_TIMES = 1,
bool odd_nheads = false,
int PARTITION_SIZE = 0,std::enable_if_t<std::is_same<scalar_t, uint16_t>::value, int> = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_heads, // [num_heads]
const int num_kv_heads, // [num_kv_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
......@@ -108,9 +125,9 @@ __device__ void paged_attention_kernel(
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
const int seq_idx = blockIdx.y;
const int partition_idx = blockIdx.z;
const int max_num_partitions = gridDim.z;
const int seq_idx = blockIdx.z;
const int partition_idx = blockIdx.y;
const int max_num_partitions = gridDim.y;
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
const int seq_len = seq_lens[seq_idx];
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
......@@ -121,7 +138,7 @@ __device__ void paged_attention_kernel(
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_blocks_per_partition =
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE;
// [start_block_idx, end_block_idx) is the range of blocks to process.
const int start_block_idx =
USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
......@@ -144,22 +161,38 @@ __device__ void paged_attention_kernel(
DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x;
const int warp_idx = thread_idx / WARP_SIZE;
// const int warp_idx_vec = thread_idx / WARP_SIZE;
// int warp_idx =0;
// asm volatile("v_readfirstlane_b32 %0,%1"
// : "=s"(warp_idx)
// : "v"(warp_idx_vec)
// :);
// // const int warp_idx = thread_idx / WARP_SIZE;
// const int lane = thread_idx % WARP_SIZE;
//const int warp_idx = thread_idx / WARP_SIZE;
const int lane = thread_idx % WARP_SIZE;
const int head_idx = blockIdx.x;
const int num_heads = gridDim.x;
int warp_id_vec = threadIdx.x / WARP_SIZE; //warp id in a block
int warp_idx =0;
asm volatile("v_readfirstlane_b32 %0,%1"
: "=s"(warp_idx)
: "v"(warp_id_vec)
:);
// const int head_idx = blockIdx.x;
// const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / num_kv_heads;
const int kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope =
alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
// const float alibi_slope =
// alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread
// group fetch or compute 16 bytes at a time. For example, if the size of a
// thread group is 4 and the data type is half, then the vector size is 16 /
// (4 * sizeof(half)) == 2.
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
constexpr int VEC_SIZE = MAX(32 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
......@@ -176,61 +209,89 @@ __device__ void paged_attention_kernel(
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
#pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[thread_group_offset][i] =
*reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
}
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
// memory wall right before we use q_vecs
// const scalar_t* q_ptr = q + seq_idx * q_stride;
const scalar_t* q_ptr_offset = q + seq_idx * q_stride;
__shared__ Q_vec q_vecs[REUSE_KV_TIMES * THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
// #pragma unroll
// for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
// i += NUM_THREAD_GROUPS) {
// const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
// q_vecs[thread_group_offset][i] =
// *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
// }
// __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
// // memory wall right before we use q_vecs
// Memory planning.
extern __shared__ char shared_mem[];
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
float* logits = reinterpret_cast<float*>(shared_mem);
// Workspace for reduction.
__shared__ float red_smem[2 * NUM_WARPS];
__shared__ float red_smem[REUSE_KV_TIMES][2 * NUM_WARPS];
// float (*red_smem)[2 * NUM_WARPS] = reinterpret_cast<float(*)[2 * NUM_WARPS]>(&shared_mem[10*1024]);
// __shared__ char shared_mem[12 * 1024];
// float* logits = reinterpret_cast<float*>(shared_mem);
// __shared__ float red_smem[REUSE_KV_TIMES][2 * NUM_WARPS];
// x == THREAD_GROUP_SIZE * VEC_SIZE
// Each thread group fetches x elements from the key at a time.
constexpr int x = 16 / sizeof(cache_t);
float qk_max = -FLT_MAX;
float qk_max[REUSE_KV_TIMES];
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
qk_max[reuse_kv_idx] = -FLT_MAX;
}
const int num_blocks_per_kv = ((num_queries_per_kv + REUSE_KV_TIMES -1) / REUSE_KV_TIMES);
const int head_idx_soffset = (blockIdx.x / num_blocks_per_kv) * num_queries_per_kv + (blockIdx.x % num_blocks_per_kv) * REUSE_KV_TIMES;
const int kv_head_idx = head_idx_soffset / num_queries_per_kv;
const int q_boundary = (kv_head_idx + 1)* num_queries_per_kv;
#pragma unroll
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
const int head_idx = head_idx_soffset + reuse_kv_idx;//blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
const scalar_t* q_ptr = q_ptr_offset + head_idx * HEAD_SIZE;
#pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[reuse_kv_idx*THREAD_GROUP_SIZE + thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
}
}
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
// Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes
// dot product with the query.
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
// blocksparse specific vars
int bs_block_offset;
int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0)
// sliding on q heads
bs_block_offset =
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
else
// sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) +
1;
}
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) {
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
const int head_idx = head_idx_soffset + reuse_kv_idx;//blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
if(!odd_nheads || head_idx < q_boundary) {
// blocksparse specific vars
int bs_block_offset;
int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0)
// sliding on q heads
bs_block_offset =
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
else
// sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) +
1;
}
if constexpr (IS_BLOCK_SPARSE) {
const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
const bool is_remote =
......@@ -254,8 +315,8 @@ __device__ void paged_attention_kernel(
continue;
}
}
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
......@@ -263,99 +324,104 @@ __device__ void paged_attention_kernel(
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
// has 1, 5, 9, ... th vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset =
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
K_vec k_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const cache_t* k_ptr =
k_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
k_vecs[j] = *reinterpret_cast<const K_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
} else {
// Vector conversion from Quant_vec to K_vec.
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, kv_scale);
if(reuse_kv_idx == 0) {
#pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const cache_t* k_ptr =
k_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
k_vecs[j] = *reinterpret_cast<const K_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
} else {
// Vector conversion from Quant_vec to K_vec.
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, kv_scale);
}
}
}
__builtin_amdgcn_sched_barrier(0);
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
q_vecs[thread_group_offset], k_vecs);
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[reuse_kv_idx*THREAD_GROUP_SIZE + thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
__builtin_amdgcn_sched_barrier(0);
if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= seq_len;
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
logits[(reuse_kv_idx * partition_size) + (token_idx - start_token_idx)] = mask ? 0.f : qk;
// Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
qk_max[reuse_kv_idx] = mask ? qk_max[reuse_kv_idx] : fmaxf(qk_max[reuse_kv_idx], qk);
}
}
}
}
}
// Get the sum of the exp values.
float exp_sum[REUSE_KV_TIMES] = {0.f};
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
if (lane == 0) {
red_smem[warp_idx] = qk_max;
}
__syncthreads();
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
// Broadcast the max qk value to all threads.
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
const int head_idx = head_idx_soffset + reuse_kv_idx;
if(!odd_nheads || head_idx < q_boundary) {
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max[reuse_kv_idx] = fmaxf(qk_max[reuse_kv_idx], VLLM_SHFL_XOR_SYNC(qk_max[reuse_kv_idx], mask));
}
if (lane == 0) {
red_smem[reuse_kv_idx][warp_idx] = qk_max[reuse_kv_idx];
}
__syncthreads();
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
qk_max[reuse_kv_idx] = lane < NUM_WARPS ? red_smem[reuse_kv_idx][lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max[reuse_kv_idx] = fmaxf(qk_max[reuse_kv_idx], VLLM_SHFL_XOR_SYNC(qk_max[reuse_kv_idx], mask));
}
// Broadcast the max qk value to all threads.
qk_max[reuse_kv_idx] = VLLM_SHFL_SYNC(qk_max[reuse_kv_idx], 0);
// Get the sum of the exp values.
float exp_sum = 0.f;
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(logits[i] - qk_max);
logits[i] = val;
exp_sum += val;
}
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(logits[(reuse_kv_idx * partition_size) + i] - qk_max[reuse_kv_idx]);
logits[(reuse_kv_idx * partition_size) + i] = val;
exp_sum[reuse_kv_idx] += val;
}
exp_sum[reuse_kv_idx] = block_sum<NUM_WARPS>(&red_smem[reuse_kv_idx][NUM_WARPS], exp_sum[reuse_kv_idx]);
// Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
logits[i] *= inv_sum;
}
__syncthreads();
// Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum[reuse_kv_idx] + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
logits[(reuse_kv_idx * partition_size) + i] *= inv_sum;
}
__syncthreads();
// If partitioning is enabled, store the max logit and exp_sum.
if (USE_PARTITIONING && thread_idx == 0) {
float* max_logits_ptr = max_logits +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
*max_logits_ptr = qk_max;
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
*exp_sums_ptr = exp_sum;
// If partitioning is enabled, store the max logit and exp_sum.
if (USE_PARTITIONING && thread_idx == 0) {
float* max_logits_ptr = max_logits +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
*max_logits_ptr = qk_max[reuse_kv_idx];
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
*exp_sums_ptr = exp_sum[reuse_kv_idx];
}
}
}
// Each thread will fetch 16 bytes from the value cache at a time.
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
......@@ -369,44 +435,74 @@ __device__ void paged_attention_kernel(
DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float accs[NUM_ROWS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[i] = 0.f;
}
float accs[REUSE_KV_TIMES][NUM_ROWS_PER_THREAD];
#pragma unroll
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[reuse_kv_idx][i] = 0.f;
}
}
scalar_t zero_value;
zero(zero_value);
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
if constexpr (IS_BLOCK_SPARSE) {
int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
!((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
continue;
}
}
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
start_token_idx));
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
V_vec v_vec;
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
// blocksparse specific vars
const int head_idx = head_idx_soffset + reuse_kv_idx;
int bs_block_offset;
int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0)
// sliding on q heads
bs_block_offset =
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
else
// sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) +
1;
}
if constexpr (IS_BLOCK_SPARSE) {
int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
!((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
continue;
}
}
if(!odd_nheads || head_idx < q_boundary) {
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + (reuse_kv_idx * partition_size) + token_idx - start_token_idx));
// scalar_t* logits_vec_ptr = reinterpret_cast<scalar_t*>(&logits_vec);
// for(int i=0;i<8;++i){
// from_float(*(logits_vec_ptr+i), 1000);
// }
if(reuse_kv_idx==0) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
V_vec v_vec;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
......@@ -428,20 +524,41 @@ __device__ void paged_attention_kernel(
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
}
}
accs[i] += dot(logits_vec, v_vec);
// if(threadIdx.x==0){
// scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
// scalar_t* logits_vec_ptr = reinterpret_cast<scalar_t*>(&logits_vec);
// for(int i=0;i<8;++i){
// printf("v_vec[%d] = %f\n",i, half_to_float(v_vec_ptr[i]));
// // from_float(*(v_vec_ptr + i), 1000);
// }
// for(int i=0;i<8;++i){
// printf("logits_vec[%d] = %f\n",i,half_to_float(logits_vec_ptr[i]));
// // from_float(*(logits_vec_ptr + i), 1000);
// }
// }
// accs[reuse_kv_idx][i] += dot(logits_vec, v_vec);
}
}
accs[reuse_kv_idx][i] += dot(logits_vec, v_vec);
}
}
}
}
// Perform reduction within each warp.
#pragma unroll
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
int head_idx = head_idx_soffset + reuse_kv_idx;
if(!odd_nheads || head_idx < q_boundary) {
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[i];
float acc = accs[reuse_kv_idx][i];
#pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
}
accs[i] = acc;
accs[reuse_kv_idx][i] = acc;
}
// NOTE(woosuk): A barrier is required because the shared memory space for
......@@ -455,12 +572,12 @@ __device__ void paged_attention_kernel(
int mid = i / 2;
// Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) {
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
float* dst = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE) + (warp_idx - mid) * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
dst[row_idx] = accs[i];
dst[row_idx] = accs[reuse_kv_idx][i];
}
}
}
......@@ -468,12 +585,12 @@ __device__ void paged_attention_kernel(
// Lower warps update the output.
if (warp_idx < mid) {
const float* src = &out_smem[warp_idx * HEAD_SIZE];
const float* src = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE) + warp_idx * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
accs[i] += src[row_idx];
accs[reuse_kv_idx][i] += src[row_idx];
}
}
}
......@@ -489,23 +606,29 @@ __device__ void paged_attention_kernel(
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
from_float(*(out_ptr + row_idx), accs[i]);
from_float(*(out_ptr + row_idx), accs[reuse_kv_idx][i]);
}
}
}
}
}
}
// Grid: (num_heads, num_seqs, 1).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE>
__global__ void paged_attention_v1_kernel(
int REUSE_KV_TIMES = 1,
bool IS_BLOCK_SPARSE,
bool odd_nheads = false>
__global__ __launch_bounds__(256,1) void paged_attention_v1_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_heads, // [num_heads]
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
......@@ -516,22 +639,24 @@ __global__ void paged_attention_v1_kernel(
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
v_cache, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
v_cache, num_heads, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
}
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE,
int PARTITION_SIZE>
__global__ void paged_attention_v2_kernel(
int REUSE_KV_TIMES,
int PARTITION_SIZE,
bool odd_nheads = false>
__global__ __launch_bounds__(256,1) void paged_attention_v2_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
......@@ -542,7 +667,8 @@ __global__ void paged_attention_v2_kernel(
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, // [num_heads]
const int num_heads, // [num_heads]
const int num_kv_heads, // [num_kv_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
......@@ -552,19 +678,19 @@ __global__ void paged_attention_v2_kernel(
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, kv_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_heads, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, kv_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
}
// Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
int PARTITION_SIZE>
__global__ void paged_attention_v2_reduce_kernel(
__global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions]
......@@ -674,22 +800,32 @@ __global__ void paged_attention_v2_reduce_kernel(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \
BLOCK_SIZE, NUM_THREADS, \
KV_DTYPE, IS_BLOCK_SPARSE>), \
KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>), \
shared_mem_size); \
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE> \
<<<grid, block, shared_mem_size, stream>>>( \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
hipLaunchKernelGGL(( vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>) \
, dim3(grid), dim3(block), shared_mem_size, stream, \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
kv_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
// #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
// vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
// NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads> \
// <<<dim3(grid), dim3(block)>>>( \
// out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
// scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
// alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
// kv_scale, tp_rank, blocksparse_local_blocks, \
// blocksparse_vert_stride, blocksparse_block_size, \
// blocksparse_head_sliding_step);
// TODO(woosuk): Tune NUM_THREADS.
template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
int NUM_THREADS = 128>
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE>
void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
......@@ -705,7 +841,10 @@ void paged_attention_v1_launcher(
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
int num_threads = 128;
if(num_heads!=num_kv_heads){
num_threads =256;
}
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0);
......@@ -722,48 +861,31 @@ void paged_attention_v1_launcher(
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_seq_len =
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_seq_len * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
int shared_mem_size = std::max(logits_size, outputs_size);
dim3 grid(num_heads, num_seqs, 1);
dim3 block(NUM_THREADS);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (head_size) {
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
// to support any head size which is a multiple of 16.
case 64:
LAUNCH_PAGED_ATTENTION_V1(64);
break;
case 80:
LAUNCH_PAGED_ATTENTION_V1(80);
break;
case 96:
LAUNCH_PAGED_ATTENTION_V1(96);
break;
case 112:
LAUNCH_PAGED_ATTENTION_V1(112);
break;
case 128:
LAUNCH_PAGED_ATTENTION_V1(128);
break;
case 192:
LAUNCH_PAGED_ATTENTION_V1(192);
break;
case 256:
LAUNCH_PAGED_ATTENTION_V1(256);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
REUSEKV_SWITCH_V1(num_heads * num_seqs , [&] {
BOOL_SWITCH((num_heads/num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] {
HEADSIZE_SWITCH(head_size, [&] {
NUM_THREADS_SWITCH(num_threads, [&] {
OPT_SWITCH(num_heads == num_kv_heads, [&] {
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int logits_size = REUSE_KV_TIMES*padded_max_seq_len * sizeof(float);
int outputs_size = REUSE_KV_TIMES*(NUM_WARPS / 2) * head_size * sizeof(float);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
int shared_mem_size = ::max(logits_size, outputs_size);
if(num_heads == num_kv_heads) shared_mem_size = ::max(12 * 1024, shared_mem_size);
// int shared_mem_size = ::max(31*1024, ::max(logits_size, outputs_size));
// std::cout<<"shared_mem_size = "<<shared_mem_size<<std::endl;
dim3 grid((num_heads/num_kv_heads + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES*num_kv_heads, 1, num_seqs);
dim3 block(NUM_THREADS);
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query));
const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE);
});
});
});
});
});
}
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
......@@ -788,20 +910,25 @@ void paged_attention_v1_launcher(
// 1, 2, 4, 64, 128, 256.
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
case 8: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \
case 32: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
// // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// // 1, 2, 4, 64, 128, 256.
// #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
// switch (block_size) { \
// case 16: \
// CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
// break; \
// TORCH_CHECK(false, "Unsupported block size: ", block_size); \
// break; \
// }
void paged_attention_v1(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
......@@ -826,19 +953,19 @@ void paged_attention_v1(
}
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
hipLaunchKernelGGL(( vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
PARTITION_SIZE> \
<<<grid, block, shared_mem_size, stream>>>( \
REUSE_KV_TIMES, PARTITION_SIZE, odd_nheads>) \
, dim3(grid), dim3(block), shared_mem_size, stream, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, kv_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE> \
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>) \
, dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions);
......@@ -883,48 +1010,33 @@ void paged_attention_v2_launcher(
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
int logits_size = PARTITION_SIZE * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
// For paged attention v2 kernel.
dim3 grid(num_heads, num_seqs, max_num_partitions);
int shared_mem_size = std::max(logits_size, outputs_size);
// For paged attention v2 reduce kernel.
dim3 reduce_grid(num_heads, num_seqs);
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
dim3 block(NUM_THREADS);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (head_size) {
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
// to support any head size which is a multiple of 16.
case 64:
LAUNCH_PAGED_ATTENTION_V2(64);
break;
case 80:
LAUNCH_PAGED_ATTENTION_V2(80);
break;
case 96:
LAUNCH_PAGED_ATTENTION_V2(96);
break;
case 112:
LAUNCH_PAGED_ATTENTION_V2(112);
break;
case 128:
LAUNCH_PAGED_ATTENTION_V2(128);
break;
case 192:
LAUNCH_PAGED_ATTENTION_V2(192);
break;
case 256:
LAUNCH_PAGED_ATTENTION_V2(256);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
REUSEKV_SWITCH(num_heads * max_num_partitions * num_seqs , [&] {
BOOL_SWITCH((num_heads/num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] {
HEADSIZE_SWITCH(head_size, [&] {
OPT_SWITCH(num_heads == num_kv_heads, [&] {
int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * sizeof(float);
int outputs_size = REUSE_KV_TIMES*(NUM_WARPS / 2) * head_size * sizeof(float);
// For paged attention v2 kernel.
// dim3 grid(num_heads, max_num_partitions, num_seqs);
dim3 grid;
grid.x = (num_heads/num_kv_heads + REUSE_KV_TIMES -1)/REUSE_KV_TIMES * num_kv_heads;
grid.y = max_num_partitions;
grid.z = num_seqs;
// int shared_mem_size = ::max(1024*32, ::max(logits_size, outputs_size));
int shared_mem_size = ::max(logits_size, outputs_size);
// For paged attention v2 reduce kernel.
dim3 reduce_grid(num_heads, num_seqs);
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
dim3 block(NUM_THREADS);
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query));
const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE);
});
});
});
});
}
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
......@@ -949,20 +1061,25 @@ void paged_attention_v2_launcher(
// 1, 2, 4, 64, 128, 256.
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
case 8: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \
case 32: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
// // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// // 1, 2, 4, 64, 128, 256.
// #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
// switch (block_size) { \
// case 16: \
// CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
// break; \
// TORCH_CHECK(false, "Unsupported block size: ", block_size); \
// break; \
// }
void paged_attention_v2(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
......
......@@ -26,17 +26,104 @@
namespace vllm {
// Q*K^T operation.
inline __device__ void v_dot2_f32_f16(float& a, const uint32_t & b,const uint32_t & c) {
asm volatile("v_dot2_f32_f16 %0, %1, %2, %0;": "=v"(a): "v"(b), "v"(c), "0"(a));
}
inline __device__ void v_pk_fma_f16(uint32_t& a, const uint32_t & b,const uint32_t & c){
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;": "=v"(a) : "v"(b), "v"(c), "v"(a));
}
inline __device__ void ds_read_b128(uint4& a, uint32_t offset){
asm volatile("ds_read_b128 %0 %1;": "=v" (a): "v" (offset));
}
inline __device__ void ds_read_b128_sync(uint4& a, uint32_t offset){
asm volatile("ds_read_b128 %0 %1\ns_waitcnt lgkmcnt(1);": "=v" (a): "v" (offset));
}
inline __device__ void lgkmcnt0(){
asm volatile("s_waitcnt lgkmcnt(0);");
}
__device__ inline size_t __nv_cvta_generic_to_shared_impl(const void *__ptr) {
return (size_t)(void __attribute__((address_space(3))) *)__ptr;
}
inline __device__ void v_dot2_f32_f16(float& a,const uint2 & b,const uint2 & c) {
v_dot2_f32_f16(a, b.x, c.x);
v_dot2_f32_f16(a, b.y, c.y);
}
inline __device__ void v_dot2_f32_f16(float& a,const uint4 & b,const uint4 & c) {
v_dot2_f32_f16(a, b.x, c.x);
v_dot2_f32_f16(a, b.y, c.y);
v_dot2_f32_f16(a, b.z, c.z);
v_dot2_f32_f16(a, b.w, c.w);
}
inline __device__ float add_half2(uint32_t a){
union {
uint32_t u32;
half u16[2];
} tmp;
tmp.u32=a;
return static_cast<float>(tmp.u16[0]+tmp.u16[1]);
}
inline __device__ void v_pk_fma_f16x8(float& a,const uint4 & b,const uint4 & c) {
uint32_t tmp = mul<uint32_t, uint32_t, uint32_t>(b.x,c.x);
v_pk_fma_f16(tmp,b.y,c.y);
v_pk_fma_f16(tmp,b.z,c.z);
v_pk_fma_f16(tmp,b.w,c.w);
a+=add_half2(tmp);
}
// Q*K^T operation. fp16
// template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<std::is_same<scalar_t, uint16_t>::value, int> = 0>
template <int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
float qk =0;
// uint32_t offset = __nv_cvta_generic_to_shared_impl(q);
// const uint4 *k_ptr= reinterpret_cast<const uint4 *>(k);
// // Compute the parallel products for Q*K^T (treat vector lanes separately).
// constexpr int loop=N*sizeof(Vec)/16/2;
// uint4 qt[2];
// #pragma unroll
// for (int ii = 0; ii < loop; ++ii) {
// ds_read_b128(qt[0],offset+16*ii*2);
// ds_read_b128_sync(qt[1],offset+16*(ii*2+1));
// v_dot2_f32_f16(qk,qt[0],k_ptr[ii*2]);
// // v_pk_fma_f16x8(qk,qt[0],k_ptr[ii*2]);
// lgkmcnt0();
// v_dot2_f32_f16(qk,qt[1],k_ptr[ii*2+1]);
// // v_pk_fma_f16x8(qk,qt[1],k_ptr[ii*2+1]);
// }
#pragma unroll
for (int ii = 0; ii < N; ++ii) {
v_dot2_f32_f16(qk,q[ii],k[ii]);
}
// Finalize the reduction across lanes.
#pragma unroll
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
qk += VLLM_SHFL_XOR_SYNC(qk, mask);
}
return qk;
}
// Q*K^T operation. //bf16
// template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<!std::is_same<scalar_t, uint16_t>::value, int> = 0>
template <int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_vpack_(const Vec (&q)[N], const Vec (&k)[N]) {
using A_vec = typename FloatVec<Vec>::Type;
// Compute the parallel products for Q*K^T (treat vector lanes separately).
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
#pragma unroll
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec);
}
// Finalize the reduction across lanes.
float qk = sum(qk_vec);
#pragma unroll
......@@ -46,12 +133,17 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
return qk;
}
template <typename T, int THREAD_GROUP_SIZE>
struct Qk_dot {
template <typename Vec, int N>
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
return qk_dot_<THREAD_GROUP_SIZE>(q, k);
}
// template <typename Vec, int N>
// static inline __device__ float qk_dot_vpack(const Vec (&q)[N], const Vec (&k)[N]) {
// return qk_dot_vpack_<THREAD_GROUP_SIZE>(q, k);
// }
};
} // namespace vllm
\ No newline at end of file
} // namespace vllm
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#define OPT_SWITCH(COND, ...) \
[&] { \
if (COND) { \
constexpr static int opt = 1; \
return __VA_ARGS__(); \
} else { \
constexpr static int opt = 2; \
return __VA_ARGS__(); \
} \
}()
#define NUM_THREADS_SWITCH(NUM_THREAD, ...) \
[&] { \
if (NUM_THREAD == 256) { \
constexpr static int NUM_THREADS = 256; \
return __VA_ARGS__(); \
} else { \
constexpr static int NUM_THREADS = 128; \
return __VA_ARGS__(); \
} \
}()
// #define HEADSIZE_SWITCH(HEADDIM, ...) \
// [&] { \
// if (HEADDIM == 64) { \
// constexpr static int HEAD_SIZE = 64; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 80) { \
// constexpr static int HEAD_SIZE = 80; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 96) { \
// constexpr static int HEAD_SIZE = 96; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 112) { \
// constexpr static int HEAD_SIZE = 112; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 128) { \
// constexpr static int HEAD_SIZE = 128; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 256) { \
// constexpr static int HEAD_SIZE = 256; \
// return __VA_ARGS__(); \
// } \
// else { \
// TORCH_CHECK(false, "Unsupported head size: ", HEADDIM);\
// } \
// }()
#define HEADSIZE_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM == 128) { \
constexpr static int HEAD_SIZE = 128; \
return __VA_ARGS__(); \
} else { \
TORCH_CHECK(false, "Unsupported head size: ", HEADDIM);\
} \
}()
#define REUSEKV_SWITCH(num_blocks , ...) \
[&] { \
if (num_heads % 2 == 0 && num_heads / num_kv_heads >= 4 && num_blocks >= 1200){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
} else if (num_heads / num_kv_heads >= 2 && num_blocks >= 1200){\
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define REUSEKV_SWITCH_V1(num_blocks , ...) \
[&] { \
if (num_heads > num_kv_heads && num_blocks >= 1200){ \
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <c10/cuda/CUDAMathCompat.h>
#include <ATen/AccumulateType.h>
#include <THC/THCDeviceUtils.cuh>
#include "dispatch_utils.h"
#include "reduction_utils.cuh"
#ifndef USE_ROCM
......@@ -288,22 +291,150 @@ fused_add_rms_norm_kernel(
} // namespace vllm
template <typename T,int reducesize=C10_WARP_SIZE>
__inline__ __device__ T WarpReduceSum_NEW(T val) {
#pragma unroll
for (int offset = reducesize/2; offset > 0; offset >>= 1) {
val += WARP_SHFL_DOWN(val, offset);
}
return val;
}
template <typename T,int block_size=512>
__inline__ __device__ T BlockReduceSum_NEW(T val, T* shared) {
constexpr int share_size=block_size/C10_WARP_SIZE;
val = WarpReduceSum_NEW<T>(val);
if constexpr(block_size==C10_WARP_SIZE)
{
return val;
}
else{
const int lid = threadIdx.x % C10_WARP_SIZE;
const int wid = threadIdx.x / C10_WARP_SIZE;
__syncthreads();
if (lid == 0&&wid<share_size) {
shared[wid] = val;
}
__syncthreads();
if (wid == 0&&lid<share_size) {
val = WarpReduceSum_NEW<T,share_size>(shared[lid]);
}
return val;
}
}
template <typename scalar_t,typename T_ACC,int Vec=4,int block_size=512>
__global__ void fused_add_rms_kernel_eval(scalar_t* input,scalar_t* residual,scalar_t* gamma,int cols,T_ACC eps)
{
constexpr int share_size=block_size/C10_WARP_SIZE;
__shared__ T_ACC val_shared[share_size];
__shared__ T_ACC s_rstd;
T_ACC val=0;
int i=blockIdx.x;
int j=threadIdx.x;
int tcol=cols/Vec;
if(j>=tcol)return;
using LoadT = at::native::memory::aligned_vector<scalar_t, Vec>;
scalar_t intput_vec[Vec];
scalar_t residual_vec[Vec];
T_ACC trstd;
int idx = i * tcol + j;
idx*=Vec;
*(LoadT*)intput_vec = *(LoadT*)(input+idx);
*(LoadT*)residual_vec = *(LoadT*)(residual+idx);
#pragma unroll
for (int ii = 0; ii < Vec; ii++) {
residual_vec[ii]+=intput_vec[ii];
val += static_cast<T_ACC>(residual_vec[ii])*static_cast<T_ACC>(residual_vec[ii]);
}
val = BlockReduceSum_NEW<T_ACC,block_size>(val,val_shared);
if (j == 0) s_rstd=c10::cuda::compat::rsqrt(val/cols + eps);
__syncthreads();
trstd=s_rstd;
#pragma unroll
for(int ii=0;ii<Vec;ii++){
int jj=j*Vec+ii;
intput_vec[ii] = static_cast<T_ACC>(residual_vec[ii]) * trstd * static_cast<T_ACC>(gamma[jj]);
}
*(LoadT*)(residual+idx)=*(LoadT*)residual_vec;
*(LoadT*)(input+idx)=*(LoadT*)intput_vec;
}
template <typename scalar_t,typename T_ACC,int Vec=4,int block_size=512>
__global__ void fused_rms_kernel_eval(scalar_t* input,scalar_t* output,scalar_t* gamma,int cols,T_ACC eps)
{
constexpr int share_size=block_size/C10_WARP_SIZE;
__shared__ T_ACC val_shared[share_size];
__shared__ T_ACC s_rstd;
T_ACC val=0;
int i=blockIdx.x;
int j=threadIdx.x;
int tcol=cols/Vec;
if(j>=tcol)return;
using LoadT = at::native::memory::aligned_vector<scalar_t, Vec>;
scalar_t intput_vec[Vec];
T_ACC trstd;
int idx = i * tcol + j;
idx*=Vec;
*(LoadT*)intput_vec = *(LoadT*)(input+idx);
#pragma unroll
for (int ii = 0; ii < Vec; ii++) {
val += static_cast<T_ACC>(intput_vec[ii])*static_cast<T_ACC>(intput_vec[ii]);
}
val = BlockReduceSum_NEW<T_ACC,block_size>(val,val_shared);
if (j == 0) s_rstd=c10::cuda::compat::rsqrt(val/cols + eps);
__syncthreads();
trstd=s_rstd;
#pragma unroll
for(int ii=0;ii<Vec;ii++){
int jj=j*Vec+ii;
intput_vec[ii] = static_cast<T_ACC>(intput_vec[ii]) * trstd * static_cast<T_ACC>(gamma[jj]);
}
*(LoadT*)(output+idx)=*(LoadT*)intput_vec;
}
void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
double epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
});
if(hidden_size%16==0&&hidden_size>=2048&&hidden_size<=8192){
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"fused_add_rms_norm_kernel",
[&] {
using T_ACC = at::acc_type<scalar_t, true>;
T_ACC eps = epsilon;
scalar_t* self_data = input.data_ptr<scalar_t>();
scalar_t* out_data =out.data_ptr<scalar_t>();
scalar_t* weight_data=weight.data_ptr<scalar_t>();
if(hidden_size==2048){
fused_rms_kernel_eval<scalar_t,T_ACC,2,1024><<<num_tokens, 1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
else if(hidden_size<=4096){
fused_rms_kernel_eval<scalar_t,T_ACC,4,1024><<<num_tokens, 1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
else{
fused_rms_kernel_eval<scalar_t,T_ACC,8,1024><<<num_tokens, 1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
});
}
else{
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
});
}
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
......@@ -316,37 +447,64 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
num_tokens, hidden_size); \
});
void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
double epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 block(std::min(hidden_size, max_block_size));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) {
LAUNCH_FUSED_ADD_RMS_NORM(8);
} else {
LAUNCH_FUSED_ADD_RMS_NORM(0);
if(hidden_size%16==0&&hidden_size>=2048&&hidden_size<=8192){
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"fused_add_rms_norm_kernel",
[&] {
using T_ACC = at::acc_type<scalar_t, true>;
T_ACC eps = epsilon;
scalar_t* self_data = input.data_ptr<scalar_t>();
scalar_t* other_data =residual.data_ptr<scalar_t>();
scalar_t* weight_data=weight.data_ptr<scalar_t>();
if(hidden_size==2048){
fused_add_rms_kernel_eval<scalar_t,T_ACC,2,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else if(hidden_size<=4096){
fused_add_rms_kernel_eval<scalar_t,T_ACC,4,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else{
fused_add_rms_kernel_eval<scalar_t,T_ACC,8,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
});
}
}
\ No newline at end of file
else{
dim3 grid(num_tokens);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 block(std::min(hidden_size, max_block_size));
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) {
LAUNCH_FUSED_ADD_RMS_NORM(8);
} else {
LAUNCH_FUSED_ADD_RMS_NORM(0);
}
}
}
......@@ -39,6 +39,13 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& cos_sin_cache, bool is_neox,
int64_t rot_dim,
torch::Tensor& cos_sin_cache_offsets);
void rotary_embedding_tgi(
torch::Tensor& query,
torch::Tensor& key,
int64_t head_size,
torch::Tensor& cos_cache,
torch::Tensor& sin_cache,
bool is_neox);
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
......@@ -123,6 +130,8 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
void trans_w16_gemm(torch::Tensor dst, torch::Tensor src, int64_t row, int64_t col);
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& scale);
......
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "dispatch_utils.h"
namespace vllm {
template<typename scalar_t, bool IS_NEOX>
inline __device__ void apply_token_rotary_embedding_tgi(
scalar_t* __restrict__ arr,
const float* __restrict__ cos_ptr,
const float* __restrict__ sin_ptr,
int rot_offset,
int embed_dim)
{
int x_index, y_index;
float cos, sin;
if (IS_NEOX) {
// GPT-NeoX style rotary embedding.
x_index = rot_offset;
y_index = embed_dim + rot_offset;
cos = VLLM_LDG(cos_ptr + x_index);
sin = VLLM_LDG(sin_ptr + x_index);
} else {
// GPT-J style rotary embedding.
x_index = 2 * rot_offset;
y_index = 2 * rot_offset + 1;
cos = VLLM_LDG(cos_ptr + x_index / 2);
sin = VLLM_LDG(sin_ptr + x_index / 2);
}
const scalar_t x = arr[x_index];
const scalar_t y = arr[y_index];
arr[x_index] = x * cos - y * sin;
arr[y_index] = y * cos + x * sin;
}
template<typename scalar_t, bool IS_NEOX>
inline __device__ void apply_rotary_embedding_tgi(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const float* __restrict__ cos_ptr, // [max_position, 1, rot_dim]
const float* __restrict__ sin_ptr, // [max_position, 1, rot_dim]
const int head_size,
const int num_heads,
const int num_kv_heads,
const int rot_dim,
const int token_idx,
const int64_t query_stride,
const int64_t key_stride)
{
const int nq = num_heads * rot_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / rot_dim;
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % rot_dim;
apply_token_rotary_embedding_tgi<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
sin_ptr, rot_offset, rot_dim);
}
const int nk = num_kv_heads * rot_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / rot_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % rot_dim;
apply_token_rotary_embedding_tgi<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
sin_ptr, rot_offset, rot_dim);
}
}
template<typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_tgi_kernel(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const float* __restrict__ cos_cache, // [max_position, 1, rot_dim]
const float* __restrict__ sin_cache, // [max_position, 1, rot_dim]
const int rot_dim,
const int64_t query_stride,
const int64_t key_stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
const float* cos_ptr = cos_cache + token_idx * rot_dim;
const float* sin_ptr = sin_cache + token_idx * rot_dim;
apply_rotary_embedding_tgi<scalar_t, IS_NEOX>(query, key, cos_ptr, sin_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
}
} // namespace vllm
void rotary_embedding_tgi(
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
int64_t head_size,
torch::Tensor& cos_cache,
torch::Tensor& sin_cache,
bool is_neox) {
int num_tokens = query.size(0);
int rot_dim = cos_cache.size(2);
int num_heads = query.size(1);
int num_kv_heads = key.size(1);
int query_stride = query.stride(0);
int key_stride = key.stride(0);
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(),
"rotary_embedding_tgi",
[&] {
if (is_neox) {
vllm::rotary_embedding_tgi_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_cache.data_ptr<float>(),
sin_cache.data_ptr<float>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
} else {
vllm::rotary_embedding_tgi_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_cache.data_ptr<float>(),
sin_cache.data_ptr<float>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
}
});
}
......@@ -1542,6 +1542,7 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
}
}
__global__ void shuffle_4bit_kernel(uint32_t* __restrict__ b_q_weight,
const int size_k, const int size_n) {
int n = blockIdx.x * THREADS_X + threadIdx.x;
......@@ -1847,6 +1848,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
return c;
}
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
vllm::gptq::shuffle_exllama_weight(
......
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import torch
# Compiler flags.
CXX_FLAGS = ["-g", "-O3", "-std=c++17"]
NVCC_FLAGS = ["-O3", "-std=c++17","-DUSE_ROCM","-U__HIP_NO_HALF_CONVERSIONS__","-U__HIP_NO_HALF_OPERATORS__"]
#--gpu-max-threads-per-block=1024编译会导致GPTQ多batch性能下降。
# NVCC_FLAGS = ["-O3", "-std=c++17","-DUSE_ROCM","--gpu-max-threads-per-block=1024","-U__HIP_NO_HALF_CONVERSIONS__","-U__HIP_NO_HALF_OPERATORS__"]
ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
}
setup(
name="gptq_kernels",
ext_modules=[
CUDAExtension(
name="gptq_kernels",
sources=[
"./torch_bindings.cpp",
"./q_gemm.cu",
],
extra_compile_args=extra_compile_args,
)
],
cmdclass={"build_ext": BuildExtension},
)
#include <torch/extension.h>
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
bool use_exllama, int64_t bit);
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
// Bindings
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("gptq_gemm", &gptq_gemm, "make_q_matrix");
m.def("gptq_shuffle", &gptq_shuffle, "gemm_half_q_half");
}
......@@ -93,6 +93,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
// Rotary embedding TGI for TGI
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
"rotary_embedding_tgi(Tensor! query, Tensor! key,"
" int head_size, Tensor cos_cache,"
" Tensor sin_cache, bool is_neox) -> ()");
// ops.def("rotary_embedding_tgi",&rotary_embedding_tgi);
ops.impl("rotary_embedding_tgi", torch::kCUDA, &rotary_embedding_tgi);
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key
// (supports multiple loras).
ops.def(
......@@ -164,6 +173,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
// trans w16
ops.def("trans_w16_gemm(Tensor! dst, Tensor src, int row, int col) -> ()");
ops.impl("trans_w16_gemm", torch::kCUDA, &trans_w16_gemm);
// Quantized GEMM for SqueezeLLM.
ops.def(
"squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
......
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
namespace vllm {
template <typename T>
__global__ void trans_w16_gemm_cudakernel(int64_t num_kernels,T* dst,const T* src,int64_t row,int64_t col)
{
int64_t id = blockIdx.x * blockDim.x + threadIdx.x;
if(id >= num_kernels) return;
int64_t j=id%row;
int64_t i=id/row;
dst[i*row+j]=src[j*col+i];
}
void trans_w16_gemm_cuda(half* dst,const half* src,int64_t row,int64_t col){
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int64_t num_kernels=row*col;
int block_size=256;
trans_w16_gemm_cudakernel<<<(num_kernels+block_size-1)/block_size,block_size, 0, stream>>>(num_kernels,dst,src,row,col);
}
} // namespace vllm
void trans_w16_gemm(torch::Tensor dst,torch::Tensor src,int64_t row,int64_t col){
const at::cuda::OptionalCUDAGuard device_guard(device_of(src));
vllm::trans_w16_gemm_cuda(
(half*)dst.data_ptr(),
(const half*)src.data_ptr(),
row,
col
);
}
\ No newline at end of file
......@@ -3,6 +3,10 @@ import functools
from typing import List, Optional, Tuple, Type
import torch
try:
import gptq_kernels
except ImportError as e:
raise RuntimeError("Failed to import gptq_kernel with, Please install gptq_kernels from csrc/quantization/gptq ")
from vllm.logger import init_logger
......@@ -182,14 +186,21 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, use_exllama: bool,
bit: int) -> torch.Tensor:
return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
return gptq_kernels.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, use_exllama, bit)
# return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
# b_g_idx, use_exllama, bit)
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
bit: int) -> None:
torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
gptq_kernels.gptq_shuffle(q_weight, q_perm, bit)
# torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# trans_w16
def trans_w16_gemm(dst: torch.Tensor, src: torch.Tensor,
row:int, col:int) -> None :
torch.ops._C.trans_w16_gemm(dst,src,row,col)
# squeezellm
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
......
......@@ -142,7 +142,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# flag to control if vllm should use triton flash attention
"VLLM_USE_TRITON_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "False").lower() in
("true", "1")),
# local rank of the process in the distributed setting, used to determine
......
......@@ -14,6 +14,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
import os
logger = init_logger(__name__)
......@@ -65,34 +66,6 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
return param[shard_id], loaded_weight
def pad_weight(weight: torch.Tensor, num_pad: int, pad_dim: int = 0):
if weight.dim() == 1:
padding = torch.zeros(num_pad, dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=0)
elif weight.dim() == 2:
if pad_dim == 0:
padding = torch.zeros(num_pad, weight.shape[1], dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=0)
elif pad_dim == 1:
padding = torch.zeros(weight.shape[0], num_pad, dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=1)
else:
raise ValueError("pad_dim must be 0 or 1")
else:
raise ValueError("Weight tensor must be 1D or 2D")
return padded_weight
def gemm_bank_conf(weight):
is_mul_of_2048 = weight % 2048 == 0
is_power_of_two = (weight & (weight - 1)) == 0 and weight != 0
if is_mul_of_2048 and is_power_of_two:
return True
else:
return False
class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods."""
......@@ -133,7 +106,6 @@ class UnquantizedLinearMethod(LinearMethodBase):
def __init__(self):
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
......@@ -151,20 +123,19 @@ class UnquantizedLinearMethod(LinearMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.use_llama_nn:
layer.weight = layer.weight.reshape(layer.weight.shape[1], -1)
if bias is not None:
return torch.matmul(x, layer.weight) + bias
else:
if gemm_bank_conf(layer.weight.shape[1] - 32) and os.environ['GEMM_PAD'] == '1':
return torch.matmul(x, layer.weight[:,:-32])
if len(x.shape) == 2:
return torch.addmm(bias, x, layer.weight)
else:
return torch.matmul(x, layer.weight)
return torch.matmul(x, layer.weight) + bias
else:
return torch.matmul(x, layer.weight)
else:
return F.linear(x, layer.weight, bias)
class LinearBase(torch.nn.Module):
"""Base linear layer.
......@@ -321,7 +292,6 @@ class ColumnParallelLinear(LinearBase):
})
else:
self.register_parameter("bias", None)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
......@@ -339,9 +309,6 @@ class ColumnParallelLinear(LinearBase):
loaded_weight = loaded_weight.reshape(1)
assert param_data.shape == loaded_weight.shape
if self.use_llama_nn:
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(param_data.shape[0],-1)
param_data.copy_(loaded_weight)
def forward(self, input_):
......@@ -406,8 +373,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self,
param: Parameter,
......@@ -469,21 +434,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
if use_bitsandbytes:
shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id
if self.use_llama_nn:
param_data_ = param_data.narrow(output_dim, shard_offset,
shard_size)
else:
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
......@@ -506,16 +465,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
"Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.")
if self.use_llama_nn:
assert param_data_.shape == loaded_weight.shape
param_data_.copy_(loaded_weight)
if loaded_shard_id == 1 and len(param_data.shape) == 2:
param_data = param_data.transpose(0, 1)
param.data = param_data.reshape(param_data.shape[1], -1)
else:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class QKVParallelLinear(ColumnParallelLinear):
......@@ -575,6 +527,7 @@ class QKVParallelLinear(ColumnParallelLinear):
self.num_kv_heads * self.head_size * tp_size, # k_proj
self.num_kv_heads * self.head_size * tp_size, # v_proj
]
super().__init__(input_size=input_size,
output_size=output_size,
bias=bias,
......@@ -582,8 +535,6 @@ class QKVParallelLinear(ColumnParallelLinear):
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def weight_loader(self,
param: Parameter,
......@@ -675,14 +626,9 @@ class QKVParallelLinear(ColumnParallelLinear):
}
shard_size, shard_offset = adjust_bitsandbytes_shard(
param, orig_qkv_offsets, loaded_shard_id)
if self.use_llama_nn:
param_data_ = param_data.narrow(output_dim, shard_offset,
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
else:
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
if loaded_shard_id == "q":
shard_id = tp_rank
else:
......@@ -708,20 +654,9 @@ class QKVParallelLinear(ColumnParallelLinear):
"Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same "
"for all partitions.")
if self.use_llama_nn:
assert param_data_.shape == loaded_weight.shape
param_data_.copy_(loaded_weight)
if loaded_shard_id == "v" and len(param_data.shape) == 2:
if self.use_fa_pad and param_data.shape[0]== 12288:
param_data = pad_weight(param.data, 32)
param_data = param_data.transpose(0, 1)
param.data = param_data.reshape(param_data.shape[1], -1)
if self.use_fa_pad and param_data.shape[0]== 12288 and loaded_shard_id == "v" and len(param_data.shape) == 1:
param.data = pad_weight(param.data, 32)
else:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class RowParallelLinear(LinearBase):
......@@ -790,8 +725,6 @@ class RowParallelLinear(LinearBase):
})
else:
self.register_parameter("bias", None)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
......@@ -809,18 +742,7 @@ class RowParallelLinear(LinearBase):
loaded_weight = loaded_weight.reshape(1)
assert param_data.shape == loaded_weight.shape
# if self.use_llama_nn:
# loaded_weight = loaded_weight.transpose(0, 1)
# loaded_weight=loaded_weight.reshape(param_data.shape[0],-1)
# param_data.copy_(loaded_weight)
param_data.copy_(loaded_weight)
if self.use_llama_nn:
if gemm_bank_conf(param.data.shape[0]) and self.use_gemm_pad:
param.data = pad_weight(param.data, 32)
param.data = param.data.transpose(0, 1)
param.data=param.data.reshape(param.data.shape[1],-1)
def forward(self, input_):
if self.input_is_parallel:
......@@ -853,4 +775,4 @@ class RowParallelLinear(LinearBase):
s += f", bias={self.bias is not None}"
s += f", tp_size={self.tp_size}"
s += f", reduce_results={self.reduce_results}"
return s
return s
\ No newline at end of file
......@@ -22,13 +22,11 @@ def set_default_torch_dtype(dtype: torch.dtype):
def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
if architectures == ['LlamaForCausalLM'] or architectures == ['Qwen2ForCausalLM'] or architectures == ['ChatGLMModel'] or architectures == ['BaichuanForCausalLM']:
if architectures == ['LlamaForCausalLM'] or architectures == ['QWenLMHeadModel'] or architectures == ['Qwen2ForCausalLM'] or architectures == ['ChatGLMModel'] or architectures == ['BaichuanForCausalLM']:
if os.getenv('LLAMA_NN') != '0':
os.environ['LLAMA_NN'] = '1'
if os.getenv('GEMM_PAD') != '0':
os.environ['GEMM_PAD'] = '1'
if os.getenv('FA_PAD') != '1':
os.environ['FA_PAD'] = '0'
else:
os.environ['LLAMA_NN'] = '0'
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None
......
......@@ -25,6 +25,7 @@ import torch
from torch import nn
from transformers import PretrainedConfig
import os
import re
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
......@@ -48,6 +49,9 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsLoRA
from vllm import _custom_ops as ops
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
......@@ -181,8 +185,6 @@ class BaiChuanAttention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k)
......@@ -336,6 +338,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def forward(
self,
......@@ -404,6 +407,26 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if self.use_llama_nn:
lay_key_words = [
"self_attn.W_pack.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight"
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1], -1)
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
......
......@@ -8,6 +8,7 @@ import torch
from torch import nn
from torch.nn import LayerNorm
import os
import re
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
......@@ -28,6 +29,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig
from vllm import _custom_ops as ops
from .interfaces import SupportsLoRA
......@@ -104,8 +106,6 @@ class GLMAttention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
context_layer = self.attn(
......@@ -362,6 +362,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
self.lm_head = self.transformer.output_layer
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def forward(
self,
......@@ -403,3 +404,23 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if self.use_llama_nn:
lay_key_words = [
"self_attention.query_key_value.weight",
"self_attention.dense.weight",
"mlp.dense_h_to_4h.weight",
"mlp.dense_4h_to_h.weight"
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1], -1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment