Commit e00b0a19 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.3.3

parents ead94d93 3f1166ab
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half)
#pragma once
template <int feat_in, int feat_out, typename in_T, typename out_T,
typename W_T>
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
int64_t layer_idx, float scale);
// clang-format off
#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \
f(in_T, out_T, W_T, narrow, 128) \
f(in_T, out_T, W_T, narrow, 256) \
f(in_T, out_T, W_T, narrow, 512) \
f(in_T, out_T, W_T, narrow, 1024) \
f(in_T, out_T, W_T, narrow, 1280) \
f(in_T, out_T, W_T, narrow, 1728) \
f(in_T, out_T, W_T, narrow, 1792) \
f(in_T, out_T, W_T, narrow, 2048) \
f(in_T, out_T, W_T, narrow, 2560) \
f(in_T, out_T, W_T, narrow, 2752) \
f(in_T, out_T, W_T, narrow, 3072) \
f(in_T, out_T, W_T, narrow, 3456) \
f(in_T, out_T, W_T, narrow, 3584) \
f(in_T, out_T, W_T, narrow, 4096) \
f(in_T, out_T, W_T, narrow, 5120) \
f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \
f(in_T, out_T, W_T, narrow, 8192) \
f(in_T, out_T, W_T, narrow, 9216) \
f(in_T, out_T, W_T, narrow, 10240) \
f(in_T, out_T, W_T, narrow, 11008) \
f(in_T, out_T, W_T, narrow, 12288) \
f(in_T, out_T, W_T, narrow, 13824) \
f(in_T, out_T, W_T, narrow, 14336) \
f(in_T, out_T, W_T, narrow, 16384) \
f(in_T, out_T, W_T, narrow, 20480) \
f(in_T, out_T, W_T, narrow, 24576) \
f(in_T, out_T, W_T, narrow, 28672) \
f(in_T, out_T, W_T, narrow, 32000) \
f(in_T, out_T, W_T, narrow, 32256) \
f(in_T, out_T, W_T, narrow, 32512) \
f(in_T, out_T, W_T, narrow, 32768) \
f(in_T, out_T, W_T, narrow, 33024) \
f(in_T, out_T, W_T, narrow, 36864) \
f(in_T, out_T, W_T, narrow, 49152) \
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
// Keep this in sync with vllm/config::LoRAConfig
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
// clang-format on
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half)
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <cooperative_groups.h>
#include <cuda/pipeline>
#include <cuda_runtime.h>
#include <iostream>
#include <stdio.h>
#include "vec_dtypes.cuh"
namespace cg = cooperative_groups;
// nthrs = (32, 4)
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
typename out_T, typename W_T>
__global__ void
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
float scale) {
size_t batch_idx = blockIdx.y;
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
if (idx < 0) {
return;
}
auto block = cg::this_thread_block();
size_t j = blockIdx.x;
constexpr size_t num_pipeline_stages = 2;
constexpr size_t tile_size = tx * ty * vec_size;
__shared__ W_T W_shared[num_pipeline_stages * tile_size];
__shared__ in_T X_shared[num_pipeline_stages * tile_size];
__shared__ float y_warpwise[ty];
size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
auto pipe = cuda::make_pipeline();
// pipeline load W/X and compute WX;
pipe.producer_acquire();
cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
W + (idx * feat_out + j) * feat_in +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
X + (batch_idx * feat_in) +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
pipe.producer_commit();
size_t copy_idx, compute_idx;
float y = 0.f;
vec_t<in_T, vec_size> x_vec;
vec_t<W_T, vec_size> w_vec;
size_t tile_idx;
#pragma unroll
for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size;
++tile_idx) {
copy_idx = tile_idx % num_pipeline_stages;
// pipeline stage: async copy W fragment
pipe.producer_acquire();
if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) {
cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size,
W + (idx * feat_out + j) * feat_in +
tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size,
X + (batch_idx * feat_in) + tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
}
pipe.producer_commit();
compute_idx = (tile_idx - 1) % num_pipeline_stages;
// pipeline stage: compute WX
pipe.consumer_wait();
block.sync();
x_vec.load(X_shared + X_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
w_vec.load(W_shared + W_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
}
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += __shfl_down_sync(0xffffffff, sum, offset);
}
y_warpwise[threadIdx.y] = sum;
block.sync();
#pragma unroll
for (size_t i = 0; i < ty; ++i) {
y += y_warpwise[i];
}
block.sync();
pipe.consumer_release();
}
compute_idx = (tile_idx - 1) % num_pipeline_stages;
// final pipeline stage
pipe.consumer_wait();
block.sync();
x_vec.load(X_shared + X_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
w_vec.load(W_shared + W_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
}
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += __shfl_down_sync(0xffffffff, sum, offset);
}
y_warpwise[threadIdx.y] =
((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in)
? sum
: 0.f;
block.sync();
#pragma unroll
for (size_t i = 0; i < ty; ++i) {
y += y_warpwise[i];
}
block.sync();
pipe.consumer_release();
// write Y;
if (block.thread_rank() == 0) {
Y[batch_idx * full_y_size + y_offset + j] += static_cast<out_T>(y);
}
}
// nthrs = (2, 16, 4)
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
typename in_T, typename out_T, typename W_T>
__global__ void
bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
float scale) {
size_t batch_idx = blockIdx.y;
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
if (idx < 0) {
return;
}
auto block = cg::this_thread_block();
size_t tile_idx = blockIdx.x;
// load X;
vec_t<in_T, vec_size> x_vec;
x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size);
// load W;
vec_t<W_T, vec_size> w_vec;
w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in +
block.thread_rank() * vec_size);
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
}
cg::thread_block_tile g = cg::tiled_partition<tx>(block);
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += g.shfl_down(sum, offset);
}
sum = g.shfl(sum, 0);
if (threadIdx.x == 0) {
Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
}
}
template <int feat_in, int feat_out, typename in_T, typename out_T,
typename W_T>
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
int64_t layer_idx, float scale) {
constexpr size_t vec_size = 8;
constexpr int tz = 4;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if constexpr (feat_in < feat_out) {
static_assert(feat_in % vec_size == 0);
constexpr int tx = feat_in / vec_size;
static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) ||
(16 % tx == 0 && feat_out % (16 / tx * tz) == 0) ||
(8 % tx == 0 && feat_out % (8 / tx * tz) == 0));
if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) {
constexpr int ty = 32 / tx;
dim3 nblks(feat_out / (ty * tz), batch_size);
dim3 nthrs(tx, ty, tz);
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) {
constexpr int ty = 16 / tx;
dim3 nblks(feat_out / (ty * tz), batch_size);
dim3 nthrs(tx, ty, tz);
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else {
constexpr int ty = 8 / tx;
dim3 nblks(feat_out / (ty * tz), batch_size);
dim3 nthrs(tx, ty, tz);
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
}
} else {
static_assert(feat_in % (vec_size * 32) == 0 ||
feat_in % (vec_size * 16) == 0 ||
feat_in % (vec_size * 8) == 0);
if constexpr (feat_in % (vec_size * 32) == 0) {
constexpr int tx = 32;
constexpr int ty = 4;
dim3 nblks(feat_out, batch_size);
dim3 nthrs(tx, ty);
bgmv_shrink_kernel<feat_in, feat_out, vec_size, vec_size * sizeof(in_T),
vec_size * sizeof(W_T), tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else if constexpr (feat_in % (vec_size / 2 * 32) == 0) {
constexpr int tx = 32;
constexpr int ty = 4;
dim3 nblks(feat_out, batch_size);
dim3 nthrs(tx, ty);
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
vec_size * sizeof(in_T) / 2,
vec_size * sizeof(W_T) / 2, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else if constexpr (feat_in % (vec_size / 2 * 16) == 0) {
constexpr int tx = 16;
constexpr int ty = 4;
dim3 nblks(feat_out, batch_size);
dim3 nthrs(tx, ty);
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
vec_size * sizeof(in_T) / 2,
vec_size * sizeof(W_T) / 2, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
}
}
}
#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \
template void bgmv_kernel<feat_in, feat_out>( \
out_T * __restrict__ Y, const in_T *__restrict__ X, \
const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
int64_t num_layers, int64_t layer_idx, float scale);
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
INST_BGMV(narrow, wide, in_T, out_T, W_T) \
INST_BGMV(wide, narrow, in_T, out_T, W_T)
DTYPES = ["fp16", "bf16", "fp32"]
DTYPE_MAP = {
"fp16": "nv_half",
"bf16": "nv_bfloat16",
"fp32": "float",
}
TEMPLATE = """
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
""".lstrip()
for input_dtype in DTYPES:
for output_dtype in DTYPES:
for weight_dtype in DTYPES:
if weight_dtype == "fp32":
# FP32 weights are not supported.
continue
kernel_definition = TEMPLATE.format(
input_dtype=DTYPE_MAP[input_dtype],
output_dtype=DTYPE_MAP[output_dtype],
weight_dtype=DTYPE_MAP[weight_dtype])
filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu"
with open(filename, "w") as f:
f.write(kernel_definition)
This diff is collapsed.
This diff is collapsed.
...@@ -22,6 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -22,6 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"silu_and_mul", "silu_and_mul",
&silu_and_mul, &silu_and_mul,
"Activation function used in SwiGLU."); "Activation function used in SwiGLU.");
ops.def(
"gelu_and_mul",
&gelu_and_mul,
"Activation function used in GeGLU.");
ops.def( ops.def(
"gelu_new", "gelu_new",
&gelu_new, &gelu_new,
...@@ -48,13 +52,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -48,13 +52,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&rotary_embedding, &rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); "Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
// Quantization ops
#ifndef USE_ROCM #ifndef USE_ROCM
// Quantization ops
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
#endif #endif
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def(
"moe_align_block_size",
&moe_align_block_size,
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");
// Cache ops // Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
...@@ -71,9 +82,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -71,9 +82,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&reshape_and_cache, &reshape_and_cache,
"Reshape the key and value tensors and cache them"); "Reshape the key and value tensors and cache them");
cache_ops.def( cache_ops.def(
"gather_cached_kv", "convert_fp8_e5m2",
&gather_cached_kv, &convert_fp8_e5m2,
"Gather key and value from the cache into contiguous QKV tensors"); "Convert the key and value cache to fp8_e5m2 data type");
// Cuda utils // Cuda utils
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils"); pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
...@@ -81,4 +92,26 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -81,4 +92,26 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"get_device_attribute", "get_device_attribute",
&get_device_attribute, &get_device_attribute,
"Gets the specified device attribute."); "Gets the specified device attribute.");
cuda_utils.def(
"get_max_shared_memory_per_block_device_attribute",
&get_max_shared_memory_per_block_device_attribute,
"Gets the maximum shared memory per block device attribute.");
#ifndef USE_ROCM
// Custom all-reduce kernels
pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce");
custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar");
custom_ar.def("should_custom_ar", &should_custom_ar, "should_custom_ar");
custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg");
custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg");
custom_ar.def("dispose", &dispose, "dispose");
custom_ar.def("meta_size", &meta_size, "meta_size");
custom_ar.def("register_buffer", &register_buffer, "register_buffer");
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta,
"get_graph_buffer_ipc_meta");
custom_ar.def("register_graph_buffers", &register_graph_buffers,
"register_graph_buffers");
#endif
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment