Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
539aa992
Commit
539aa992
authored
Sep 27, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.2' into v0.6.2-dev
parents
93872128
7193774b
Changes
383
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
657 additions
and
292 deletions
+657
-292
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+2
-3
vllm/model_executor/layers/quantization/gguf.py
vllm/model_executor/layers/quantization/gguf.py
+4
-1
vllm/model_executor/layers/quantization/gptq.py
vllm/model_executor/layers/quantization/gptq.py
+1
-0
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+49
-93
vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py
...el_executor/layers/quantization/kernels/MPLinearKernel.py
+83
-0
vllm/model_executor/layers/quantization/kernels/__init__.py
vllm/model_executor/layers/quantization/kernels/__init__.py
+72
-0
vllm/model_executor/layers/quantization/kernels/machete.py
vllm/model_executor/layers/quantization/kernels/machete.py
+118
-0
vllm/model_executor/layers/quantization/kernels/marlin.py
vllm/model_executor/layers/quantization/kernels/marlin.py
+132
-0
vllm/model_executor/layers/quantization/qqq.py
vllm/model_executor/layers/quantization/qqq.py
+1
-1
vllm/model_executor/layers/quantization/utils/__init__.py
vllm/model_executor/layers/quantization/utils/__init__.py
+3
-0
vllm/model_executor/layers/quantization/utils/layer_utils.py
vllm/model_executor/layers/quantization/utils/layer_utils.py
+37
-0
vllm/model_executor/layers/quantization/utils/machete_utils.py
...model_executor/layers/quantization/utils/machete_utils.py
+30
-0
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+24
-15
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
...el_executor/layers/quantization/utils/marlin_utils_fp8.py
+1
-2
vllm/model_executor/layers/quantization/utils/quant_utils.py
vllm/model_executor/layers/quantization/utils/quant_utils.py
+43
-0
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+32
-24
vllm/model_executor/layers/rejection_sampler.py
vllm/model_executor/layers/rejection_sampler.py
+1
-8
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+12
-105
vllm/model_executor/layers/spec_decode_base_sampler.py
vllm/model_executor/layers/spec_decode_base_sampler.py
+1
-14
vllm/model_executor/layers/typical_acceptance_sampler.py
vllm/model_executor/layers/typical_acceptance_sampler.py
+11
-26
No files found.
vllm/model_executor/layers/quantization/fp8.py
View file @
539aa992
...
@@ -120,9 +120,8 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -120,9 +120,8 @@ class Fp8LinearMethod(LinearMethodBase):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
# kernel for fast weight-only FP8 quantization
capability
=
current_platform
.
get_device_capability
()
self
.
use_marlin
=
(
not
current_platform
.
has_device_capability
(
89
)
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
or
envs
.
VLLM_TEST_FORCE_FP8_MARLIN
)
self
.
use_marlin
=
capability
<
89
or
envs
.
VLLM_TEST_FORCE_FP8_MARLIN
# Disable marlin for rocm
# Disable marlin for rocm
if
is_hip
():
if
is_hip
():
self
.
use_marlin
=
False
self
.
use_marlin
=
False
...
...
vllm/model_executor/layers/quantization/gguf.py
View file @
539aa992
...
@@ -55,7 +55,10 @@ class GGUFConfig(QuantizationConfig):
...
@@ -55,7 +55,10 @@ class GGUFConfig(QuantizationConfig):
def
_fuse_mul_mat
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
def
_fuse_mul_mat
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
qweight_type
:
int
)
->
torch
.
Tensor
:
qweight_type
:
int
)
->
torch
.
Tensor
:
# use dequantize mulmat for IQmatrix, mmq for k-quants
# use dequantize mulmat for IQmatrix, mmq for k-quants
if
qweight_type
>=
16
:
if
x
.
shape
[
0
]
==
1
:
# enable mmvq in contiguous batching
y
=
ops
.
ggml_mul_mat_vec_a8
(
qweight
,
x
,
qweight_type
,
qweight
.
shape
[
0
])
elif
qweight_type
>=
16
:
block_size
,
type_size
=
gguf
.
GGML_QUANT_SIZES
[
qweight_type
]
block_size
,
type_size
=
gguf
.
GGML_QUANT_SIZES
[
qweight_type
]
shape
=
(
qweight
.
shape
[
0
],
qweight
.
shape
[
1
]
//
type_size
*
block_size
)
shape
=
(
qweight
.
shape
[
0
],
qweight
.
shape
[
1
]
//
type_size
*
block_size
)
weight
=
ops
.
ggml_dequantize
(
qweight
,
qweight_type
,
*
shape
)
weight
=
ops
.
ggml_dequantize
(
qweight
,
qweight_type
,
*
shape
)
...
...
vllm/model_executor/layers/quantization/gptq.py
View file @
539aa992
...
@@ -217,6 +217,7 @@ class GPTQLinearMethod(LinearMethodBase):
...
@@ -217,6 +217,7 @@ class GPTQLinearMethod(LinearMethodBase):
layer
.
qzeros
=
Parameter
(
layer
.
qzeros
.
data
,
requires_grad
=
False
)
layer
.
qzeros
=
Parameter
(
layer
.
qzeros
.
data
,
requires_grad
=
False
)
layer
.
qweight
=
Parameter
(
layer
.
qweight
.
data
,
requires_grad
=
False
)
layer
.
qweight
=
Parameter
(
layer
.
qweight
.
data
,
requires_grad
=
False
)
layer
.
g_idx
=
Parameter
(
layer
.
g_idx
.
data
,
requires_grad
=
False
)
layer
.
g_idx
=
Parameter
(
layer
.
g_idx
.
data
,
requires_grad
=
False
)
layer
.
scales
=
Parameter
(
layer
.
scales
.
data
,
requires_grad
=
False
)
# exllama needs to shuffle the weight after the weight is loaded
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
# here we do the shuffle on first forward pass
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
539aa992
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Union
import
torch
import
torch
from
torch.nn
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -11,12 +10,12 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
...
@@ -11,12 +10,12 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs
)
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.kernels
import
(
MPLinearLayerConfig
,
choose_mp_linear_kernel
)
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_gptq_marlin_linear
,
check_marlin_supported
,
marlin_is_k_full
,
check_marlin_supported
,
marlin_moe_permute_scales
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_moe_permute_scales
,
marlin_repeat_scales_on_all_ranks
,
verify_marlin_supported
)
marlin_permute_scales
,
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
GroupQuantScaleParameter
,
...
@@ -132,10 +131,10 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -132,10 +131,10 @@ class GPTQMarlinConfig(QuantizationConfig):
def
is_gptq_marlin_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
def
is_gptq_marlin_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
# Extract data from quant config.
# Extract data from quant config.
quant_method
=
quant_config
.
get
(
"quant_method"
,
""
).
lower
()
quant_method
=
quant_config
.
get
(
"quant_method"
,
""
).
lower
()
num_bits
=
quant_config
.
get
(
"bits"
,
None
)
num_bits
=
quant_config
.
get
(
"bits"
)
group_size
=
quant_config
.
get
(
"group_size"
,
None
)
group_size
=
quant_config
.
get
(
"group_size"
)
sym
=
quant_config
.
get
(
"sym"
,
None
)
sym
=
quant_config
.
get
(
"sym"
)
desc_act
=
quant_config
.
get
(
"desc_act"
,
None
)
desc_act
=
quant_config
.
get
(
"desc_act"
)
if
quant_method
!=
"gptq"
:
if
quant_method
!=
"gptq"
:
return
False
return
False
...
@@ -159,6 +158,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -159,6 +158,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
quant_config: The GPTQ Marlin quantization config.
quant_config: The GPTQ Marlin quantization config.
"""
"""
_kernel_backends_being_used
:
Set
[
str
]
=
set
()
def
__init__
(
self
,
quant_config
:
GPTQMarlinConfig
)
->
None
:
def
__init__
(
self
,
quant_config
:
GPTQMarlinConfig
)
->
None
:
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
...
@@ -176,25 +177,34 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -176,25 +177,34 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
**
extra_weight_attrs
,
)
->
None
:
)
->
None
:
del
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
is_row_parallel
=
input_size
!=
input_size_per_partition
is_row_parallel
=
input_size
!=
input_size_per_partition
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
mp_linear_kernel_config
=
MPLinearLayerConfig
(
full_weight_shape
=
(
input_size
,
output_size
),
partition_weight_shape
=
\
(
input_size_per_partition
,
output_size_per_partition
),
weight_type
=
self
.
quant_config
.
quant_type
,
act_type
=
params_dtype
,
group_size
=
self
.
quant_config
.
group_size
,
zero_points
=
False
,
has_g_idx
=
self
.
quant_config
.
desc_act
)
kernel_type
=
choose_mp_linear_kernel
(
mp_linear_kernel_config
)
if
kernel_type
.
__name__
not
in
self
.
_kernel_backends_being_used
:
logger
.
info
(
"Using %s for GPTQMarlinLinearMethod"
,
kernel_type
.
__name__
)
self
.
_kernel_backends_being_used
.
add
(
kernel_type
.
__name__
)
# Normalize group_size
# Normalize group_size
if
self
.
quant_config
.
group_size
!=
-
1
:
if
self
.
quant_config
.
group_size
!=
-
1
:
group_size
=
self
.
quant_config
.
group_size
group_size
=
self
.
quant_config
.
group_size
else
:
else
:
group_size
=
input_size
group_size
=
input_size
verify_marlin_supports_shape
(
output_size_per_partition
=
output_size_per_partition
,
input_size_per_partition
=
input_size_per_partition
,
input_size
=
input_size
,
group_size
=
group_size
,
)
# Determine sharding
# Determine sharding
if
marlin_repeat_scales_on_all_ranks
(
self
.
quant_config
.
desc_act
,
if
marlin_repeat_scales_on_all_ranks
(
self
.
quant_config
.
desc_act
,
self
.
quant_config
.
group_size
,
self
.
quant_config
.
group_size
,
...
@@ -275,57 +285,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -275,57 +285,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"g_idx"
,
g_idx
)
layer
.
register_parameter
(
"g_idx"
,
g_idx
)
layer
.
register_parameter
(
"scales"
,
scales
)
layer
.
register_parameter
(
"scales"
,
scales
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
input_size
=
input_size
layer
.
is_k_full
=
marlin_is_k_full
(
self
.
quant_config
.
desc_act
,
is_row_parallel
)
# Checkpoints are serialized in AutoGPTQ format, which is different from the
# marlin format. This function is called after the weights are loaded.
# Here, we handle the repacking, including the activation reordering case.
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
device
=
layer
.
qweight
.
device
# required by torch.compile
self
.
kernel
=
kernel_type
(
mp_linear_kernel_config
,
layer
.
qweight
=
Parameter
(
layer
.
qweight
.
data
,
requires_grad
=
False
)
w_q_param_name
=
"qweight"
,
layer
.
scales
=
Parameter
(
layer
.
scales
.
data
,
requires_grad
=
False
)
w_s_param_name
=
"scales"
,
w_zp_param_name
=
"qzeros"
,
w_gidx_param_name
=
"g_idx"
)
# Allocate marlin workspace
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
workspace
=
marlin_make_workspace
(
self
.
kernel
.
process_weights_after_loading
(
layer
)
layer
.
output_size_per_partition
,
device
)
# Handle sorting for activation reordering if needed.
if
self
.
quant_config
.
desc_act
:
g_idx
,
g_idx_sort_indices
=
marlin_sort_g_idx
(
layer
.
g_idx
)
layer
.
g_idx_sort_indices
=
g_idx_sort_indices
replace_tensor
(
layer
,
"g_idx"
,
g_idx
)
else
:
layer
.
g_idx
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
# No zero-point
layer
.
zp
=
marlin_make_empty_g_idx
(
device
)
# Repack weights from autogptq format to marlin format.
marlin_qweight
=
ops
.
gptq_marlin_repack
(
layer
.
qweight
,
perm
=
layer
.
g_idx_sort_indices
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
quant_config
.
quant_type
.
size_bits
,
)
replace_tensor
(
layer
,
"qweight"
,
marlin_qweight
)
# Permute scales from autogptq format to marlin format.
marlin_scales
=
marlin_permute_scales
(
layer
.
scales
,
size_k
=
(
layer
.
input_size
if
self
.
quant_config
.
desc_act
else
layer
.
input_size_per_partition
),
size_n
=
layer
.
output_size_per_partition
,
group_size
=
self
.
quant_config
.
group_size
,
)
replace_tensor
(
layer
,
"scales"
,
marlin_scales
)
def
apply
(
def
apply
(
self
,
self
,
...
@@ -333,20 +301,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -333,20 +301,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
apply_gptq_marlin_linear
(
return
self
.
kernel
.
apply_weights
(
layer
,
x
,
bias
)
input
=
x
,
weight
=
layer
.
qweight
,
weight_scale
=
layer
.
scales
,
weight_zp
=
layer
.
zp
,
g_idx
=
layer
.
g_idx
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
workspace
=
layer
.
workspace
,
wtype
=
self
.
quant_config
.
quant_type
,
output_size_per_partition
=
layer
.
output_size_per_partition
,
input_size_per_partition
=
layer
.
input_size_per_partition
,
is_k_full
=
layer
.
is_k_full
,
bias
=
bias
,
)
class
GPTQMarlinMoEMethod
(
FusedMoEMethodBase
):
class
GPTQMarlinMoEMethod
(
FusedMoEMethodBase
):
...
@@ -506,12 +461,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -506,12 +461,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
w13_g_idx_sort_indices
[
e
]]
w13_g_idx_sort_indices
[
e
]]
w2_sorted_g_idx
[
e
]
=
layer
.
w2_g_idx
[
e
][
w2_sorted_g_idx
[
e
]
=
layer
.
w2_g_idx
[
e
][
w2_g_idx_sort_indices
[
e
]]
w2_g_idx_sort_indices
[
e
]]
replace_
tenso
r
(
layer
,
"w13_g_idx"
,
w13_sorted_g_idx
)
replace_
paramete
r
(
layer
,
"w13_g_idx"
,
w13_sorted_g_idx
)
replace_
tenso
r
(
layer
,
"w2_g_idx"
,
w2_sorted_g_idx
)
replace_
paramete
r
(
layer
,
"w2_g_idx"
,
w2_sorted_g_idx
)
replace_
tenso
r
(
layer
,
"w13_g_idx_sort_indices"
,
replace_
paramete
r
(
layer
,
"w13_g_idx_sort_indices"
,
w13_g_idx_sort_indices
)
w13_g_idx_sort_indices
)
replace_
tenso
r
(
layer
,
"w2_g_idx_sort_indices"
,
replace_
paramete
r
(
layer
,
"w2_g_idx_sort_indices"
,
w2_g_idx_sort_indices
)
w2_g_idx_sort_indices
)
else
:
else
:
# Reset g_idx related tensors
# Reset g_idx related tensors
num_experts
=
layer
.
w13_g_idx
.
shape
[
0
]
num_experts
=
layer
.
w13_g_idx
.
shape
[
0
]
...
@@ -544,7 +499,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -544,7 +499,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer
.
w13_qweight
.
shape
[
2
],
layer
.
w13_qweight
.
shape
[
2
],
self
.
quant_config
.
quant_type
.
size_bits
,
self
.
quant_config
.
quant_type
.
size_bits
,
)
)
replace_
tenso
r
(
layer
,
"w13_qweight"
,
marlin_w13_qweight
)
replace_
paramete
r
(
layer
,
"w13_qweight"
,
marlin_w13_qweight
)
marlin_w2_qweight
=
ops
.
gptq_marlin_moe_repack
(
marlin_w2_qweight
=
ops
.
gptq_marlin_moe_repack
(
layer
.
w2_qweight
,
layer
.
w2_qweight
,
layer
.
w2_g_idx_sort_indices
,
layer
.
w2_g_idx_sort_indices
,
...
@@ -552,7 +507,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -552,7 +507,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer
.
w2_qweight
.
shape
[
2
],
layer
.
w2_qweight
.
shape
[
2
],
self
.
quant_config
.
quant_type
.
size_bits
,
self
.
quant_config
.
quant_type
.
size_bits
,
)
)
replace_
tenso
r
(
layer
,
"w2_qweight"
,
marlin_w2_qweight
)
replace_
paramete
r
(
layer
,
"w2_qweight"
,
marlin_w2_qweight
)
# Repack scales
# Repack scales
marlin_w13_scales
=
marlin_moe_permute_scales
(
marlin_w13_scales
=
marlin_moe_permute_scales
(
s
=
layer
.
w13_scales
,
s
=
layer
.
w13_scales
,
...
@@ -560,14 +515,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -560,14 +515,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
size_n
=
layer
.
w13_scales
.
shape
[
2
],
size_n
=
layer
.
w13_scales
.
shape
[
2
],
group_size
=
self
.
quant_config
.
group_size
,
group_size
=
self
.
quant_config
.
group_size
,
)
)
replace_
tenso
r
(
layer
,
"w13_scales"
,
marlin_w13_scales
)
replace_
paramete
r
(
layer
,
"w13_scales"
,
marlin_w13_scales
)
marlin_w2_scales
=
marlin_moe_permute_scales
(
marlin_w2_scales
=
marlin_moe_permute_scales
(
s
=
layer
.
w2_scales
,
s
=
layer
.
w2_scales
,
size_k
=
layer
.
w2_scales
.
shape
[
1
]
*
self
.
quant_config
.
pack_factor
,
size_k
=
layer
.
w2_scales
.
shape
[
1
]
*
self
.
quant_config
.
pack_factor
,
size_n
=
layer
.
w2_scales
.
shape
[
2
],
size_n
=
layer
.
w2_scales
.
shape
[
2
],
group_size
=
self
.
quant_config
.
group_size
,
group_size
=
self
.
quant_config
.
group_size
,
)
)
replace_
tenso
r
(
layer
,
"w2_scales"
,
marlin_w2_scales
)
replace_
paramete
r
(
layer
,
"w2_scales"
,
marlin_w2_scales
)
def
apply
(
def
apply
(
self
,
self
,
...
@@ -611,4 +566,5 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -611,4 +566,5 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
topk_ids
,
topk_ids
,
w1_scale
=
layer
.
w13_scales
,
w1_scale
=
layer
.
w13_scales
,
w2_scale
=
layer
.
w2_scales
,
w2_scale
=
layer
.
w2_scales
,
num_bits
=
self
.
quant_config
.
quant_type
.
size_bits
,
).
to
(
orig_dtype
)
).
to
(
orig_dtype
)
vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py
0 → 100644
View file @
539aa992
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
Callable
,
Optional
,
Tuple
import
torch
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.scalar_type
import
ScalarType
@
dataclass
class
MPLinearLayerConfig
:
full_weight_shape
:
Tuple
[
int
,
int
]
# [in, out]
partition_weight_shape
:
Tuple
[
int
,
int
]
weight_type
:
ScalarType
act_type
:
torch
.
dtype
group_size
:
int
zero_points
:
bool
has_g_idx
:
bool
class
MPLinearKernel
(
ABC
):
@
classmethod
@
abstractmethod
def
get_min_capability
(
cls
)
->
int
:
raise
NotImplementedError
@
classmethod
@
abstractmethod
def
can_implement
(
cls
,
c
:
MPLinearLayerConfig
)
->
Tuple
[
bool
,
Optional
[
str
]]:
raise
NotImplementedError
def
__init__
(
self
,
c
:
MPLinearLayerConfig
,
w_q_param_name
:
str
,
w_s_param_name
:
str
,
w_zp_param_name
:
Optional
[
str
]
=
None
,
w_gidx_param_name
:
Optional
[
str
]
=
None
)
->
None
:
assert
self
.
can_implement
(
c
)
self
.
config
=
c
self
.
w_q_name
=
w_q_param_name
self
.
w_s_name
=
w_s_param_name
self
.
w_zp_name
=
w_zp_param_name
self
.
w_gidx_name
=
w_gidx_param_name
@
abstractmethod
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
_transform_param
(
self
,
layer
:
torch
.
nn
.
Module
,
name
:
Optional
[
str
],
fn
:
Callable
)
->
None
:
if
name
is
not
None
and
getattr
(
layer
,
name
,
None
)
is
not
None
:
old_param
=
getattr
(
layer
,
name
)
new_param
=
fn
(
old_param
)
# replace the parameter with torch.nn.Parameter for TorchDynamo
# compatibility
replace_parameter
(
layer
,
name
,
torch
.
nn
.
Parameter
(
new_param
.
data
,
requires_grad
=
False
))
def
_get_weight_params
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Tuple
[
torch
.
Tensor
,
# w_q
torch
.
Tensor
,
# w_s
Optional
[
torch
.
Tensor
],
# w_zp,
Optional
[
torch
.
Tensor
]
# w_gidx
]:
return
(
getattr
(
layer
,
self
.
w_q_name
),
getattr
(
layer
,
self
.
w_s_name
),
getattr
(
layer
,
self
.
w_zp_name
or
""
,
None
),
getattr
(
layer
,
self
.
w_gidx_name
or
""
,
None
),
)
vllm/model_executor/layers/quantization/kernels/__init__.py
0 → 100644
View file @
539aa992
import
os
from
typing
import
List
,
Optional
,
Type
from
vllm.model_executor.layers.quantization.kernels.machete
import
(
MacheteLinearKernel
)
from
vllm.model_executor.layers.quantization.kernels.marlin
import
(
MarlinLinearKernel
)
from
vllm.model_executor.layers.quantization.kernels.MPLinearKernel
import
(
MPLinearKernel
,
MPLinearLayerConfig
)
from
vllm.platforms
import
current_platform
# in priority/performance order (when available)
_POSSIBLE_KERNELS
:
List
[
Type
[
MPLinearKernel
]]
=
[
MacheteLinearKernel
,
MarlinLinearKernel
,
]
def
choose_mp_linear_kernel
(
config
:
MPLinearLayerConfig
,
compute_capability
:
Optional
[
int
]
=
None
)
->
Type
[
MPLinearKernel
]:
"""
Choose an MPLinearKernel that can implement the given config for the given
compute capability. Attempts to choose the best kernel in terms of
performance.
Args:
config (MPLinearLayerConfig): Description of the linear layer to be
implemented.
compute_capability (Optional[int], optional): The compute capability of
the target device, if None uses `current_platform` to get the compute
capability. Defaults to None.
Raises:
ValueError: If no kernel can implement the given config.
Returns:
Type[MPLinearKernel]: Chosen kernel.
"""
if
compute_capability
is
None
:
if
current_platform
is
None
:
raise
ValueError
(
"Cannot determine compute capability"
)
_cc
=
current_platform
.
get_device_capability
()
compute_capability
=
_cc
[
0
]
*
10
+
_cc
[
1
]
failure_reasons
=
[]
for
kernel
in
_POSSIBLE_KERNELS
:
if
kernel
.
__name__
in
os
.
environ
.
get
(
"VLLM_DISABLED_KERNELS"
,
""
)
\
.
split
(
","
):
failure_reasons
.
append
(
f
'
{
kernel
.
__name__
}
disabled by environment variable'
)
continue
if
kernel
.
get_min_capability
()
>
compute_capability
:
failure_reasons
.
append
(
f
"
{
kernel
.
__name__
}
requires capability "
f
"
{
kernel
.
get_min_capability
()
}
, current compute capability "
f
"is
{
compute_capability
}
"
)
continue
can_implement
,
failure_reason
=
kernel
.
can_implement
(
config
)
if
can_implement
:
return
kernel
else
:
failure_reasons
.
append
(
f
'
{
kernel
.
__name__
}
cannot implement due to:
{
failure_reason
}
'
)
raise
ValueError
(
"Failed to find a kernel that can implement the "
\
"WNA16 linear layer. Reasons:
\n
"
+
'
\n
'
.
join
(
failure_reasons
))
vllm/model_executor/layers/quantization/kernels/machete.py
0 → 100644
View file @
539aa992
from
functools
import
partial
from
typing
import
Optional
,
Tuple
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.machete_utils
import
(
MACHETE_SUPPORTED_GROUP_SIZES
,
check_machete_supports_shape
,
query_machete_supported_quant_types
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
pack_weights_into_int32
,
unpack_weights_into_int32
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
permute_param_layout_
)
from
.MPLinearKernel
import
MPLinearKernel
,
MPLinearLayerConfig
class
MacheteLinearKernel
(
MPLinearKernel
):
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
90
@
classmethod
def
can_implement
(
cls
,
c
:
MPLinearLayerConfig
)
->
Tuple
[
bool
,
Optional
[
str
]]:
if
c
.
has_g_idx
and
\
c
.
partition_weight_shape
[
0
]
!=
c
.
full_weight_shape
[
0
]:
return
False
,
"Act reordering currently not supported by Machete, "
\
"when the input features are partitioned across "
\
"devices"
if
c
.
zero_points
:
return
False
,
"Zero points currently not supported by "
\
" Compressed Tensors + Machete. (Kernel supports it"
\
" but CompressedTensorsWNA16 does not so support has"
\
" not been added to MacheteWNA16Kernel yet"
if
c
.
weight_type
not
in
query_machete_supported_quant_types
(
c
.
zero_points
):
return
False
,
f
"Quant type (
{
c
.
weight_type
}
) not supported by "
\
"Machete, supported types are: "
\
f
"
{
query_machete_supported_quant_types
(
c
.
zero_points
)
}
"
if
c
.
group_size
not
in
MACHETE_SUPPORTED_GROUP_SIZES
:
return
False
,
f
"Group size (
{
c
.
group_size
}
) not supported by "
\
"Machete, supported group sizes are: "
\
f
"
{
MACHETE_SUPPORTED_GROUP_SIZES
}
"
return
check_machete_supports_shape
(
c
.
partition_weight_shape
[
0
],
c
.
partition_weight_shape
[
1
])
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
):
c
=
self
.
config
if
c
.
has_g_idx
:
assert
self
.
w_gidx_name
is
not
None
perm
=
torch
.
argsort
(
getattr
(
layer
,
self
.
w_gidx_name
))
\
.
to
(
torch
.
int
)
self
.
act_perm
=
lambda
x
:
x
[:,
perm
]
# use `ops.permute_cols` if possible
if
c
.
act_type
in
[
torch
.
float16
,
torch
.
bfloat16
]
\
and
c
.
partition_weight_shape
[
0
]
%
8
==
0
:
self
.
act_perm
=
partial
(
ops
.
permute_cols
,
perm
=
perm
)
def
transform_w_q
(
x
):
assert
isinstance
(
x
,
BasevLLMParameter
)
permute_param_layout_
(
x
,
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
0
)
if
c
.
has_g_idx
:
x_unpacked
=
unpack_weights_into_int32
(
x
.
data
,
c
.
weight_type
,
packed_dim
=
0
)
x_perm
=
x_unpacked
[
perm
,
:]
x
.
data
=
pack_weights_into_int32
(
x_perm
,
c
.
weight_type
,
packed_dim
=
0
)
x
.
data
=
ops
.
machete_prepack_B
(
x
.
data
.
t
().
contiguous
().
t
(),
self
.
config
.
weight_type
)
return
x
def
transform_w_s
(
x
):
assert
isinstance
(
x
,
BasevLLMParameter
)
permute_param_layout_
(
x
,
input_dim
=
0
,
output_dim
=
1
)
x
.
data
=
x
.
data
.
contiguous
()
return
x
# Repack weights and scales for Machete
self
.
_transform_param
(
layer
,
self
.
w_q_name
,
transform_w_q
)
self
.
_transform_param
(
layer
,
self
.
w_s_name
,
transform_w_s
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
c
=
self
.
config
w_q
,
w_s
,
_
,
_
=
self
.
_get_weight_params
(
layer
)
x_2d
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out_shape
=
x
.
shape
[:
-
1
]
+
(
c
.
partition_weight_shape
[
1
],
)
if
c
.
has_g_idx
:
x_2d
=
self
.
act_perm
(
x_2d
)
output
=
ops
.
machete_gemm
(
a
=
x_2d
,
b_q
=
w_q
,
b_type
=
c
.
weight_type
,
b_zeros
=
None
,
b_scales
=
w_s
,
b_group_size
=
c
.
group_size
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
vllm/model_executor/layers/quantization/kernels/marlin.py
0 → 100644
View file @
539aa992
from
typing
import
Optional
,
Tuple
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
MARLIN_SUPPORTED_GROUP_SIZES
,
apply_gptq_marlin_linear
,
check_marlin_supports_shape
,
marlin_is_k_full
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
marlin_sort_g_idx
,
query_marlin_supported_quant_types
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
permute_param_layout_
)
from
.MPLinearKernel
import
MPLinearKernel
,
MPLinearLayerConfig
class
MarlinLinearKernel
(
MPLinearKernel
):
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
can_implement
(
cls
,
c
:
MPLinearLayerConfig
)
->
Tuple
[
bool
,
Optional
[
str
]]:
if
c
.
zero_points
:
return
False
,
"Zero points currently not supported by "
\
" MarlinLinearKernel. Will be added when AWQMarlin "
\
"is migrated over to using MPLinearKernel backend"
quant_types
=
query_marlin_supported_quant_types
(
c
.
zero_points
)
if
c
.
weight_type
not
in
quant_types
:
return
False
,
f
"Quant type (
{
c
.
weight_type
}
) not supported by"
\
f
" Marlin, supported types are:
{
quant_types
}
"
if
c
.
group_size
not
in
MARLIN_SUPPORTED_GROUP_SIZES
:
return
False
,
f
"Group size (
{
c
.
group_size
}
) not supported by "
\
"Marlin, supported group sizes are: "
\
f
"
{
MARLIN_SUPPORTED_GROUP_SIZES
}
"
return
check_marlin_supports_shape
(
c
.
partition_weight_shape
[
0
],
c
.
partition_weight_shape
[
1
],
c
.
full_weight_shape
[
1
],
c
.
group_size
)
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
device
=
getattr
(
layer
,
self
.
w_q_name
).
device
c
=
self
.
config
row_parallel
=
(
c
.
partition_weight_shape
[
0
]
!=
c
.
full_weight_shape
[
0
])
self
.
is_k_full
=
marlin_is_k_full
(
c
.
has_g_idx
,
row_parallel
)
# Allocate marlin workspace.
self
.
workspace
=
marlin_make_workspace
(
c
.
partition_weight_shape
[
1
],
device
)
# Default names since marlin requires empty parameters for these,
# TODO: remove this requirement from marlin (allow optional tensors)
if
self
.
w_gidx_name
is
None
:
self
.
w_gidx_name
=
"g_idx"
if
self
.
w_zp_name
is
None
:
self
.
w_zp_name
=
"w_zp"
if
c
.
has_g_idx
:
g_idx
,
g_idx_sort_indices
=
marlin_sort_g_idx
(
getattr
(
layer
,
self
.
w_gidx_name
))
self
.
_transform_param
(
layer
,
self
.
w_gidx_name
,
lambda
_
:
g_idx
)
layer
.
g_idx_sort_indices
=
g_idx_sort_indices
else
:
setattr
(
layer
,
self
.
w_gidx_name
,
marlin_make_empty_g_idx
(
device
))
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
if
c
.
zero_points
:
pass
# TODO (lucas): add the following when AWQMarlin is migrated over to
# using MPLinearKernel backend
# self._transform_param(layer, self.w_zp_name, lambda x: \
# marlin_zero_points(
# x,
# size_k=c.partition_weight_shape[0],
# size_n=c.partition_weight_shape[1],
# num_bits=c.weight_type.size_bits))
else
:
setattr
(
layer
,
self
.
w_zp_name
,
marlin_make_empty_g_idx
(
device
))
def
transform_w_q
(
x
):
assert
isinstance
(
x
,
BasevLLMParameter
)
permute_param_layout_
(
x
,
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
0
)
x
.
data
=
ops
.
gptq_marlin_repack
(
x
.
data
.
contiguous
(),
perm
=
layer
.
g_idx_sort_indices
,
size_k
=
c
.
partition_weight_shape
[
0
],
size_n
=
c
.
partition_weight_shape
[
1
],
num_bits
=
c
.
weight_type
.
size_bits
)
return
x
def
transform_w_s
(
x
):
assert
isinstance
(
x
,
BasevLLMParameter
)
permute_param_layout_
(
x
,
input_dim
=
0
,
output_dim
=
1
)
x
.
data
=
marlin_permute_scales
(
x
.
data
.
contiguous
(),
size_k
=
c
.
partition_weight_shape
[
0
],
size_n
=
c
.
partition_weight_shape
[
1
],
group_size
=
c
.
group_size
)
return
x
self
.
_transform_param
(
layer
,
self
.
w_q_name
,
transform_w_q
)
self
.
_transform_param
(
layer
,
self
.
w_s_name
,
transform_w_s
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
c
=
self
.
config
w_q
,
w_s
,
w_zp
,
w_gidx
=
self
.
_get_weight_params
(
layer
)
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
# None for marlin
return
apply_gptq_marlin_linear
(
input
=
x
,
weight
=
w_q
,
weight_scale
=
w_s
,
weight_zp
=
w_zp
,
# type: ignore
g_idx
=
w_gidx
,
# type: ignore
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
workspace
=
self
.
workspace
,
wtype
=
c
.
weight_type
,
input_size_per_partition
=
c
.
partition_weight_shape
[
0
],
output_size_per_partition
=
c
.
partition_weight_shape
[
1
],
is_k_full
=
self
.
is_k_full
,
bias
=
bias
)
vllm/model_executor/layers/quantization/qqq.py
View file @
539aa992
...
@@ -260,7 +260,7 @@ class QQQLinearMethod(LinearMethodBase):
...
@@ -260,7 +260,7 @@ class QQQLinearMethod(LinearMethodBase):
size_k
=
x_2d
.
shape
[
1
]
size_k
=
x_2d
.
shape
[
1
]
size_n
=
s_ch
.
shape
[
1
]
size_n
=
s_ch
.
shape
[
1
]
x_int8
,
s_tok
=
ops
.
scaled_int8_quant
(
x_2d
)
x_int8
,
s_tok
,
_
=
ops
.
scaled_int8_quant
(
x_2d
)
output_2d
=
ops
.
marlin_qqq_gemm
(
x_int8
,
qweight
,
s_tok
,
s_ch
,
s_group
,
output_2d
=
ops
.
marlin_qqq_gemm
(
x_int8
,
qweight
,
s_tok
,
s_ch
,
s_group
,
workspace
,
size_m
,
size_n
,
size_k
)
workspace
,
size_m
,
size_n
,
size_k
)
...
...
vllm/model_executor/layers/quantization/utils/__init__.py
View file @
539aa992
from
.layer_utils
import
replace_parameter
,
update_tensor_inplace
__all__
=
[
'update_tensor_inplace'
,
'replace_parameter'
]
vllm/model_executor/layers/quantization/utils/layer_utils.py
0 → 100644
View file @
539aa992
from
typing
import
Union
import
torch
def
update_tensor_inplace
(
dst
:
torch
.
Tensor
,
src
:
torch
.
Tensor
):
assert
dst
.
dtype
==
src
.
dtype
,
"Tensors must have the same dtype"
# update tensor shape and stride
dst
.
as_strided_
(
src
.
shape
,
src
.
stride
())
# If not the same underlying storage move tensor data
if
dst
.
data_ptr
()
!=
src
.
data_ptr
():
dst
.
copy_
(
src
)
del
src
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def
replace_parameter
(
mod
:
torch
.
nn
.
Module
,
name
:
str
,
new
:
Union
[
torch
.
Tensor
,
torch
.
nn
.
Parameter
])
->
None
:
old
=
getattr
(
mod
,
name
)
if
type
(
old
)
is
type
(
new
)
and
old
.
dtype
==
new
.
dtype
and
\
old
.
untyped_storage
().
nbytes
()
==
new
.
untyped_storage
().
nbytes
():
# If we can just update in-place to avoid re-registering
# can be faster if the underlying storage is the same
update_tensor_inplace
(
old
,
new
)
else
:
# Fallback re-register parameter, convert to Parameter if necessary
# this not only ensures we don't register a tensor as a parameter, but
# also ensures that all parameter subclasses get re-registered as
# parameters for `torch.compile` compatibility
if
not
isinstance
(
new
,
torch
.
nn
.
Parameter
):
new
=
torch
.
nn
.
Parameter
(
new
,
requires_grad
=
False
)
mod
.
register_parameter
(
name
,
torch
.
nn
.
Parameter
(
new
,
requires_grad
=
False
))
vllm/model_executor/layers/quantization/utils/machete_utils.py
0 → 100644
View file @
539aa992
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
vllm.scalar_type
import
ScalarType
,
scalar_types
MACHETE_SUPPORTED_GROUP_SIZES
=
[
-
1
,
128
]
MACHETE_PREPACKED_BLOCK_SHAPE
=
[
64
,
128
]
def
query_machete_supported_quant_types
(
zero_points
:
bool
)
->
List
[
ScalarType
]:
if
zero_points
:
return
[
scalar_types
.
uint4
,
scalar_types
.
uint8
]
else
:
return
[
scalar_types
.
uint4b8
,
scalar_types
.
uint8b128
]
def
query_machete_supported_act_types
(
zero_points
:
bool
)
->
List
[
ScalarType
]:
return
[
torch
.
float16
,
torch
.
bfloat16
]
def
check_machete_supports_shape
(
in_features
:
int
,
out_featrues
:
int
)
\
->
Tuple
[
bool
,
Optional
[
str
]]:
if
in_features
%
MACHETE_PREPACKED_BLOCK_SHAPE
[
0
]
!=
0
:
return
False
,
"Input features size must be divisible by "
\
f
"
{
MACHETE_PREPACKED_BLOCK_SHAPE
[
0
]
}
"
if
out_featrues
%
MACHETE_PREPACKED_BLOCK_SHAPE
[
1
]
!=
0
:
return
False
,
"Output features size must be divisible by "
\
f
"
{
MACHETE_PREPACKED_BLOCK_SHAPE
[
1
]
}
"
return
True
,
None
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
539aa992
...
@@ -29,8 +29,9 @@ def query_marlin_supported_quant_types(has_zp: bool,
...
@@ -29,8 +29,9 @@ def query_marlin_supported_quant_types(has_zp: bool,
device_capability
:
Optional
[
int
]
=
None
device_capability
:
Optional
[
int
]
=
None
):
):
if
device_capability
is
None
:
if
device_capability
is
None
:
major
,
minor
=
current_platform
.
get_device_capability
()
capability_tuple
=
current_platform
.
get_device_capability
()
device_capability
=
major
*
10
+
minor
device_capability
=
(
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
())
if
device_capability
<
80
:
if
device_capability
<
80
:
return
[]
return
[]
...
@@ -52,8 +53,9 @@ def _check_marlin_supported(
...
@@ -52,8 +53,9 @@ def _check_marlin_supported(
device_capability
:
Optional
[
int
]
=
None
)
->
Tuple
[
bool
,
Optional
[
str
]]:
device_capability
:
Optional
[
int
]
=
None
)
->
Tuple
[
bool
,
Optional
[
str
]]:
if
device_capability
is
None
:
if
device_capability
is
None
:
major
,
minor
=
current_platform
.
get_device_capability
()
capability_tuple
=
current_platform
.
get_device_capability
()
device_capability
=
major
*
10
+
minor
device_capability
=
(
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
())
supported_types
=
query_marlin_supported_quant_types
(
supported_types
=
query_marlin_supported_quant_types
(
has_zp
,
device_capability
)
has_zp
,
device_capability
)
...
@@ -118,6 +120,19 @@ def verify_marlin_supports_shape(output_size_per_partition: int,
...
@@ -118,6 +120,19 @@ def verify_marlin_supports_shape(output_size_per_partition: int,
"with --quantization gptq."
)
"with --quantization gptq."
)
def
check_marlin_supports_shape
(
output_size_per_partition
:
int
,
input_size_per_partition
:
int
,
input_size
:
int
,
group_size
:
int
)
\
->
Tuple
[
bool
,
Optional
[
str
]]:
try
:
verify_marlin_supports_shape
(
output_size_per_partition
,
input_size_per_partition
,
input_size
,
group_size
)
except
ValueError
as
e
:
return
False
,
e
.
__str__
()
return
True
,
None
def
marlin_make_workspace
(
output_size_per_partition
:
int
,
def
marlin_make_workspace
(
output_size_per_partition
:
int
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
device
:
torch
.
device
)
->
torch
.
Tensor
:
max_workspace_size
=
(
output_size_per_partition
//
max_workspace_size
=
(
output_size_per_partition
//
...
@@ -146,6 +161,11 @@ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
...
@@ -146,6 +161,11 @@ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
requires_grad
=
False
)
requires_grad
=
False
)
def
marlin_make_empty_zp
(
device
:
torch
.
device
)
->
torch
.
Tensor
:
return
torch
.
nn
.
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
),
requires_grad
=
False
)
def
marlin_sort_g_idx
(
def
marlin_sort_g_idx
(
g_idx
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
g_idx
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
g_idx_sort_indices
=
torch
.
argsort
(
g_idx
).
to
(
torch
.
int
)
g_idx_sort_indices
=
torch
.
argsort
(
g_idx
).
to
(
torch
.
int
)
...
@@ -238,17 +258,6 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
...
@@ -238,17 +258,6 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
return
marlin_zp
return
marlin_zp
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def
replace_tensor
(
layer
:
torch
.
nn
.
Module
,
name
:
str
,
new_t
:
torch
.
Tensor
)
->
None
:
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr
(
layer
,
name
).
resize_
(
new_t
.
shape
)
getattr
(
layer
,
name
).
copy_
(
new_t
)
del
new_t
def
apply_gptq_marlin_linear
(
def
apply_gptq_marlin_linear
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
View file @
539aa992
...
@@ -10,8 +10,7 @@ from .marlin_utils import marlin_make_workspace, marlin_permute_scales
...
@@ -10,8 +10,7 @@ from .marlin_utils import marlin_make_workspace, marlin_permute_scales
def
is_fp8_marlin_supported
():
def
is_fp8_marlin_supported
():
capability
=
current_platform
.
get_device_capability
()
return
current_platform
.
has_device_capability
(
80
)
return
capability
[
0
]
>=
8
def
apply_fp8_marlin_linear
(
def
apply_fp8_marlin_linear
(
...
...
vllm/model_executor/layers/quantization/utils/quant_utils.py
View file @
539aa992
...
@@ -20,6 +20,49 @@ FUSED_LAYER_NAME_MAPPING = {
...
@@ -20,6 +20,49 @@ FUSED_LAYER_NAME_MAPPING = {
}
}
def
pack_weights_into_int32
(
w_q
:
torch
.
Tensor
,
wtype
:
ScalarType
,
packed_dim
:
int
=
0
):
# move dim to pack to the end
perm
=
(
*
[
i
for
i
in
range
(
len
(
w_q
.
shape
))
if
i
!=
packed_dim
],
packed_dim
)
inv_perm
=
tuple
(
perm
.
index
(
i
)
for
i
in
range
(
len
(
perm
)))
w_q_perm
=
w_q
.
permute
(
perm
)
pack_factor
=
32
//
wtype
.
size_bits
mask
=
(
1
<<
wtype
.
size_bits
)
-
1
new_shape_perm
=
list
(
w_q_perm
.
shape
)
assert
w_q_perm
.
shape
[
-
1
]
%
pack_factor
==
0
new_shape_perm
[
-
1
]
//=
pack_factor
res
=
torch
.
zeros
(
new_shape_perm
,
dtype
=
torch
.
int32
,
device
=
w_q
.
device
)
for
i
in
range
(
pack_factor
):
res
|=
(
w_q_perm
[...,
i
::
pack_factor
]
&
mask
)
<<
wtype
.
size_bits
*
i
return
res
.
permute
(
inv_perm
)
def
unpack_weights_into_int32
(
w_q
:
torch
.
Tensor
,
wtype
:
ScalarType
,
packed_dim
:
int
=
0
):
# move dim to pack to the end
perm
=
(
*
[
i
for
i
in
range
(
len
(
w_q
.
shape
))
if
i
!=
packed_dim
],
packed_dim
)
inv_perm
=
tuple
(
perm
.
index
(
i
)
for
i
in
range
(
len
(
perm
)))
w_q_perm
=
w_q
.
permute
(
perm
)
pack_factor
=
32
//
wtype
.
size_bits
mask
=
(
1
<<
wtype
.
size_bits
)
-
1
new_shape_perm
=
list
(
w_q_perm
.
shape
)
new_shape_perm
[
-
1
]
*=
pack_factor
res
=
torch
.
zeros
(
new_shape_perm
,
dtype
=
torch
.
int32
,
device
=
w_q
.
device
)
for
i
in
range
(
pack_factor
):
res
[...,
i
::
pack_factor
]
=
(
w_q_perm
>>
wtype
.
size_bits
*
i
)
&
mask
return
res
.
permute
(
inv_perm
)
def
is_layer_skipped
(
prefix
:
str
,
ignored_layers
:
List
[
str
])
->
bool
:
def
is_layer_skipped
(
prefix
:
str
,
ignored_layers
:
List
[
str
])
->
bool
:
# prefix: model.layers.0.self_attn.q_proj
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
# proj_name: q_proj
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
539aa992
...
@@ -6,19 +6,18 @@ from vllm import _custom_ops as ops
...
@@ -6,19 +6,18 @@ from vllm import _custom_ops as ops
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_hip
from
vllm.utils
import
is_hip
# scaled_mm in pytorch on rocm has a bug that requires always
# Input scaling factors are no longer optional in _scaled_mm starting
# providing scaling factor for result. This value is created
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
# as global value to avoid multiple tensor allocations, and
TORCH_DEVICE_IDENTITY
=
torch
.
ones
(
1
).
cuda
()
if
is_hip
()
else
None
# can be removed once pytorch fixes the bug.
TORCH_SCALED_MM_SCALE_RESULT
=
torch
.
ones
(
1
).
cuda
()
if
is_hip
()
else
None
def
cutlass_fp8_supported
()
->
bool
:
def
cutlass_fp8_supported
()
->
bool
:
# cutlass is not supported on Rocm
# cutlass is not supported on Rocm
if
is_hip
():
if
is_hip
():
return
False
return
False
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
capability_tuple
=
current_platform
.
get_device_capability
()
capability
=
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
()
return
ops
.
cutlass_scaled_mm_supports_fp8
(
capability
)
return
ops
.
cutlass_scaled_mm_supports_fp8
(
capability
)
...
@@ -130,19 +129,17 @@ def apply_fp8_linear(
...
@@ -130,19 +129,17 @@ def apply_fp8_linear(
if
per_tensor_weights
and
per_tensor_activations
:
if
per_tensor_weights
and
per_tensor_activations
:
# Fused GEMM_DQ
# Fused GEMM_DQ
output
=
torch
.
_scaled_mm
(
output
=
torch
.
_scaled_mm
(
qinput
,
qinput
,
weight
,
weight
,
out_dtype
=
input
.
dtype
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
scale_b
=
weight_scale
,
bias
=
bias
)
scale_result
=
TORCH_SCALED_MM_SCALE_RESULT
,
# A fix for discrepancy in scaled_mm which returns tuple
bias
=
bias
)
# for torch < 2.5 and a single value in torch >= 2.5
# Since in torch 2.5, scaled_mm only returns single value
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
# This should be removed when vllm-nvidia also moves to 2.5
return
torch
.
narrow
(
output
[
0
],
0
,
0
,
input
.
shape
[
0
])
if
is_hip
():
return
torch
.
narrow
(
output
,
0
,
0
,
input
.
shape
[
0
])
return
torch
.
narrow
(
output
,
0
,
0
,
input
.
shape
[
0
])
return
torch
.
narrow
(
output
[
0
],
0
,
0
,
input
.
shape
[
0
])
else
:
else
:
# Fallback for channelwise case, where we use unfused DQ
# Fallback for channelwise case, where we use unfused DQ
...
@@ -160,12 +157,23 @@ def apply_fp8_linear(
...
@@ -160,12 +157,23 @@ def apply_fp8_linear(
# For the scaled_mm fallback case, we break this down, since it
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# does not support s_w being a vector.
# Making sure the dummy tensor is on the same device as the weight
global
TORCH_DEVICE_IDENTITY
if
TORCH_DEVICE_IDENTITY
.
device
!=
weight
.
device
:
TORCH_DEVICE_IDENTITY
=
TORCH_DEVICE_IDENTITY
.
to
(
weight
.
device
)
# GEMM
# GEMM
# This computes C = (X * W).
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
# Output in fp32 to allow subsequent ops to happen in-place
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
output
=
torch
.
_scaled_mm
(
qinput
,
weight
,
weight
,
out_dtype
=
torch
.
float32
)
scale_a
=
TORCH_DEVICE_IDENTITY
,
scale_b
=
TORCH_DEVICE_IDENTITY
,
out_dtype
=
torch
.
float32
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
output
=
output
[
0
]
# Unpad (undo num_token_padding)
# Unpad (undo num_token_padding)
output
=
torch
.
narrow
(
output
,
0
,
0
,
input
.
shape
[
0
])
output
=
torch
.
narrow
(
output
,
0
,
0
,
input
.
shape
[
0
])
x_scale
=
torch
.
narrow
(
x_scale
,
0
,
0
,
input
.
shape
[
0
])
x_scale
=
torch
.
narrow
(
x_scale
,
0
,
0
,
input
.
shape
[
0
])
...
@@ -188,7 +196,7 @@ def apply_int8_linear(
...
@@ -188,7 +196,7 @@ def apply_int8_linear(
# ops.scaled_int8_quant supports both dynamic and static quant.
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
# * static, layer.input_scale is scalar and x_scale is input_scale.
x_q
,
x_scale
=
ops
.
scaled_int8_quant
(
input
,
input_scale
)
x_q
,
x_scale
,
_
=
ops
.
scaled_int8_quant
(
input
,
input_scale
)
return
ops
.
cutlass_scaled_mm
(
x_q
,
return
ops
.
cutlass_scaled_mm
(
x_q
,
weight
,
weight
,
...
...
vllm/model_executor/layers/rejection_sampler.py
View file @
539aa992
...
@@ -31,15 +31,11 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -31,15 +31,11 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
disable_bonus_tokens
:
bool
=
True
,
strict_mode
:
bool
=
False
,
strict_mode
:
bool
=
False
,
use_flashinfer
:
Optional
[
bool
]
=
None
):
use_flashinfer
:
Optional
[
bool
]
=
None
):
"""Create a rejection sampler.
"""Create a rejection sampler.
Args:
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
during sampling. This catches correctness issues but adds
nontrivial latency.
nontrivial latency.
...
@@ -48,8 +44,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -48,8 +44,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
None, we will use the default value from the environment variable.
None, we will use the default value from the environment variable.
This parameter is only used for testing purposes.
This parameter is only used for testing purposes.
"""
"""
super
().
__init__
(
disable_bonus_tokens
=
disable_bonus_tokens
,
super
().
__init__
(
strict_mode
=
strict_mode
)
strict_mode
=
strict_mode
)
if
use_flashinfer
is
None
:
if
use_flashinfer
is
None
:
self
.
use_flashinfer
=
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
(
self
.
use_flashinfer
=
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
(
chain_speculative_sampling
is
not
None
)
chain_speculative_sampling
is
not
None
)
...
@@ -57,8 +52,6 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -57,8 +52,6 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
self
.
use_flashinfer
=
use_flashinfer
self
.
use_flashinfer
=
use_flashinfer
if
self
.
use_flashinfer
:
if
self
.
use_flashinfer
:
assert
not
disable_bonus_tokens
,
\
"flashinfer will enable bonus token by default"
logger
.
info
(
"Use flashinfer for rejection sampling."
)
logger
.
info
(
"Use flashinfer for rejection sampling."
)
else
:
else
:
logger
.
info
(
"Use pytorch for rejection sampling."
)
logger
.
info
(
"Use pytorch for rejection sampling."
)
...
...
vllm/model_executor/layers/sampler.py
View file @
539aa992
...
@@ -10,19 +10,15 @@ import msgspec
...
@@ -10,19 +10,15 @@ import msgspec
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.spec_decode.metrics
import
SpecDecodeWorkerMetrics
from
vllm.triton_utils
import
HAS_TRITON
if
HAS_TRITON
:
from
vllm.model_executor.layers.ops.sample
import
sample
as
sample_triton
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.model_executor.sampling_metadata
import
(
SamplingMetadata
,
from
vllm.model_executor.sampling_metadata
import
(
SamplingMetadata
,
SamplingTensors
,
SamplingTensors
,
SequenceGroupToSample
)
SequenceGroupToSample
)
from
vllm.sampling_params
import
SamplingType
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
CompletionSequenceGroupOutput
,
Logprob
,
PromptLogprobs
,
SampleLogprobs
,
SequenceOutput
)
PromptLogprobs
,
SampleLogprobs
,
SequenceOutput
)
from
vllm.spec_decode.metrics
import
SpecDecodeWorkerMetrics
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
find_spec
(
"flashinfer"
):
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
find_spec
(
"flashinfer"
):
import
flashinfer.sampling
import
flashinfer.sampling
...
@@ -438,12 +434,9 @@ def _apply_top_k_top_p(
...
@@ -438,12 +434,9 @@ def _apply_top_k_top_p(
logits_sort
.
masked_fill_
(
top_p_mask
,
-
float
(
"inf"
))
logits_sort
.
masked_fill_
(
top_p_mask
,
-
float
(
"inf"
))
# Re-sort the probabilities.
# Re-sort the probabilities.
src
=
torch
.
arange
(
logits_idx
.
shape
[
-
1
],
logits
=
torch
.
empty_like
(
logits_sort
).
scatter_
(
dim
=-
1
,
device
=
logits_idx
.
device
).
expand_as
(
logits_idx
)
index
=
logits_idx
,
logits_idx_inv
=
torch
.
empty_like
(
logits_idx
).
scatter_
(
dim
=-
1
,
src
=
logits_sort
)
index
=
logits_idx
,
src
=
src
)
logits
=
torch
.
gather
(
logits_sort
,
dim
=-
1
,
index
=
logits_idx_inv
)
return
logits
return
logits
...
@@ -740,7 +733,7 @@ def _sample_with_torch(
...
@@ -740,7 +733,7 @@ def _sample_with_torch(
)
->
SampleReturnType
:
)
->
SampleReturnType
:
'''Torch-oriented _sample() implementation.
'''Torch-oriented _sample() implementation.
Single-step scheduling:
Single-step scheduling:
* Perform GPU-side sampling computation
* Perform GPU-side sampling computation
* Immediately Pythonize sampling result
* Immediately Pythonize sampling result
...
@@ -767,17 +760,17 @@ def _sample_with_torch(
...
@@ -767,17 +760,17 @@ def _sample_with_torch(
# Create output tensor for sampled token ids.
# Create output tensor for sampled token ids.
if
include_gpu_probs_tensor
:
if
include_gpu_probs_tensor
:
sampled_token_ids_tensor
=
torch
.
empty
(
logprobs
.
shape
[
0
],
sampled_token_ids_tensor
=
torch
.
full
(
(
logprobs
.
shape
[
0
],
1
),
1
,
VLLM_INVALID_TOKEN_ID
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
logprobs
.
device
)
device
=
logprobs
.
device
)
else
:
else
:
sampled_token_ids_tensor
=
None
sampled_token_ids_tensor
=
None
# Counterintiutively, having two loops here is actually faster.
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
# The first loop can run without waiting on GPU<->CPU sync.
for
sampling_type
in
SamplingType
:
for
sampling_type
in
SamplingType
:
sample_indices
=
categorized_sample_indices
[
sampling_type
]
[:,
0
]
sample_indices
=
categorized_sample_indices
[
sampling_type
]
num_tokens
=
len
(
sample_indices
)
num_tokens
=
len
(
sample_indices
)
if
num_tokens
==
0
:
if
num_tokens
==
0
:
continue
continue
...
@@ -863,88 +856,6 @@ def _sample_with_torch(
...
@@ -863,88 +856,6 @@ def _sample_with_torch(
)
)
def
_sample_with_triton_kernel
(
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
)
->
SampleResultType
:
categorized_seq_group_ids
:
Dict
[
SamplingType
,
List
[
int
]]
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
sampling_params
=
seq_group
.
sampling_params
sampling_type
=
sampling_params
.
sampling_type
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
sample_results_dict
:
Dict
[
int
,
Tuple
[
List
[
int
],
List
[
int
]]]
=
{}
sample_metadata
:
Dict
[
SamplingType
,
Tuple
[
List
[
int
],
List
[
SequenceGroupToSample
],
torch
.
Tensor
,
torch
.
Tensor
]]
=
{}
max_best_of_in_batch
=
1
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
for
sampling_type
in
SamplingType
:
sample_indices
=
categorized_sample_indices
[
sampling_type
][:,
0
]
sampled_token_indices
=
categorized_sample_indices
[
sampling_type
][:,
1
]
num_tokens
=
len
(
sample_indices
)
if
num_tokens
==
0
:
continue
seq_group_id
=
categorized_seq_group_ids
[
sampling_type
]
seq_groups
=
[
sampling_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_id
]
sample_metadata
[
sampling_type
]
=
(
seq_group_id
,
seq_groups
,
sample_indices
,
sampled_token_indices
)
if
sampling_type
in
(
SamplingType
.
GREEDY
,
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
for
seq_group
in
seq_groups
:
if
seq_group
.
is_prompt
:
sampling_params
=
seq_group
.
sampling_params
max_best_of_in_batch
=
max
(
max_best_of_in_batch
,
sampling_params
.
best_of
)
elif
sampling_type
==
SamplingType
.
BEAM
:
beam_search_logprobs
=
logprobs
[
sample_indices
]
else
:
raise
ValueError
(
f
"Unsupported sampling type:
{
sampling_type
}
"
)
sampled_tokens
,
_
,
_
=
sample_triton
(
probs
=
probs
,
seeds
=
sampling_tensors
.
sampling_seeds
,
max_best_of
=
max_best_of_in_batch
,
sample_indices
=
sampling_tensors
.
sample_indices
,
logprobs
=
logprobs
,
# don't save logprobs because we have logic for that below
# TODO: use this instead of the CPU-based logic below
save_logprobs
=
False
,
)
# GPU<->CPU sync happens in the loop below.
for
sampling_type
in
SamplingType
:
if
sampling_type
not
in
sample_metadata
:
continue
(
seq_group_id
,
seq_groups
,
sample_indices
,
sampled_token_indices
)
=
sample_metadata
[
sampling_type
]
if
sampling_type
==
SamplingType
.
GREEDY
:
sample_results
=
_greedy_sample
(
seq_groups
,
sampled_tokens
[
sampled_token_indices
][:,
0
])
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
sample_results
=
_random_sample
(
seq_groups
,
sampled_tokens
[
sampled_token_indices
])
elif
sampling_type
==
SamplingType
.
BEAM
:
sample_results
=
_beam_search_sample
(
seq_groups
,
beam_search_logprobs
)
sample_results_dict
.
update
(
zip
(
seq_group_id
,
sample_results
))
sample_results
=
[
sample_results_dict
.
get
(
i
,
([],
[]))
for
i
in
range
(
len
(
sampling_metadata
.
seq_groups
))
]
return
sample_results
def
_sample
(
def
_sample
(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
...
@@ -974,10 +885,6 @@ def _sample(
...
@@ -974,10 +885,6 @@ def _sample(
modify_greedy_probs
=
modify_greedy_probs
,
modify_greedy_probs
=
modify_greedy_probs
,
)
)
# TODO: Enable once Triton kernel & associated code is faster.
# return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
# sampling_tensors)
def
_get_ranks
(
x
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_get_ranks
(
x
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
"""
...
...
vllm/model_executor/layers/spec_decode_base_sampler.py
View file @
539aa992
...
@@ -11,20 +11,14 @@ class SpecDecodeBaseSampler(nn.Module):
...
@@ -11,20 +11,14 @@ class SpecDecodeBaseSampler(nn.Module):
step.
step.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
strict_mode
:
bool
=
False
):
disable_bonus_tokens
:
bool
=
True
,
strict_mode
:
bool
=
False
):
"""Base class constructor.
"""Base class constructor.
Args:
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
during sampling. This catches correctness issues but adds
nontrivial latency.
nontrivial latency.
"""
"""
super
().
__init__
()
super
().
__init__
()
self
.
_disable_bonus_tokens
=
disable_bonus_tokens
self
.
_strict_mode
=
strict_mode
self
.
_strict_mode
=
strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are
# NOTE: A "bonus token" is accepted iff all proposal tokens are
...
@@ -111,13 +105,6 @@ class SpecDecodeBaseSampler(nn.Module):
...
@@ -111,13 +105,6 @@ class SpecDecodeBaseSampler(nn.Module):
output_with_bonus_tokens
[:,
-
1
]
=
torch
.
where
(
output
[:,
-
1
]
!=
-
1
,
output_with_bonus_tokens
[:,
-
1
]
=
torch
.
where
(
output
[:,
-
1
]
!=
-
1
,
bonus_token_ids
,
-
1
)
bonus_token_ids
,
-
1
)
# We disable bonus tokens because it causes corrupt KV cache for
# proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212
if
self
.
_disable_bonus_tokens
:
output_with_bonus_tokens
[:,
-
1
]
=
-
1
# Fill the recovered token ids.
# Fill the recovered token ids.
output
.
mul_
(
~
after_false_mask
).
add_
(
output
.
mul_
(
~
after_false_mask
).
add_
(
substitute_token_ids
.
mul
(
after_false_mask
))
substitute_token_ids
.
mul
(
after_false_mask
))
...
...
vllm/model_executor/layers/typical_acceptance_sampler.py
View file @
539aa992
...
@@ -16,15 +16,11 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
...
@@ -16,15 +16,11 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
self
,
self
,
posterior_threshold
:
float
,
posterior_threshold
:
float
,
posterior_alpha
:
float
,
posterior_alpha
:
float
,
disable_bonus_tokens
:
bool
=
False
,
strict_mode
:
bool
=
False
,
strict_mode
:
bool
=
False
,
):
):
"""Create a Typical Acceptance Sampler.
"""Create a Typical Acceptance Sampler.
Args:
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
during sampling. This catches correctness issues but adds
nontrivial latency.
nontrivial latency.
...
@@ -36,8 +32,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
...
@@ -36,8 +32,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
"""
"""
self
.
_posterior_threshold
=
posterior_threshold
self
.
_posterior_threshold
=
posterior_threshold
self
.
_posterior_alpha
=
posterior_alpha
self
.
_posterior_alpha
=
posterior_alpha
super
().
__init__
(
disable_bonus_tokens
=
disable_bonus_tokens
,
super
().
__init__
(
strict_mode
=
strict_mode
)
strict_mode
=
strict_mode
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -54,7 +49,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
...
@@ -54,7 +49,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
one token will be emitted.
one token will be emitted.
In the case where all draft tokens are accepted, the bonus token will be
In the case where all draft tokens are accepted, the bonus token will be
accepted
conditioned on self._disable_bonus_tokens being false
.
accepted.
Args:
Args:
target_probs: The probability distribution over token ids given
target_probs: The probability distribution over token ids given
...
@@ -85,7 +80,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
...
@@ -85,7 +80,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
target_probs
=
target_with_bonus_probs
[:,
:
-
1
]
target_probs
=
target_with_bonus_probs
[:,
:
-
1
]
accepted
=
self
.
_evaluate_accepted_tokens
(
target_probs
,
accepted
=
self
.
_evaluate_accepted_tokens
(
target_probs
,
draft_token_ids
)
draft_token_ids
)
recovered_token_ids
=
self
.
_
replacement
_token_ids
(
target_probs
)
recovered_token_ids
=
self
.
_
get_recovered
_token_ids
(
target_probs
)
output_token_ids
=
self
.
_create_output
(
accepted
,
recovered_token_ids
,
output_token_ids
=
self
.
_create_output
(
accepted
,
recovered_token_ids
,
draft_token_ids
,
draft_token_ids
,
bonus_token_ids
)
bonus_token_ids
)
...
@@ -153,16 +148,10 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
...
@@ -153,16 +148,10 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
accepted_mask
=
candidates_prob
>
threshold
accepted_mask
=
candidates_prob
>
threshold
return
accepted_mask
return
accepted_mask
def
_
replacement
_token_ids
(
self
,
target_probs
):
def
_
get_recovered
_token_ids
(
self
,
target_probs
):
"""
"""
Generate one replacement token ID for each sequence based on target
The recovered token ids will fill the first unmatched token
probabilities. The replacement token is used as the fallback option
by the target token.
if typical acceptance sampling does not accept any draft tokens for
that particular sequence.
This method computes the token IDs to be replaced by selecting the
token with the highest probability for each sequence in the first
position. The rest of the output is filled with -1.
Parameters
Parameters
----------
----------
...
@@ -173,13 +162,9 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
...
@@ -173,13 +162,9 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
Returns
Returns
-------
-------
torch.Tensor
torch.Tensor
A tensor of shape (batch_size, k) with the replacement
A tensor of shape (batch_size, k) with the recovered token
token IDs. Only the first column is set, and the rest of the
ids which are selected from target probs.
columns are filled with -1.
"""
"""
max_indices
=
torch
.
argmax
(
target_probs
[:,
0
,
:],
dim
=
1
)
max_indices
=
torch
.
argmax
(
target_probs
,
dim
=-
1
)
output
=
-
torch
.
ones
((
target_probs
.
shape
[
0
],
target_probs
.
shape
[
1
]),
dtype
=
self
.
token_id_dtype
,
return
max_indices
device
=
target_probs
.
device
)
output
[:,
0
]
=
max_indices
return
output
Prev
1
…
11
12
13
14
15
16
17
18
19
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment