Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5f6d10c1
Unverified
Commit
5f6d10c1
authored
May 22, 2024
by
Michael Goin
Committed by
GitHub
May 22, 2024
Browse files
[CI/Build] Enforce style for C++ and CUDA code with `clang-format` (#4722)
parent
9b9a10d6
Changes
64
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
95 additions
and
46 deletions
+95
-46
csrc/quantization/squeezellm/quant_cuda_kernel.cu
csrc/quantization/squeezellm/quant_cuda_kernel.cu
+27
-36
csrc/reduction_utils.cuh
csrc/reduction_utils.cuh
+11
-9
format.sh
format.sh
+56
-1
requirements-dev.txt
requirements-dev.txt
+1
-0
No files found.
csrc/quantization/squeezellm/quant_cuda_kernel.cu
View file @
5f6d10c1
...
@@ -22,27 +22,23 @@ __device__ inline unsigned int as_unsigned(int i) {
...
@@ -22,27 +22,23 @@ __device__ inline unsigned int as_unsigned(int i) {
// 4-bit matvec kernel (LUT-based)
// 4-bit matvec kernel (LUT-based)
__global__
void
NUQ4MatMulKernel
(
__global__
void
NUQ4MatMulKernel
(
#ifndef USE_ROCM
#ifndef USE_ROCM
const
half2
*
__restrict__
vec
,
const
half2
*
__restrict__
vec
,
#else
#else
const
__half2
*
__restrict__
vec
,
const
__half2
*
__restrict__
vec
,
#endif
#endif
const
int
*
__restrict__
mat
,
const
int
*
__restrict__
mat
,
#ifndef USE_ROCM
#ifndef USE_ROCM
half2
*
__restrict__
mul
,
half2
*
__restrict__
mul
,
#else
#else
float2
*
__restrict__
mul
,
float2
*
__restrict__
mul
,
#endif
#endif
const
__half
*
__restrict__
lookup_table
,
const
__half
*
__restrict__
lookup_table
,
int
height
,
int
width
,
int
batch
,
int
height
,
int
vec_height
)
{
int
width
,
int
batch
,
int
vec_height
)
{
const
int
blockwidth2
=
BLOCKWIDTH
/
2
;
const
int
blockwidth2
=
BLOCKWIDTH
/
2
;
int
row
=
BLOCKHEIGHT4
*
blockIdx
.
x
;
int
row
=
BLOCKHEIGHT4
*
blockIdx
.
x
;
int
col
=
BLOCKWIDTH
*
blockIdx
.
y
+
threadIdx
.
x
;
int
col
=
BLOCKWIDTH
*
blockIdx
.
y
+
threadIdx
.
x
;
#ifndef USE_ROCM
#ifndef USE_ROCM
__shared__
half2
blockvec
[
blockwidth2
];
__shared__
half2
blockvec
[
blockwidth2
];
...
@@ -73,14 +69,16 @@ __global__ void NUQ4MatMulKernel(
...
@@ -73,14 +69,16 @@ __global__ void NUQ4MatMulKernel(
unsigned
int
tmp1
;
unsigned
int
tmp1
;
unsigned
int
lut_index1
,
lut_index2
;
unsigned
int
lut_index1
,
lut_index2
;
for
(
int
b
=
0
;
b
<
batch
;
++
b
){
for
(
int
b
=
0
;
b
<
batch
;
++
b
)
{
i
=
width
*
row
+
col
;
i
=
width
*
row
+
col
;
res
=
__int2half_rd
(
0
);
res
=
__int2half_rd
(
0
);
k
=
0
;
k
=
0
;
__syncthreads
();
__syncthreads
();
if
(
threadIdx
.
x
<
blockwidth2
)
if
(
threadIdx
.
x
<
blockwidth2
)
blockvec
[
threadIdx
.
x
]
=
vec
[
b
*
vec_height
/
2
+
(
row
/
BLOCKHEIGHT4
)
*
blockwidth2
+
threadIdx
.
x
];
blockvec
[
threadIdx
.
x
]
=
vec
[
b
*
vec_height
/
2
+
(
row
/
BLOCKHEIGHT4
)
*
blockwidth2
+
threadIdx
.
x
];
__syncthreads
();
__syncthreads
();
while
(
k
<
blockwidth2
)
{
while
(
k
<
blockwidth2
)
{
...
@@ -143,7 +141,8 @@ __global__ void NUQ4MatMulKernel(
...
@@ -143,7 +141,8 @@ __global__ void NUQ4MatMulKernel(
#ifndef USE_ROCM
#ifndef USE_ROCM
res
=
__hadd
(
__hadd
(
res2
.
x
,
res2
.
y
),
res
);
res
=
__hadd
(
__hadd
(
res2
.
x
,
res2
.
y
),
res
);
#else
#else
res
=
__hadd
(
__hadd
(
__ushort_as_half
(
res2
.
x
),
__ushort_as_half
(
res2
.
y
)),
res
);
res
=
__hadd
(
__hadd
(
__ushort_as_half
(
res2
.
x
),
__ushort_as_half
(
res2
.
y
)),
res
);
#endif
#endif
i
+=
width
;
i
+=
width
;
...
@@ -179,46 +178,38 @@ __global__ void NUQ4MatMulKernel(
...
@@ -179,46 +178,38 @@ __global__ void NUQ4MatMulKernel(
}
}
}
}
}
// namespace squeezellm
}
// namespace squeezellm
}
// namespace vllm
}
// namespace vllm
// 4-bit matvec kernel (LUT-based)
// 4-bit matvec kernel (LUT-based)
void
squeezellm_gemm
(
void
squeezellm_gemm
(
torch
::
Tensor
vec
,
torch
::
Tensor
mat
,
torch
::
Tensor
mul
,
torch
::
Tensor
vec
,
torch
::
Tensor
lookup_table
)
{
torch
::
Tensor
mat
,
torch
::
Tensor
mul
,
torch
::
Tensor
lookup_table
)
{
int
height
=
mat
.
size
(
0
);
int
height
=
mat
.
size
(
0
);
int
width
=
mat
.
size
(
1
);
int
width
=
mat
.
size
(
1
);
int
batch
=
vec
.
size
(
0
);
int
batch
=
vec
.
size
(
0
);
int
vec_height
=
vec
.
size
(
1
);
int
vec_height
=
vec
.
size
(
1
);
dim3
blocks
(
dim3
blocks
((
height
+
BLOCKHEIGHT4
-
1
)
/
BLOCKHEIGHT4
,
(
height
+
BLOCKHEIGHT4
-
1
)
/
BLOCKHEIGHT4
,
(
width
+
BLOCKWIDTH
-
1
)
/
BLOCKWIDTH
);
(
width
+
BLOCKWIDTH
-
1
)
/
BLOCKWIDTH
);
dim3
threads
(
BLOCKWIDTH
);
dim3
threads
(
BLOCKWIDTH
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
vec
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
vec
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
vllm
::
squeezellm
::
NUQ4MatMulKernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
vllm
::
squeezellm
::
NUQ4MatMulKernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
#ifndef USE_ROCM
#ifndef USE_ROCM
(
half2
*
)
vec
.
data
<
at
::
Half
>
(),
(
half2
*
)
vec
.
data
<
at
::
Half
>
(),
#else
#else
(
__half2
*
)
vec
.
data_ptr
<
at
::
Half
>
(),
(
__half2
*
)
vec
.
data_ptr
<
at
::
Half
>
(),
#endif
#endif
mat
.
data_ptr
<
int
>
(),
mat
.
data_ptr
<
int
>
(),
#ifndef USE_ROCM
#ifndef USE_ROCM
(
half2
*
)
mul
.
data
<
at
::
Half
>
(),
(
half2
*
)
mul
.
data
<
at
::
Half
>
(),
(
__half
*
)
lookup_table
.
data
<
at
::
Half
>
(),
(
__half
*
)
lookup_table
.
data
<
at
::
Half
>
(),
#else
#else
(
float2
*
)
mul
.
data_ptr
<
float
>
(),
(
float2
*
)
mul
.
data_ptr
<
float
>
(),
(
__half
*
)
lookup_table
.
data_ptr
<
at
::
Half
>
(),
(
__half
*
)
lookup_table
.
data_ptr
<
at
::
Half
>
(),
#endif
#endif
height
,
width
,
batch
,
vec_height
height
,
width
,
batch
,
vec_height
);
);
}
}
#undef BLOCKWIDTH
#undef BLOCKWIDTH
...
...
csrc/reduction_utils.cuh
View file @
5f6d10c1
/*
/*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
*
...
@@ -20,12 +21,12 @@
...
@@ -20,12 +21,12 @@
#include "cuda_compat.h"
#include "cuda_compat.h"
namespace
vllm
{
namespace
vllm
{
template
<
typename
T
,
int
numLanes
=
WARP_SIZE
>
template
<
typename
T
,
int
numLanes
=
WARP_SIZE
>
__inline__
__device__
T
warpReduceSum
(
T
val
)
{
__inline__
__device__
T
warpReduceSum
(
T
val
)
{
static_assert
(
numLanes
>
0
&&
(
numLanes
&
(
numLanes
-
1
))
==
0
,
static_assert
(
numLanes
>
0
&&
(
numLanes
&
(
numLanes
-
1
))
==
0
,
"numLanes is not a positive power of 2!"
);
"numLanes is not a positive power of 2!"
);
static_assert
(
numLanes
<=
WARP_SIZE
);
static_assert
(
numLanes
<=
WARP_SIZE
);
#pragma unroll
#pragma unroll
for
(
int
mask
=
numLanes
>>
1
;
mask
>
0
;
mask
>>=
1
)
for
(
int
mask
=
numLanes
>>
1
;
mask
>
0
;
mask
>>=
1
)
val
+=
VLLM_SHFL_XOR_SYNC
(
val
,
mask
);
val
+=
VLLM_SHFL_XOR_SYNC
(
val
,
mask
);
return
val
;
return
val
;
...
@@ -38,22 +39,23 @@ static constexpr int _nextPow2(unsigned int num) {
...
@@ -38,22 +39,23 @@ static constexpr int _nextPow2(unsigned int num) {
}
}
/* Calculate the sum of all elements in a block */
/* Calculate the sum of all elements in a block */
template
<
typename
T
,
int
maxBlockSize
=
1024
>
template
<
typename
T
,
int
maxBlockSize
=
1024
>
__inline__
__device__
T
blockReduceSum
(
T
val
)
{
__inline__
__device__
T
blockReduceSum
(
T
val
)
{
static_assert
(
maxBlockSize
<=
1024
);
static_assert
(
maxBlockSize
<=
1024
);
if
constexpr
(
maxBlockSize
>
WARP_SIZE
)
{
if
constexpr
(
maxBlockSize
>
WARP_SIZE
)
{
val
=
warpReduceSum
<
T
>
(
val
);
val
=
warpReduceSum
<
T
>
(
val
);
// Calculates max number of lanes that need to participate in the last warpReduce
// Calculates max number of lanes that need to participate in the last
// warpReduce
constexpr
int
maxActiveLanes
=
(
maxBlockSize
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
constexpr
int
maxActiveLanes
=
(
maxBlockSize
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
static
__shared__
T
shared
[
maxActiveLanes
];
static
__shared__
T
shared
[
maxActiveLanes
];
int
lane
=
threadIdx
.
x
%
WARP_SIZE
;
int
lane
=
threadIdx
.
x
%
WARP_SIZE
;
int
wid
=
threadIdx
.
x
/
WARP_SIZE
;
int
wid
=
threadIdx
.
x
/
WARP_SIZE
;
if
(
lane
==
0
)
if
(
lane
==
0
)
shared
[
wid
]
=
val
;
shared
[
wid
]
=
val
;
__syncthreads
();
__syncthreads
();
val
=
(
threadIdx
.
x
<
blockDim
.
x
/
float
(
WARP_SIZE
))
?
shared
[
lane
]
:
(
T
)(
0.0
f
);
val
=
(
threadIdx
.
x
<
blockDim
.
x
/
float
(
WARP_SIZE
))
?
shared
[
lane
]
:
(
T
)(
0.0
f
);
val
=
warpReduceSum
<
T
,
_nextPow2
(
maxActiveLanes
)
>
(
val
);
val
=
warpReduceSum
<
T
,
_nextPow2
(
maxActiveLanes
)
>
(
val
);
}
else
{
}
else
{
// A single warpReduce is equal to blockReduce
// A single warpReduce is equal to blockReduce
...
@@ -62,4 +64,4 @@ __inline__ __device__ T blockReduceSum(T val) {
...
@@ -62,4 +64,4 @@ __inline__ __device__ T blockReduceSum(T val) {
return
val
;
return
val
;
}
}
}
// namespace vllm
}
// namespace vllm
format.sh
View file @
5f6d10c1
...
@@ -26,6 +26,7 @@ RUFF_VERSION=$(ruff --version | awk '{print $2}')
...
@@ -26,6 +26,7 @@ RUFF_VERSION=$(ruff --version | awk '{print $2}')
MYPY_VERSION
=
$(
mypy
--version
|
awk
'{print $2}'
)
MYPY_VERSION
=
$(
mypy
--version
|
awk
'{print $2}'
)
CODESPELL_VERSION
=
$(
codespell
--version
)
CODESPELL_VERSION
=
$(
codespell
--version
)
ISORT_VERSION
=
$(
isort
--vn
)
ISORT_VERSION
=
$(
isort
--vn
)
CLANGFORMAT_VERSION
=
$(
clang-format
--version
|
awk
'{print $3}'
)
# # params: tool name, tool version, required version
# # params: tool name, tool version, required version
tool_version_check
()
{
tool_version_check
()
{
...
@@ -40,6 +41,7 @@ tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt |
...
@@ -40,6 +41,7 @@ tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt |
tool_version_check
"mypy"
"
$MYPY_VERSION
"
"
$(
grep
mypy requirements-dev.txt |
cut
-d
'='
-f3
)
"
tool_version_check
"mypy"
"
$MYPY_VERSION
"
"
$(
grep
mypy requirements-dev.txt |
cut
-d
'='
-f3
)
"
tool_version_check
"isort"
"
$ISORT_VERSION
"
"
$(
grep
isort requirements-dev.txt |
cut
-d
'='
-f3
)
"
tool_version_check
"isort"
"
$ISORT_VERSION
"
"
$(
grep
isort requirements-dev.txt |
cut
-d
'='
-f3
)
"
tool_version_check
"codespell"
"
$CODESPELL_VERSION
"
"
$(
grep
codespell requirements-dev.txt |
cut
-d
'='
-f3
)
"
tool_version_check
"codespell"
"
$CODESPELL_VERSION
"
"
$(
grep
codespell requirements-dev.txt |
cut
-d
'='
-f3
)
"
tool_version_check
"clang-format"
"
$CLANGFORMAT_VERSION
"
"
$(
grep
clang-format requirements-dev.txt |
cut
-d
'='
-f3
)
"
YAPF_FLAGS
=(
YAPF_FLAGS
=(
'--recursive'
'--recursive'
...
@@ -179,7 +181,6 @@ lint_changed() {
...
@@ -179,7 +181,6 @@ lint_changed() {
}
}
# Run Ruff
# Run Ruff
echo
'vLLM ruff:'
### This flag lints individual files. --files *must* be the first command line
### This flag lints individual files. --files *must* be the first command line
### arg to use this option.
### arg to use this option.
if
[[
"
$1
"
==
'--files'
]]
;
then
if
[[
"
$1
"
==
'--files'
]]
;
then
...
@@ -192,6 +193,7 @@ else
...
@@ -192,6 +193,7 @@ else
# Format only the files that changed in last commit.
# Format only the files that changed in last commit.
lint_changed
lint_changed
fi
fi
echo
'vLLM ruff: Done'
# check spelling of specified files
# check spelling of specified files
isort_check
()
{
isort_check
()
{
...
@@ -233,6 +235,59 @@ else
...
@@ -233,6 +235,59 @@ else
fi
fi
echo
'vLLM isort: Done'
echo
'vLLM isort: Done'
# Clang-format section
# Exclude some files for formatting because they are vendored
# NOTE: Keep up to date with .github/workflows/clang-format.yml
CLANG_FORMAT_EXCLUDES
=(
'csrc/moe/topk_softmax_kernels.cu'
'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu'
'csrc/punica/bgmv/bgmv_config.h'
'csrc/punica/bgmv/bgmv_impl.cuh'
'csrc/punica/bgmv/vec_dtypes.cuh'
'csrc/punica/punica_ops.cu'
'csrc/punica/type_convert.h'
)
# Format specified files with clang-format
clang_format
()
{
clang-format
-i
"
$@
"
}
# Format files that differ from main branch with clang-format.
clang_format_changed
()
{
# The `if` guard ensures that the list of filenames is not empty, which
# could cause clang-format to receive 0 positional arguments, making it hang
# waiting for STDIN.
#
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
# exist on both branches.
MERGEBASE
=
"
$(
git merge-base origin/main HEAD
)
"
# Get the list of changed files, excluding the specified ones
changed_files
=
$(
git diff
--name-only
--diff-filter
=
ACM
"
$MERGEBASE
"
--
'*.h'
'*.cpp'
'*.cu'
'*.cuh'
|
grep
-vFf
<
(
printf
"%s
\n
"
"
${
CLANG_FORMAT_EXCLUDES
[@]
}
"
)
)
if
[
-n
"
$changed_files
"
]
;
then
echo
"
$changed_files
"
| xargs
-P
5 clang-format
-i
fi
}
# Format all files with clang-format
clang_format_all
()
{
find csrc/
\(
-name
'*.h'
-o
-name
'*.cpp'
-o
-name
'*.cu'
-o
-name
'*.cuh'
\)
-print
\
|
grep
-vFf
<
(
printf
"%s
\n
"
"
${
CLANG_FORMAT_EXCLUDES
[@]
}
"
)
\
| xargs clang-format
-i
}
# Run clang-format
if
[[
"
$1
"
==
'--files'
]]
;
then
clang_format
"
${
@
:2
}
"
elif
[[
"
$1
"
==
'--all'
]]
;
then
clang_format_all
else
clang_format_changed
fi
echo
'vLLM clang-format: Done'
if
!
git diff
--quiet
&>/dev/null
;
then
if
!
git diff
--quiet
&>/dev/null
;
then
echo
'Reformatted files. Please review and stage the changes.'
echo
'Reformatted files. Please review and stage the changes.'
echo
'Changes not staged for commit:'
echo
'Changes not staged for commit:'
...
...
requirements-dev.txt
View file @
5f6d10c1
...
@@ -5,6 +5,7 @@ tomli==2.0.1
...
@@ -5,6 +5,7 @@ tomli==2.0.1
ruff==0.1.5
ruff==0.1.5
codespell==2.2.6
codespell==2.2.6
isort==5.13.2
isort==5.13.2
clang-format==18.1.5
# type checking
# type checking
mypy==1.9.0
mypy==1.9.0
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment