Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7c25fe45
Unverified
Commit
7c25fe45
authored
Nov 23, 2024
by
kliuae
Committed by
GitHub
Nov 22, 2024
Browse files
[AMD] Add support for GGUF quantization on ROCm (#10254)
parent
02a43f82
Changes
11
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
234 additions
and
211 deletions
+234
-211
.buildkite/run-amd-test.sh
.buildkite/run-amd-test.sh
+0
-1
CMakeLists.txt
CMakeLists.txt
+1
-1
csrc/ops.h
csrc/ops.h
+2
-0
csrc/quantization/gguf/ggml-common.h
csrc/quantization/gguf/ggml-common.h
+16
-1
csrc/quantization/gguf/gguf_kernel.cu
csrc/quantization/gguf/gguf_kernel.cu
+4
-2
csrc/quantization/gguf/mmq.cuh
csrc/quantization/gguf/mmq.cuh
+35
-35
csrc/quantization/gguf/mmvq.cuh
csrc/quantization/gguf/mmvq.cuh
+2
-2
csrc/quantization/gguf/vecdotq.cuh
csrc/quantization/gguf/vecdotq.cuh
+143
-143
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+2
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+28
-25
vllm/config.py
vllm/config.py
+1
-1
No files found.
.buildkite/run-amd-test.sh
View file @
7c25fe45
...
...
@@ -85,7 +85,6 @@ if [[ $commands == *" kernels "* ]]; then
--ignore=kernels/test_encoder_decoder_attn.py
\
--ignore=kernels/test_flash_attn.py
\
--ignore=kernels/test_flashinfer.py
\
--ignore=kernels/test_gguf.py
\
--ignore=kernels/test_int8_quant.py
\
--ignore=kernels/test_machete_gemm.py
\
--ignore=kernels/test_mamba_ssm.py
\
...
...
CMakeLists.txt
View file @
7c25fe45
...
...
@@ -196,6 +196,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/torch_bindings.cpp"
)
...
...
@@ -237,7 +238,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
"csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/custom_all_reduce.cu"
"csrc/permute_cols.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
)
...
...
csrc/ops.h
View file @
7c25fe45
...
...
@@ -128,6 +128,7 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
int64_t
thx
,
int64_t
thy
);
torch
::
Tensor
permute_cols
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
perm
);
#endif
torch
::
Tensor
ggml_dequantize
(
torch
::
Tensor
W
,
int64_t
type
,
int64_t
m
,
int64_t
n
);
...
...
@@ -138,6 +139,7 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
torch
::
Tensor
ggml_mul_mat_a8
(
torch
::
Tensor
W
,
torch
::
Tensor
X
,
int64_t
type
,
int64_t
row
);
#ifndef USE_ROCM
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
);
void
cutlass_scaled_mm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
...
...
csrc/quantization/gguf/ggml-common.h
View file @
7c25fe45
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-common.h
#define QK_K 256
#define K_QUANTS_PER_ITERATION 2
#define WARP_SIZE 32
#define WARP_SIZE
_GGUF
32
#define K_SCALE_SIZE 12
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
#define CUDA_QUANTIZE_BLOCK_SIZE 256
...
...
@@ -1112,4 +1112,19 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
#endif
return
c
;
}
static
__device__
__forceinline__
uint32_t
__vcmpeq4
(
const
uint32_t
a
,
const
uint32_t
b
)
{
uint32_t
neq
=
a
^
b
;
return
!
(
neq
&
0xff000000
)
*
0xff000000
|
!
(
neq
&
0x00ff0000
)
*
0x00ff0000
|
!
(
neq
&
0x0000ff00
)
*
0x0000ff00
|
!
(
neq
&
0x000000ff
)
*
0x000000ff
;
}
static
__device__
__forceinline__
uint32_t
__vsub4
(
const
uint32_t
a
,
const
uint32_t
b
)
{
return
(
static_cast
<
uint8_t
>
(((
a
&
0xff000000
)
>>
24
)
-
((
b
&
0xff000000
)
>>
24
))
<<
24
)
+
(
static_cast
<
uint8_t
>
(((
a
&
0x00ff0000
)
>>
16
)
-
((
b
&
0x00ff0000
)
>>
16
))
<<
16
)
+
(
static_cast
<
uint8_t
>
(((
a
&
0x0000ff00
)
>>
8
)
-
((
b
&
0x0000ff00
)
>>
8
))
<<
8
)
+
(
static_cast
<
uint8_t
>
(((
a
&
0x000000ff
)
>>
0
)
-
((
b
&
0x000000ff
)
>>
0
))
<<
0
);
}
#endif // defined(USE_ROCM)
csrc/quantization/gguf/gguf_kernel.cu
View file @
7c25fe45
...
...
@@ -4,6 +4,8 @@
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "ggml-common.h"
#include "vecdotq.cuh"
#include "dequantize.cuh"
...
...
@@ -32,8 +34,8 @@ static __global__ void quantize_q8_1(const half* __restrict__ x,
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
{
amax
=
fmaxf
(
amax
,
__shfl_xor_sync
(
0xffffffff
,
amax
,
mask
,
32
));
sum
+=
__shfl_xor_sync
(
0xffffffff
,
sum
,
mask
,
32
);
amax
=
fmaxf
(
amax
,
VLLM_SHFL_XOR_SYNC_WIDTH
(
amax
,
mask
,
32
));
sum
+=
VLLM_SHFL_XOR_SYNC_WIDTH
(
sum
,
mask
,
32
);
}
const
float
d
=
amax
/
127
;
...
...
csrc/quantization/gguf/mmq.cuh
View file @
7c25fe45
...
...
@@ -10,7 +10,7 @@ static __device__ __forceinline__ void mul_mat_q(
const
int
blocks_per_row_x
=
ncols_x
/
qk
;
const
int
blocks_per_col_y
=
nrows_y
/
QK8_1
;
const
int
blocks_per_warp
=
WARP_SIZE
/
qi
;
const
int
blocks_per_warp
=
WARP_SIZE
_GGUF
/
qi
;
const
int
&
ncols_dst
=
ncols_y
;
...
...
@@ -27,10 +27,10 @@ static __device__ __forceinline__ void mul_mat_q(
allocate_tiles
(
&
tile_x_ql
,
&
tile_x_dm
,
&
tile_x_qh
,
&
tile_x_sc
);
__shared__
int
tile_y_qs
[
mmq_x
*
WARP_SIZE
];
__shared__
half2
tile_y_ds
[
mmq_x
*
WARP_SIZE
/
QI8_1
];
__shared__
int
tile_y_qs
[
mmq_x
*
WARP_SIZE
_GGUF
];
__shared__
half2
tile_y_ds
[
mmq_x
*
WARP_SIZE
_GGUF
/
QI8_1
];
float
sum
[
mmq_y
/
WARP_SIZE
][
mmq_x
/
nwarps
]
=
{{
0.0
f
}};
float
sum
[
mmq_y
/
WARP_SIZE
_GGUF
][
mmq_x
/
nwarps
]
=
{{
0.0
f
}};
for
(
int
ib0
=
0
;
ib0
<
blocks_per_row_x
;
ib0
+=
blocks_per_warp
)
{
...
...
@@ -39,26 +39,26 @@ static __device__ __forceinline__ void mul_mat_q(
#pragma unroll
for
(
int
ir
=
0
;
ir
<
qr
;
++
ir
)
{
const
int
kqs
=
ir
*
WARP_SIZE
+
threadIdx
.
x
;
const
int
kqs
=
ir
*
WARP_SIZE
_GGUF
+
threadIdx
.
x
;
const
int
kbxd
=
kqs
/
QI8_1
;
#pragma unroll
for
(
int
i
=
0
;
i
<
mmq_x
;
i
+=
nwarps
)
{
const
int
col_y_eff
=
min
(
col_y_0
+
threadIdx
.
y
+
i
,
ncols_y
-
1
);
// to prevent out-of-bounds memory accesses
const
block_q8_1
*
by0
=
&
y
[
col_y_eff
*
blocks_per_col_y
+
ib0
*
(
qk
/
QK8_1
)
+
kbxd
];
const
int
index_y
=
(
threadIdx
.
y
+
i
)
*
WARP_SIZE
+
kqs
%
WARP_SIZE
;
const
int
index_y
=
(
threadIdx
.
y
+
i
)
*
WARP_SIZE
_GGUF
+
kqs
%
WARP_SIZE
_GGUF
;
tile_y_qs
[
index_y
]
=
get_int_from_int8_aligned
(
by0
->
qs
,
threadIdx
.
x
%
QI8_1
);
}
#pragma unroll
for
(
int
ids0
=
0
;
ids0
<
mmq_x
;
ids0
+=
nwarps
*
QI8_1
)
{
const
int
ids
=
(
ids0
+
threadIdx
.
y
*
QI8_1
+
threadIdx
.
x
/
(
WARP_SIZE
/
QI8_1
))
%
mmq_x
;
const
int
kby
=
threadIdx
.
x
%
(
WARP_SIZE
/
QI8_1
);
const
int
ids
=
(
ids0
+
threadIdx
.
y
*
QI8_1
+
threadIdx
.
x
/
(
WARP_SIZE
_GGUF
/
QI8_1
))
%
mmq_x
;
const
int
kby
=
threadIdx
.
x
%
(
WARP_SIZE
_GGUF
/
QI8_1
);
const
int
col_y_eff
=
min
(
col_y_0
+
ids
,
ncols_y
-
1
);
// if the sum is not needed it's faster to transform the scale to f32 ahead of time
const
half2
*
dsi_src
=
&
y
[
col_y_eff
*
blocks_per_col_y
+
ib0
*
(
qk
/
QK8_1
)
+
ir
*
(
WARP_SIZE
/
QI8_1
)
+
kby
].
ds
;
half2
*
dsi_dst
=
&
tile_y_ds
[
ids
*
(
WARP_SIZE
/
QI8_1
)
+
kby
];
const
half2
*
dsi_src
=
&
y
[
col_y_eff
*
blocks_per_col_y
+
ib0
*
(
qk
/
QK8_1
)
+
ir
*
(
WARP_SIZE
_GGUF
/
QI8_1
)
+
kby
].
ds
;
half2
*
dsi_dst
=
&
tile_y_ds
[
ids
*
(
WARP_SIZE
_GGUF
/
QI8_1
)
+
kby
];
if
(
need_sum
)
{
*
dsi_dst
=
*
dsi_src
;
}
else
{
...
...
@@ -70,12 +70,12 @@ static __device__ __forceinline__ void mul_mat_q(
__syncthreads
();
// #pragma unroll // unrolling this loop causes too much register pressure
for
(
int
k
=
ir
*
WARP_SIZE
/
qr
;
k
<
(
ir
+
1
)
*
WARP_SIZE
/
qr
;
k
+=
vdr
)
{
for
(
int
k
=
ir
*
WARP_SIZE
_GGUF
/
qr
;
k
<
(
ir
+
1
)
*
WARP_SIZE
_GGUF
/
qr
;
k
+=
vdr
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
mmq_x
;
j
+=
nwarps
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
mmq_y
;
i
+=
WARP_SIZE
)
{
sum
[
i
/
WARP_SIZE
][
j
/
nwarps
]
+=
vec_dot
(
for
(
int
i
=
0
;
i
<
mmq_y
;
i
+=
WARP_SIZE
_GGUF
)
{
sum
[
i
/
WARP_SIZE
_GGUF
][
j
/
nwarps
]
+=
vec_dot
(
tile_x_ql
,
tile_x_dm
,
tile_x_qh
,
tile_x_sc
,
tile_y_qs
,
tile_y_ds
,
threadIdx
.
x
+
i
,
threadIdx
.
y
+
j
,
k
);
}
...
...
@@ -93,12 +93,12 @@ static __device__ __forceinline__ void mul_mat_q(
}
#pragma unroll
for
(
int
i
=
0
;
i
<
mmq_y
;
i
+=
WARP_SIZE
)
{
for
(
int
i
=
0
;
i
<
mmq_y
;
i
+=
WARP_SIZE
_GGUF
)
{
const
int
row_dst
=
row_dst_0
+
threadIdx
.
x
+
i
;
if
(
row_dst
>=
nrows_dst
)
{
continue
;
}
dst
[
col_dst
*
nrows_dst
+
row_dst
]
=
__float2half
(
sum
[
i
/
WARP_SIZE
][
j
/
nwarps
]);
dst
[
col_dst
*
nrows_dst
+
row_dst
]
=
__float2half
(
sum
[
i
/
WARP_SIZE
_GGUF
][
j
/
nwarps
]);
}
}
}
...
...
@@ -115,7 +115,7 @@ static __device__ __forceinline__ void mul_mat_q(
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q4_0
,
2
)
__launch_bounds__
(
WARP_SIZE
_GGUF
*
NWARPS_Q4_0
,
2
)
#endif
mul_mat_q4_0
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
...
...
@@ -140,7 +140,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
const
dim3
block_dims
(
WARP_SIZE
_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
...
...
@@ -165,7 +165,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q4_1
,
2
)
__launch_bounds__
(
WARP_SIZE
_GGUF
*
NWARPS_Q4_1
,
2
)
#endif
mul_mat_q4_1
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
...
...
@@ -190,7 +190,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
const
dim3
block_dims
(
WARP_SIZE
_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
...
...
@@ -215,7 +215,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q5_0
,
2
)
__launch_bounds__
(
WARP_SIZE
_GGUF
*
NWARPS_Q5_0
,
2
)
#endif
mul_mat_q5_0
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
...
...
@@ -240,7 +240,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
const
dim3
block_dims
(
WARP_SIZE
_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
...
...
@@ -265,7 +265,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q5_1
,
2
)
__launch_bounds__
(
WARP_SIZE
_GGUF
*
NWARPS_Q5_1
,
2
)
#endif
mul_mat_q5_1
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
...
...
@@ -289,7 +289,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
const
dim3
block_dims
(
WARP_SIZE
_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
...
...
@@ -314,7 +314,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q8_0
,
2
)
__launch_bounds__
(
WARP_SIZE
_GGUF
*
NWARPS_Q8_0
,
2
)
#endif
mul_mat_q8_0
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
...
...
@@ -338,7 +338,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
const
dim3
block_dims
(
WARP_SIZE
_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
...
...
@@ -363,7 +363,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q2_K
,
2
)
__launch_bounds__
(
WARP_SIZE
_GGUF
*
NWARPS_Q2_K
,
2
)
#endif
mul_mat_q2_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
...
...
@@ -387,7 +387,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
const
dim3
block_dims
(
WARP_SIZE
_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
...
...
@@ -412,7 +412,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q3_K
,
2
)
__launch_bounds__
(
WARP_SIZE
_GGUF
*
NWARPS_Q3_K
,
2
)
#endif
mul_mat_q3_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
...
...
@@ -438,7 +438,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
const
dim3
block_dims
(
WARP_SIZE
_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
...
...
@@ -463,7 +463,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q4_K
,
2
)
__launch_bounds__
(
WARP_SIZE
_GGUF
*
NWARPS_Q4_K
,
2
)
#endif
mul_mat_q4_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
...
...
@@ -487,7 +487,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
const
dim3
block_dims
(
WARP_SIZE
_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
...
...
@@ -512,7 +512,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q5_K
,
2
)
__launch_bounds__
(
WARP_SIZE
_GGUF
*
NWARPS_Q5_K
,
2
)
#endif
mul_mat_q5_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
...
...
@@ -537,7 +537,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
const
dim3
block_dims
(
WARP_SIZE
_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
...
...
@@ -562,7 +562,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q6_K
,
2
)
__launch_bounds__
(
WARP_SIZE
_GGUF
*
NWARPS_Q6_K
,
2
)
#endif
mul_mat_q6_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
...
...
@@ -586,7 +586,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
const
dim3
block_dims
(
WARP_SIZE
_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
...
...
csrc/quantization/gguf/mmvq.cuh
View file @
7c25fe45
...
...
@@ -28,8 +28,8 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
// sum up partial sums and write back result
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
{
tmp
+=
__shfl_xor_sync
(
0xffffffff
,
tmp
,
mask
,
32
);
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>
0
;
mask
>>=
1
)
{
tmp
+=
VLLM_SHFL_XOR_SYNC
(
tmp
,
mask
);
}
if
(
threadIdx
.
x
==
0
)
{
...
...
csrc/quantization/gguf/vecdotq.cuh
View file @
7c25fe45
This diff is collapsed.
Click to expand it.
csrc/torch_bindings.cpp
View file @
7c25fe45
...
...
@@ -258,6 +258,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
"SymInt size_n, int num_bits) -> Tensor"
);
// conditionally compiled so impl registrations are in source file
#endif
// Dequantization for GGML.
ops
.
def
(
"ggml_dequantize(Tensor W, int type, SymInt m, SymInt n) -> Tensor"
);
...
...
@@ -274,6 +275,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor"
);
ops
.
impl
(
"ggml_mul_mat_a8"
,
torch
::
kCUDA
,
&
ggml_mul_mat_a8
);
#ifndef USE_ROCM
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
ops
.
def
(
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
...
...
vllm/_custom_ops.py
View file @
7c25fe45
...
...
@@ -344,31 +344,6 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
is_zp_float
:
bool
=
False
)
->
torch
.
Tensor
:
return
torch
.
empty
((
size_m
,
size_n
),
device
=
a
.
device
,
dtype
=
a
.
dtype
)
@
register_fake
(
"_C::ggml_dequantize"
)
def
_ggml_dequantize_fake
(
W
:
torch
.
Tensor
,
quant_type
:
int
,
m
:
torch
.
SymInt
,
n
:
torch
.
SymInt
)
->
torch
.
Tensor
:
return
torch
.
empty
((
m
,
n
),
dtype
=
torch
.
float16
,
device
=
W
.
device
)
@
register_fake
(
"_C::ggml_mul_mat_vec_a8"
)
def
_ggml_mul_mat_vec_a8_fake
(
W
:
torch
.
Tensor
,
X
:
torch
.
Tensor
,
quant_type
:
int
,
row
:
torch
.
SymInt
,
)
->
torch
.
Tensor
:
return
torch
.
empty
((
1
,
row
),
dtype
=
torch
.
float16
,
device
=
W
.
device
)
@
register_fake
(
"_C::ggml_mul_mat_a8"
)
def
_ggml_mul_mat_a8_fake
(
W
:
torch
.
Tensor
,
X
:
torch
.
Tensor
,
quant_type
:
int
,
row
:
torch
.
SymInt
,
)
->
torch
.
Tensor
:
batch
=
X
.
size
(
0
)
return
torch
.
empty
((
batch
,
row
),
dtype
=
torch
.
float16
,
device
=
W
.
device
)
@
register_fake
(
"_C::marlin_qqq_gemm"
)
def
_marlin_qqq_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
s_tok
:
torch
.
Tensor
,
s_ch
:
torch
.
Tensor
,
...
...
@@ -468,6 +443,34 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
memory_format
=
torch
.
contiguous_format
)
if
hasattr
(
torch
.
ops
.
_C
,
"ggml_dequantize"
):
@
register_fake
(
"_C::ggml_dequantize"
)
def
_ggml_dequantize_fake
(
W
:
torch
.
Tensor
,
quant_type
:
int
,
m
:
torch
.
SymInt
,
n
:
torch
.
SymInt
)
->
torch
.
Tensor
:
return
torch
.
empty
((
m
,
n
),
dtype
=
torch
.
float16
,
device
=
W
.
device
)
@
register_fake
(
"_C::ggml_mul_mat_vec_a8"
)
def
_ggml_mul_mat_vec_a8_fake
(
W
:
torch
.
Tensor
,
X
:
torch
.
Tensor
,
quant_type
:
int
,
row
:
torch
.
SymInt
,
)
->
torch
.
Tensor
:
return
torch
.
empty
((
1
,
row
),
dtype
=
torch
.
float16
,
device
=
W
.
device
)
@
register_fake
(
"_C::ggml_mul_mat_a8"
)
def
_ggml_mul_mat_a8_fake
(
W
:
torch
.
Tensor
,
X
:
torch
.
Tensor
,
quant_type
:
int
,
row
:
torch
.
SymInt
,
)
->
torch
.
Tensor
:
batch
=
X
.
size
(
0
)
return
torch
.
empty
((
batch
,
row
),
dtype
=
torch
.
float16
,
device
=
W
.
device
)
# cutlass
def
cutlass_scaled_mm_supports_fp8
(
cuda_device_capability
:
int
)
->
bool
:
return
torch
.
ops
.
_C
.
cutlass_scaled_mm_supports_fp8
(
cuda_device_capability
)
...
...
vllm/config.py
View file @
7c25fe45
...
...
@@ -387,7 +387,7 @@ class ModelConfig:
supported_quantization
=
QUANTIZATION_METHODS
rocm_supported_quantization
=
[
"awq"
,
"gptq"
,
"fp8"
,
"compressed_tensors"
,
"compressed-tensors"
,
"fbgemm_fp8"
"fbgemm_fp8"
,
"gguf"
]
optimized_quantization_methods
=
[
"fp8"
,
"marlin"
,
"modelopt"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment