Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
37fd47e7
Unverified
Commit
37fd47e7
authored
Aug 16, 2024
by
bnellnm
Committed by
GitHub
Aug 16, 2024
Browse files
[Kernel] fix types used in aqlm and ggml kernels to support dynamo (#7596)
parent
7759ae95
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
39 additions
and
53 deletions
+39
-53
csrc/ops.h
csrc/ops.h
+8
-8
csrc/quantization/aqlm/gemm_kernels.cu
csrc/quantization/aqlm/gemm_kernels.cu
+13
-12
csrc/quantization/gguf/dequantize.cuh
csrc/quantization/gguf/dequantize.cuh
+1
-1
csrc/quantization/gguf/gguf_kernel.cu
csrc/quantization/gguf/gguf_kernel.cu
+4
-4
vllm/_custom_ops.py
vllm/_custom_ops.py
+7
-21
vllm/model_executor/layers/quantization/aqlm.py
vllm/model_executor/layers/quantization/aqlm.py
+5
-7
vllm/model_executor/layers/quantization/gptq.py
vllm/model_executor/layers/quantization/gptq.py
+1
-0
No files found.
csrc/ops.h
View file @
37fd47e7
...
...
@@ -63,12 +63,12 @@ void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
torch
::
Tensor
aqlm_gemm
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
scales
,
const
torch
::
Tensor
&
codebook_partition_sizes
,
const
std
::
vector
<
int64_t
>
&
codebook_partition_sizes
,
const
std
::
optional
<
torch
::
Tensor
>&
bias
);
torch
::
Tensor
aqlm_dequant
(
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
codebook_partition_sizes
);
torch
::
Tensor
aqlm_dequant
(
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
std
::
vector
<
int64_t
>
&
codebook_partition_sizes
);
torch
::
Tensor
awq_gemm
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
...
...
@@ -107,13 +107,13 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
torch
::
Tensor
awq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
torch
::
Tensor
ggml_dequantize
(
torch
::
Tensor
W
,
int
8
_t
type
,
int64_t
m
,
torch
::
Tensor
ggml_dequantize
(
torch
::
Tensor
W
,
int
64
_t
type
,
int64_t
m
,
int64_t
n
);
torch
::
Tensor
ggml_mul_mat_vec_a8
(
torch
::
Tensor
W
,
torch
::
Tensor
X
,
int8_t
type
,
int64_t
row
);
torch
::
Tensor
ggml_mul_mat_vec_a8
(
torch
::
Tensor
W
,
torch
::
Tensor
X
,
int64_t
type
,
int64_t
row
);
torch
::
Tensor
ggml_mul_mat_a8
(
torch
::
Tensor
W
,
torch
::
Tensor
X
,
int
8
_t
type
,
torch
::
Tensor
ggml_mul_mat_a8
(
torch
::
Tensor
W
,
torch
::
Tensor
X
,
int
64
_t
type
,
int64_t
row
);
torch
::
Tensor
fp8_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
...
...
csrc/quantization/aqlm/gemm_kernels.cu
View file @
37fd47e7
...
...
@@ -496,14 +496,14 @@ torch::Tensor code2x8_matmat(const torch::Tensor& input,
}
// Accumulate the partition sizes.
int4
accumulate_sizes
(
const
torch
::
Tensor
&
codebook_partition_sizes
)
{
int4
accumulate_sizes
(
const
std
::
vector
<
int64_t
>
&
codebook_partition_sizes
)
{
int4
cumulative_sizes
;
auto
cumulative_size
=
&
cumulative_sizes
.
x
;
in
t
i
=
0
;
size_
t
i
=
0
;
int
last
=
0
;
assert
(
codebook_partition_sizes
.
size
(
0
)
<=
4
);
for
(;
i
<
codebook_partition_sizes
.
size
(
0
);
++
i
,
++
cumulative_size
)
{
*
cumulative_size
=
codebook_partition_sizes
[
i
]
.
item
<
int
>
()
+
last
;
assert
(
codebook_partition_sizes
.
size
()
<=
4
);
for
(;
i
<
codebook_partition_sizes
.
size
();
++
i
,
++
cumulative_size
)
{
*
cumulative_size
=
codebook_partition_sizes
[
i
]
+
last
;
last
=
*
cumulative_size
;
}
// fill in the rest with unreachable.
...
...
@@ -519,12 +519,12 @@ int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) {
torch
::
Tensor
aqlm_gemm
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
scales
,
const
torch
::
Tensor
&
codebook_partition_sizes
,
const
std
::
vector
<
int64_t
>
&
codebook_partition_sizes
,
const
std
::
optional
<
torch
::
Tensor
>&
bias
)
{
int4
cumulative_sizes
=
vllm
::
aqlm
::
accumulate_sizes
(
codebook_partition_sizes
);
int
const
nbooks
=
codebooks
.
size
(
0
)
/
codebook_partition_sizes
.
size
(
0
);
int
const
nbooks
=
codebooks
.
size
(
0
)
/
codebook_partition_sizes
.
size
();
int
const
entries
=
codebooks
.
size
(
1
);
if
(
nbooks
==
1
&&
entries
==
(
1
<<
16
))
{
...
...
@@ -541,13 +541,13 @@ torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
return
{};
}
torch
::
Tensor
aqlm_dequant
(
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
codebook_partition_sizes
)
{
torch
::
Tensor
aqlm_dequant
(
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
std
::
vector
<
int64_t
>
&
codebook_partition_sizes
)
{
int4
cumulative_sizes
=
vllm
::
aqlm
::
accumulate_sizes
(
codebook_partition_sizes
);
int
const
nbooks
=
codebooks
.
size
(
0
)
/
codebook_partition_sizes
.
size
(
0
);
int
const
nbooks
=
codebooks
.
size
(
0
)
/
codebook_partition_sizes
.
size
();
int
const
entries
=
codebooks
.
size
(
1
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
codes
));
...
...
@@ -557,7 +557,8 @@ torch::Tensor aqlm_dequant(const torch::Tensor& codes,
auto
in_features
=
codes
.
size
(
1
)
*
8
;
auto
out_features
=
codes
.
size
(
0
);
assert
(
out_features
=
codebook_partition_sizes
.
sum
().
item
<
int
>
());
assert
(
out_features
==
std
::
accumulate
(
codebook_partition_sizes
.
begin
(),
codebook_partition_sizes
.
end
(),
0
));
auto
weights
=
torch
::
empty
({
out_features
,
in_features
},
torch
::
TensorOptions
()
...
...
csrc/quantization/gguf/dequantize.cuh
View file @
37fd47e7
...
...
@@ -487,7 +487,7 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k,
dequantize_block_iq4_xs
<<<
nb
,
32
,
0
,
stream
>>>
(
vx
,
y
);
}
static
to_fp16_cuda_t
ggml_get_to_fp16_cuda
(
int
type
)
{
static
to_fp16_cuda_t
ggml_get_to_fp16_cuda
(
int
64_t
type
)
{
switch
(
type
)
{
case
2
:
return
dequantize_block_cuda
<
QK4_0
,
QR4_0
,
dequantize_q4_0
>
;
...
...
csrc/quantization/gguf/gguf_kernel.cu
View file @
37fd47e7
...
...
@@ -60,7 +60,7 @@ static void quantize_row_q8_1_cuda(const half* x, void* vy, const int kx,
}
torch
::
Tensor
ggml_dequantize
(
torch
::
Tensor
W
,
// quant weight
int
8
_t
type
,
int64_t
m
,
int64_t
n
)
{
int
64
_t
type
,
int64_t
m
,
int64_t
n
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
W
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat16
).
device
(
W
.
device
());
...
...
@@ -73,7 +73,7 @@ torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
torch
::
Tensor
ggml_mul_mat_vec_a8
(
torch
::
Tensor
W
,
// quant weight
torch
::
Tensor
X
,
// input
int
8
_t
type
,
int64_t
row
)
{
int
64
_t
type
,
int64_t
row
)
{
int
col
=
X
.
sizes
()[
1
];
const
int
padded
=
(
col
+
512
-
1
)
/
512
*
512
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
X
));
...
...
@@ -172,7 +172,7 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
torch
::
Tensor
ggml_mul_mat_a8
(
torch
::
Tensor
W
,
// quant weight
torch
::
Tensor
X
,
// input
int
8
_t
type
,
int64_t
row
)
{
int
64
_t
type
,
int64_t
row
)
{
int
col
=
X
.
sizes
()[
1
];
int
padded
=
(
col
+
512
-
1
)
/
512
*
512
;
int
batch
=
X
.
sizes
()[
0
];
...
...
@@ -239,4 +239,4 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
break
;
}
return
Y
;
}
\ No newline at end of file
}
vllm/_custom_ops.py
View file @
37fd47e7
...
...
@@ -17,13 +17,7 @@ if not current_platform.is_tpu():
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
with
contextlib
.
suppress
(
ImportError
):
# ruff: noqa: F401
import
vllm._moe_C
def
is_custom_op_supported
(
op_name
:
str
)
->
bool
:
op
,
overloads
=
torch
.
_C
.
_jit_get_operation
(
op_name
)
return
op
is
not
None
import
vllm._moe_C
# noqa: F401
def
hint_on_error
(
fn
):
...
...
@@ -280,14 +274,14 @@ def cutlass_scaled_mm_azp(a: torch.Tensor,
# aqlm
def
aqlm_gemm
(
input
:
torch
.
Tensor
,
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
codebook_partition_sizes
:
torch
.
Tensor
,
codebook_partition_sizes
:
List
[
int
]
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
aqlm_gemm
(
input
,
codes
,
codebooks
,
scales
,
codebook_partition_sizes
,
bias
)
def
aqlm_dequant
(
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
codebook_partition_sizes
:
torch
.
Tensor
)
->
torch
.
Tensor
:
codebook_partition_sizes
:
List
[
int
]
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
aqlm_dequant
(
codes
,
codebooks
,
codebook_partition_sizes
)
...
...
@@ -434,25 +428,17 @@ def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# gguf
def
ggml_dequantize
(
W
:
torch
.
Tensor
,
quant_type
:
int
,
m
:
int
,
n
:
int
):
def
ggml_dequantize
(
W
:
torch
.
Tensor
,
quant_type
:
int
,
m
:
int
,
n
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
ggml_dequantize
(
W
,
quant_type
,
m
,
n
)
def
ggml_mul_mat_vec
(
W
:
torch
.
Tensor
,
X
:
torch
.
Tensor
,
quant_type
:
int
,
row
:
int
,
):
return
torch
.
ops
.
_C
.
ggml_mul_mat_vec
(
W
,
X
,
quant_type
,
row
)
def
ggml_mul_mat_vec_a8
(
W
:
torch
.
Tensor
,
X
:
torch
.
Tensor
,
quant_type
:
int
,
row
:
int
,
):
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
ggml_mul_mat_vec_a8
(
W
,
X
,
quant_type
,
row
)
...
...
@@ -461,7 +447,7 @@ def ggml_mul_mat_a8(
X
:
torch
.
Tensor
,
quant_type
:
int
,
row
:
int
,
):
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
ggml_mul_mat_a8
(
W
,
X
,
quant_type
,
row
)
...
...
vllm/model_executor/layers/quantization/aqlm.py
View file @
37fd47e7
...
...
@@ -95,7 +95,7 @@ def generic_dequantize_gemm(
codebooks
:
torch
.
Tensor
,
# [num_codebooks, codebook_size, out_group_size, in_group_size]
scales
:
torch
.
Tensor
,
# [num_out_groups, 1, 1, 1]
output_partition_sizes
:
torch
.
IntTensor
,
output_partition_sizes
:
List
[
int
]
,
bias
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
output_shape
=
input
.
shape
[:
-
1
]
+
(
scales
.
shape
[
0
],
)
...
...
@@ -133,7 +133,7 @@ def optimized_dequantize_gemm(
codebooks
:
torch
.
Tensor
,
# [num_codebooks, codebook_size, out_group_size, in_group_size]
scales
:
torch
.
Tensor
,
# [num_out_groups, 1, 1, 1]
output_partition_sizes
:
torch
.
IntTensor
,
output_partition_sizes
:
List
[
int
]
,
bias
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
weights
=
ops
.
aqlm_dequant
(
codes
,
codebooks
,
output_partition_sizes
)
...
...
@@ -288,10 +288,8 @@ class AQLMLinearMethod(LinearMethodBase):
codebooks
,
{
# metadata indicates fixed size concatenated along dim 0
"is_metadata"
:
True
,
"output_partition_sizes"
:
torch
.
tensor
(
output_partition_sizes
,
device
=
'cpu'
),
"is_metadata"
:
True
,
"output_partition_sizes"
:
output_partition_sizes
},
)
...
...
@@ -334,7 +332,7 @@ class AQLMLinearMethod(LinearMethodBase):
codes
=
layer
.
codes
scales
=
layer
.
scales
output_partition_sizes
=
getattr
(
codebooks
,
"output_partition_sizes"
,
None
)
[]
)
nbooks
=
codes
.
shape
[
2
]
ingroups
=
codebooks
.
shape
[
3
]
...
...
vllm/model_executor/layers/quantization/gptq.py
View file @
37fd47e7
...
...
@@ -212,6 +212,7 @@ class GPTQLinearMethod(LinearMethodBase):
layer
.
g_idx
.
data
=
torch
.
argsort
(
layer
.
g_idx
).
to
(
torch
.
int
)
else
:
layer
.
g_idx
.
data
=
torch
.
empty
((
0
,
),
dtype
=
torch
.
int
,
device
=
layer
.
g_idx
.
device
)
layer
.
exllama_state
=
ExllamaState
.
READY
ops
.
gptq_shuffle
(
layer
.
qweight
,
layer
.
g_idx
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment