Unverified Commit 5f6d10c1 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[CI/Build] Enforce style for C++ and CUDA code with `clang-format` (#4722)

parent 9b9a10d6
......@@ -22,27 +22,23 @@ __device__ inline unsigned int as_unsigned(int i) {
// 4-bit matvec kernel (LUT-based)
__global__ void NUQ4MatMulKernel(
#ifndef USE_ROCM
const half2* __restrict__ vec,
const half2* __restrict__ vec,
#else
const __half2* __restrict__ vec,
const __half2* __restrict__ vec,
#endif
const int* __restrict__ mat,
const int* __restrict__ mat,
#ifndef USE_ROCM
half2* __restrict__ mul,
half2* __restrict__ mul,
#else
float2* __restrict__ mul,
float2* __restrict__ mul,
#endif
const __half* __restrict__ lookup_table,
int height,
int width,
int batch,
int vec_height
) {
const __half* __restrict__ lookup_table, int height, int width, int batch,
int vec_height) {
const int blockwidth2 = BLOCKWIDTH / 2;
int row = BLOCKHEIGHT4 * blockIdx.x;
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
#ifndef USE_ROCM
__shared__ half2 blockvec[blockwidth2];
......@@ -73,14 +69,16 @@ __global__ void NUQ4MatMulKernel(
unsigned int tmp1;
unsigned int lut_index1, lut_index2;
for (int b = 0; b < batch; ++b){
for (int b = 0; b < batch; ++b) {
i = width * row + col;
res = __int2half_rd(0);
k = 0;
__syncthreads();
if (threadIdx.x < blockwidth2)
blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x];
blockvec[threadIdx.x] =
vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 +
threadIdx.x];
__syncthreads();
while (k < blockwidth2) {
......@@ -143,7 +141,8 @@ __global__ void NUQ4MatMulKernel(
#ifndef USE_ROCM
res = __hadd(__hadd(res2.x, res2.y), res);
#else
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res);
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)),
res);
#endif
i += width;
......@@ -179,46 +178,38 @@ __global__ void NUQ4MatMulKernel(
}
}
} // namespace squeezellm
} // namespace vllm
} // namespace squeezellm
} // namespace vllm
// 4-bit matvec kernel (LUT-based)
void squeezellm_gemm(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor lookup_table
) {
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor lookup_table) {
int height = mat.size(0);
int width = mat.size(1);
int batch = vec.size(0);
int vec_height = vec.size(1);
dim3 blocks(
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
);
dim3 blocks((height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH);
dim3 threads(BLOCKWIDTH);
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
#ifndef USE_ROCM
(half2*) vec.data<at::Half>(),
(half2*)vec.data<at::Half>(),
#else
(__half2*) vec.data_ptr<at::Half>(),
(__half2*)vec.data_ptr<at::Half>(),
#endif
mat.data_ptr<int>(),
mat.data_ptr<int>(),
#ifndef USE_ROCM
(half2*) mul.data<at::Half>(),
(__half*) lookup_table.data<at::Half>(),
(half2*)mul.data<at::Half>(), (__half*)lookup_table.data<at::Half>(),
#else
(float2*) mul.data_ptr<float>(),
(__half*) lookup_table.data_ptr<at::Half>(),
(float2*)mul.data_ptr<float>(),
(__half*)lookup_table.data_ptr<at::Half>(),
#endif
height, width, batch, vec_height
);
height, width, batch, vec_height);
}
#undef BLOCKWIDTH
......
/*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
......@@ -20,12 +21,12 @@
#include "cuda_compat.h"
namespace vllm {
template<typename T, int numLanes = WARP_SIZE>
template <typename T, int numLanes = WARP_SIZE>
__inline__ __device__ T warpReduceSum(T val) {
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
"numLanes is not a positive power of 2!");
static_assert(numLanes <= WARP_SIZE);
#pragma unroll
#pragma unroll
for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
val += VLLM_SHFL_XOR_SYNC(val, mask);
return val;
......@@ -38,22 +39,23 @@ static constexpr int _nextPow2(unsigned int num) {
}
/* Calculate the sum of all elements in a block */
template<typename T, int maxBlockSize = 1024>
template <typename T, int maxBlockSize = 1024>
__inline__ __device__ T blockReduceSum(T val) {
static_assert(maxBlockSize <= 1024);
if constexpr (maxBlockSize > WARP_SIZE) {
val = warpReduceSum<T>(val);
// Calculates max number of lanes that need to participate in the last warpReduce
// Calculates max number of lanes that need to participate in the last
// warpReduce
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
static __shared__ T shared[maxActiveLanes];
int lane = threadIdx.x % WARP_SIZE;
int wid = threadIdx.x / WARP_SIZE;
if (lane == 0)
shared[wid] = val;
if (lane == 0) shared[wid] = val;
__syncthreads();
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] : (T)(0.0f);
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane]
: (T)(0.0f);
val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val);
} else {
// A single warpReduce is equal to blockReduce
......@@ -62,4 +64,4 @@ __inline__ __device__ T blockReduceSum(T val) {
return val;
}
} // namespace vllm
} // namespace vllm
......@@ -26,6 +26,7 @@ RUFF_VERSION=$(ruff --version | awk '{print $2}')
MYPY_VERSION=$(mypy --version | awk '{print $2}')
CODESPELL_VERSION=$(codespell --version)
ISORT_VERSION=$(isort --vn)
CLANGFORMAT_VERSION=$(clang-format --version | awk '{print $3}')
# # params: tool name, tool version, required version
tool_version_check() {
......@@ -40,6 +41,7 @@ tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt |
tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "isort" "$ISORT_VERSION" "$(grep isort requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "clang-format" "$CLANGFORMAT_VERSION" "$(grep clang-format requirements-dev.txt | cut -d'=' -f3)"
YAPF_FLAGS=(
'--recursive'
......@@ -179,7 +181,6 @@ lint_changed() {
}
# Run Ruff
echo 'vLLM ruff:'
### This flag lints individual files. --files *must* be the first command line
### arg to use this option.
if [[ "$1" == '--files' ]]; then
......@@ -192,6 +193,7 @@ else
# Format only the files that changed in last commit.
lint_changed
fi
echo 'vLLM ruff: Done'
# check spelling of specified files
isort_check() {
......@@ -233,6 +235,59 @@ else
fi
echo 'vLLM isort: Done'
# Clang-format section
# Exclude some files for formatting because they are vendored
# NOTE: Keep up to date with .github/workflows/clang-format.yml
CLANG_FORMAT_EXCLUDES=(
'csrc/moe/topk_softmax_kernels.cu'
'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu'
'csrc/punica/bgmv/bgmv_config.h'
'csrc/punica/bgmv/bgmv_impl.cuh'
'csrc/punica/bgmv/vec_dtypes.cuh'
'csrc/punica/punica_ops.cu'
'csrc/punica/type_convert.h'
)
# Format specified files with clang-format
clang_format() {
clang-format -i "$@"
}
# Format files that differ from main branch with clang-format.
clang_format_changed() {
# The `if` guard ensures that the list of filenames is not empty, which
# could cause clang-format to receive 0 positional arguments, making it hang
# waiting for STDIN.
#
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
# exist on both branches.
MERGEBASE="$(git merge-base origin/main HEAD)"
# Get the list of changed files, excluding the specified ones
changed_files=$(git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.h' '*.cpp' '*.cu' '*.cuh' | grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}"))
if [ -n "$changed_files" ]; then
echo "$changed_files" | xargs -P 5 clang-format -i
fi
}
# Format all files with clang-format
clang_format_all() {
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
| grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}") \
| xargs clang-format -i
}
# Run clang-format
if [[ "$1" == '--files' ]]; then
clang_format "${@:2}"
elif [[ "$1" == '--all' ]]; then
clang_format_all
else
clang_format_changed
fi
echo 'vLLM clang-format: Done'
if ! git diff --quiet &>/dev/null; then
echo 'Reformatted files. Please review and stage the changes.'
echo 'Changes not staged for commit:'
......
......@@ -5,6 +5,7 @@ tomli==2.0.1
ruff==0.1.5
codespell==2.2.6
isort==5.13.2
clang-format==18.1.5
# type checking
mypy==1.9.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment