Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bf03ff35
Unverified
Commit
bf03ff35
authored
Jul 09, 2025
by
Jacob Manning
Committed by
GitHub
Jul 09, 2025
Browse files
[Kernel] Add Conch backend for mixed-precision linear layer (#19818)
Signed-off-by:
Jacob Manning
<
jmanning+oss@stackav.com
>
parent
47043eb6
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
105 additions
and
1 deletion
+105
-1
requirements/rocm.txt
requirements/rocm.txt
+1
-0
vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py
...r/layers/quantization/kernels/mixed_precision/__init__.py
+4
-1
vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py
...utor/layers/quantization/kernels/mixed_precision/conch.py
+92
-0
vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py
...or/layers/quantization/kernels/mixed_precision/machete.py
+4
-0
vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py
...tor/layers/quantization/kernels/mixed_precision/marlin.py
+4
-0
No files found.
requirements/rocm.txt
View file @
bf03ff35
...
@@ -17,3 +17,4 @@ setuptools>=77.0.3,<80.0.0
...
@@ -17,3 +17,4 @@ setuptools>=77.0.3,<80.0.0
setuptools-scm>=8
setuptools-scm>=8
runai-model-streamer==0.11.0
runai-model-streamer==0.11.0
runai-model-streamer-s3==0.11.0
runai-model-streamer-s3==0.11.0
conch-triton-kernels==1.2.1
vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py
View file @
bf03ff35
...
@@ -8,6 +8,8 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark im
...
@@ -8,6 +8,8 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark im
AllSparkLinearKernel
)
AllSparkLinearKernel
)
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas
import
(
# noqa: E501
BitBLASLinearKernel
)
BitBLASLinearKernel
)
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.conch
import
(
# noqa: E501
ConchLinearKernel
)
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama
import
(
# noqa: E501
ExllamaLinearKernel
)
ExllamaLinearKernel
)
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.machete
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.machete
import
(
# noqa: E501
...
@@ -24,6 +26,7 @@ _POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
...
@@ -24,6 +26,7 @@ _POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
AllSparkLinearKernel
,
AllSparkLinearKernel
,
MarlinLinearKernel
,
MarlinLinearKernel
,
BitBLASLinearKernel
,
BitBLASLinearKernel
,
ConchLinearKernel
,
ExllamaLinearKernel
,
ExllamaLinearKernel
,
]
]
...
@@ -80,4 +83,4 @@ def choose_mp_linear_kernel(
...
@@ -80,4 +83,4 @@ def choose_mp_linear_kernel(
raise
ValueError
(
raise
ValueError
(
"Failed to find a kernel that can implement the "
\
"Failed to find a kernel that can implement the "
\
"WNA16 linear layer. Reasons:
\n
"
"WNA16 linear layer. Reasons:
\n
"
+
'
\n
'
.
join
(
failure_reasons
))
+
'
\n
'
.
join
(
failure_reasons
))
\ No newline at end of file
vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py
0 → 100644
View file @
bf03ff35
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
importlib.util
import
find_spec
from
typing
import
Final
,
Optional
import
torch
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
permute_param_layout_
)
from
vllm.scalar_type
import
scalar_types
from
.MPLinearKernel
import
MPLinearKernel
,
MPLinearLayerConfig
_CONCH_SUPPORTED_WEIGHT_TYPES
:
Final
=
[
scalar_types
.
uint4
,
scalar_types
.
uint8
,
scalar_types
.
uint4b8
,
scalar_types
.
uint8b128
]
_CONCH_SUPPORTED_GROUP_SIZES
:
Final
=
[
-
1
,
128
]
class
ConchLinearKernel
(
MPLinearKernel
):
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
can_implement
(
cls
,
c
:
MPLinearLayerConfig
)
->
tuple
[
bool
,
Optional
[
str
]]:
if
c
.
weight_type
not
in
_CONCH_SUPPORTED_WEIGHT_TYPES
:
error_msg
=
f
"Weight type (
{
c
.
weight_type
}
) not supported by "
\
"ConchLinearKernel, supported types are: "
\
f
"
{
_CONCH_SUPPORTED_WEIGHT_TYPES
}
"
return
False
,
error_msg
if
c
.
group_size
not
in
_CONCH_SUPPORTED_GROUP_SIZES
:
error_msg
=
f
"Group size (
{
c
.
group_size
}
) not supported by "
\
"ConchLinearKernel, supported group sizes are: "
\
f
"
{
_CONCH_SUPPORTED_GROUP_SIZES
}
"
return
False
,
error_msg
if
find_spec
(
"conch"
)
is
None
:
error_msg
=
"conch-triton-kernels is not installed, please "
\
"install it via `pip install conch-triton-kernels` "
\
"and try again!"
return
False
,
error_msg
return
True
,
None
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
transform_w_q
(
x
):
assert
isinstance
(
x
,
BasevLLMParameter
)
permute_param_layout_
(
x
,
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
0
)
x
.
data
=
x
.
data
.
contiguous
()
return
x
def
transform_w_s
(
x
):
assert
isinstance
(
x
,
BasevLLMParameter
)
permute_param_layout_
(
x
,
input_dim
=
0
,
output_dim
=
1
)
x
.
data
=
x
.
data
.
contiguous
()
return
x
self
.
_transform_param
(
layer
,
self
.
w_q_name
,
transform_w_q
)
self
.
_transform_param
(
layer
,
self
.
w_s_name
,
transform_w_s
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
from
conch.ops.quantization.gemm
import
mixed_precision_gemm
w_q
,
w_s
,
w_zp
,
_
=
self
.
_get_weight_params
(
layer
)
output
=
mixed_precision_gemm
(
x
=
x
,
w_q_packed
=
w_q
.
data
,
w_s
=
w_s
.
data
,
w_zp
=
w_zp
.
data
if
w_zp
is
not
None
else
None
,
weight_size_bits
=
self
.
config
.
weight_type
.
size_bits
,
weight_bias
=
self
.
config
.
weight_type
.
bias
,
group_size
=
self
.
config
.
group_size
,
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py
View file @
bf03ff35
...
@@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
...
@@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_quantized_values_into_int32
,
unpack_quantized_values_into_int32
)
pack_quantized_values_into_int32
,
unpack_quantized_values_into_int32
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
permute_param_layout_
)
permute_param_layout_
)
from
vllm.platforms
import
current_platform
from
.MPLinearKernel
import
MPLinearKernel
,
MPLinearLayerConfig
from
.MPLinearKernel
import
MPLinearKernel
,
MPLinearLayerConfig
...
@@ -27,6 +28,9 @@ class MacheteLinearKernel(MPLinearKernel):
...
@@ -27,6 +28,9 @@ class MacheteLinearKernel(MPLinearKernel):
@
classmethod
@
classmethod
def
can_implement
(
cls
,
def
can_implement
(
cls
,
c
:
MPLinearLayerConfig
)
->
tuple
[
bool
,
Optional
[
str
]]:
c
:
MPLinearLayerConfig
)
->
tuple
[
bool
,
Optional
[
str
]]:
# Machete uses CUTLASS, so it can only be compatible with Nvidia
if
not
current_platform
.
is_cuda
():
return
False
,
"Machete only supported on CUDA"
if
c
.
has_g_idx
and
\
if
c
.
has_g_idx
and
\
c
.
partition_weight_shape
[
0
]
!=
c
.
full_weight_shape
[
0
]:
c
.
partition_weight_shape
[
0
]
!=
c
.
full_weight_shape
[
0
]:
...
...
vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py
View file @
bf03ff35
...
@@ -13,6 +13,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
...
@@ -13,6 +13,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_zero_points
,
query_marlin_supported_quant_types
,
unpack_cols
)
marlin_zero_points
,
query_marlin_supported_quant_types
,
unpack_cols
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
permute_param_layout_
)
permute_param_layout_
)
from
vllm.platforms
import
current_platform
from
.MPLinearKernel
import
MPLinearKernel
,
MPLinearLayerConfig
from
.MPLinearKernel
import
MPLinearKernel
,
MPLinearLayerConfig
...
@@ -26,6 +27,9 @@ class MarlinLinearKernel(MPLinearKernel):
...
@@ -26,6 +27,9 @@ class MarlinLinearKernel(MPLinearKernel):
@
classmethod
@
classmethod
def
can_implement
(
cls
,
def
can_implement
(
cls
,
c
:
MPLinearLayerConfig
)
->
tuple
[
bool
,
Optional
[
str
]]:
c
:
MPLinearLayerConfig
)
->
tuple
[
bool
,
Optional
[
str
]]:
# Marlin uses inline PTX, so it can only be compatible with Nvidia
if
not
current_platform
.
is_cuda
():
return
False
,
"Marlin only supported on CUDA"
quant_types
=
query_marlin_supported_quant_types
(
c
.
zero_points
)
quant_types
=
query_marlin_supported_quant_types
(
c
.
zero_points
)
if
c
.
weight_type
not
in
quant_types
:
if
c
.
weight_type
not
in
quant_types
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment