Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
af7f4372
Commit
af7f4372
authored
Sep 03, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.5' into v0.5.5-dtk24.04.1
parents
5e19cdef
09c77926
Changes
465
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1152 additions
and
407 deletions
+1152
-407
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+172
-84
vllm/model_executor/layers/quantization/gguf.py
vllm/model_executor/layers/quantization/gguf.py
+161
-0
vllm/model_executor/layers/quantization/gptq.py
vllm/model_executor/layers/quantization/gptq.py
+10
-7
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+60
-65
vllm/model_executor/layers/quantization/marlin.py
vllm/model_executor/layers/quantization/marlin.py
+36
-32
vllm/model_executor/layers/quantization/tpu_int8.py
vllm/model_executor/layers/quantization/tpu_int8.py
+118
-0
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+12
-11
vllm/model_executor/layers/quantization/utils/quant_utils.py
vllm/model_executor/layers/quantization/utils/quant_utils.py
+9
-2
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+47
-7
vllm/model_executor/layers/rejection_sampler.py
vllm/model_executor/layers/rejection_sampler.py
+2
-2
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+84
-74
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+89
-28
vllm/model_executor/layers/spec_decode_base_sampler.py
vllm/model_executor/layers/spec_decode_base_sampler.py
+6
-3
vllm/model_executor/layers/vocab_parallel_embedding.py
vllm/model_executor/layers/vocab_parallel_embedding.py
+54
-5
vllm/model_executor/model_loader/__init__.py
vllm/model_executor/model_loader/__init__.py
+1
-4
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+146
-51
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+1
-7
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+83
-3
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+54
-20
vllm/model_executor/models/arctic.py
vllm/model_executor/models/arctic.py
+7
-2
No files found.
Too many changes to show.
To preserve performance only
465 of 465+
files are displayed.
Plain diff
Email patch
vllm/model_executor/layers/quantization/fp8.py
View file @
af7f4372
...
...
@@ -4,6 +4,7 @@ import torch
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
,
FusedMoEMethodBase
...
...
@@ -18,11 +19,13 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
apply_fp8_linear
,
convert_to_channelwise
,
create_per_tensor_scale_param
,
cutlass_fp8_supported
,
per_tensor_dequantize
,
requantize_with_max_scale
)
cutlass_fp8_supported
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
,
requantize_with_max_scale
)
from
vllm.model_executor.parameter
import
(
ModelWeightParameter
,
PerTensorScaleParameter
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.utils
import
print_warning_once
from
vllm.utils
import
is_hip
,
print_warning_once
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
...
...
@@ -118,7 +121,10 @@ class Fp8LinearMethod(LinearMethodBase):
# kernel for fast weight-only FP8 quantization
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
self
.
use_marlin
=
capability
<
89
self
.
use_marlin
=
capability
<
89
or
envs
.
VLLM_TEST_FORCE_FP8_MARLIN
# Disable marlin for rocm
if
is_hip
():
self
.
use_marlin
=
False
def
create_weights
(
self
,
...
...
@@ -132,6 +138,7 @@ class Fp8LinearMethod(LinearMethodBase):
):
del
input_size
,
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
layer
.
logical_widths
=
output_partition_sizes
...
...
@@ -143,37 +150,54 @@ class Fp8LinearMethod(LinearMethodBase):
weight_dtype
=
(
torch
.
float8_e4m3fn
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
else
params_dtype
)
weight
=
Parameter
(
torch
.
empty
(
output_size_per_partition
,
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
weight_dtype
),
requires_grad
=
False
)
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
**
extra_weight_attrs
,
"input_dim"
:
1
,
"output_dim"
:
0
,
})
# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# WEIGHT SCALE
scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
**
extra_weight_attrs
)
scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"weight_scale"
,
scale
)
# INPUT ACTIVATION SCALE
if
self
.
quant_config
.
activation_scheme
==
"static"
:
scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
**
extra_weight_attrs
)
scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"input_scale"
,
scale
)
else
:
layer
.
register_parameter
(
"input_scale"
,
None
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
# If checkpoint not serialized fp8, quantize the weights.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
scale
=
None
)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if
self
.
use_marlin
:
assert
weight_scale
.
numel
()
==
1
weight_scale
=
convert_to_channelwise
(
weight_scale
.
expand
(
len
(
layer
.
logical_widths
)),
layer
.
logical_widths
)
# Update the layer with the new values.
layer
.
weight
=
Parameter
(
qweight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
...
...
@@ -182,6 +206,11 @@ class Fp8LinearMethod(LinearMethodBase):
# If checkpoint is fp8, handle that there are N scales for N
# shards in a fused module
else
:
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
if
self
.
quant_config
.
activation_scheme
==
"static"
:
layer
.
input_scale
=
torch
.
nn
.
Parameter
(
layer
.
input_scale
.
data
,
requires_grad
=
False
)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if
self
.
use_marlin
:
...
...
@@ -193,9 +222,23 @@ class Fp8LinearMethod(LinearMethodBase):
# requantize the logical shards as a single weight.
else
:
# Dequant -> Quant with max scale so we can run per tensor.
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
# If rocm, use float8_e4m3fnuz.
if
is_hip
():
weight
,
weight_scale
,
input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
weight_scale
,
input_scale
=
layer
.
input_scale
)
if
input_scale
is
not
None
:
layer
.
input_scale
=
Parameter
(
input_scale
,
requires_grad
=
False
)
weight_scale
,
weight
=
requantize_with_max_scale
(
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
weight
=
weight
,
weight_scale
=
weight_scale
,
logical_widths
=
layer
.
logical_widths
,
)
...
...
@@ -205,8 +248,6 @@ class Fp8LinearMethod(LinearMethodBase):
if
self
.
quant_config
.
activation_scheme
==
"static"
:
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
else
:
layer
.
input_scale
=
None
if
self
.
use_marlin
:
prepare_fp8_layer_for_marlin
(
layer
)
...
...
@@ -281,23 +322,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
w13_
weight_
scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_scale"
,
w13_scale
)
layer
.
register_parameter
(
"w13_
weight_
scale"
,
w13_
weight_
scale
)
w2_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
w2_
weight_
scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_scale"
,
w2_scale
)
layer
.
register_parameter
(
"w2_
weight_
scale"
,
w2_
weight_
scale
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
set_weight_attrs
(
w13_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w13_weight_scale
,
{
"is_fp8_scale"
:
True
,
**
extra_weight_attrs
})
set_weight_attrs
(
w2_weight_scale
,
{
"is_fp8_scale"
:
True
,
**
extra_weight_attrs
})
# INPUT_SCALES
if
self
.
quant_config
.
activation_scheme
==
"static"
:
...
...
@@ -306,42 +353,50 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
a
13_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
w
13_
input_
scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"a13_scale"
,
a13_scale
)
set_weight_attrs
(
a13_scale
,
extra_weight_attrs
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
set_weight_attrs
(
w13_input_scale
,
{
"is_fp8_scale"
:
True
,
**
extra_weight_attrs
})
a2
_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
w2_input
_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"a2_scale"
,
a2_scale
)
set_weight_attrs
(
a2_scale
,
extra_weight_attrs
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
set_weight_attrs
(
w2_input_scale
,
{
"is_fp8_scale"
:
True
,
**
extra_weight_attrs
})
else
:
layer
.
a
13_scale
=
None
layer
.
a2
_scale
=
None
layer
.
w
13_
input_
scale
=
None
layer
.
w2_input
_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# If checkpoint is fp16, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# If rocm, use float8_e4m3fnuz as dtype
fp8_dtype
=
torch
.
float8_e4m3fnuz
\
if
is_hip
()
else
torch
.
float8_e4m3fn
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
w2_weight
=
torch
.
empty_like
(
layer
.
w2_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
dtype
=
fp8_dtype
)
w2_weight
=
torch
.
empty_like
(
layer
.
w2_weight
.
data
,
dtype
=
fp8_dtype
)
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
layer
.
w13_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
layer
.
w13_
weight_
scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
layer
.
num_experts
,
dtype
=
torch
.
float32
,
device
=
w13_weight
.
device
),
requires_grad
=
False
)
for
expert
in
range
(
layer
.
num_experts
):
w13_weight
[
expert
,
:,
:],
layer
.
w13_scale
[
w13_weight
[
expert
,
:,
:],
layer
.
w13_
weight_
scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
w2_weight
[
expert
,
:,
:],
layer
.
w2_scale
[
w2_weight
[
expert
,
:,
:],
layer
.
w2_
weight_
scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
...
...
@@ -357,39 +412,66 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
layer
.
a13_scale
is
None
or
layer
.
a2_scale
is
None
:
if
(
layer
.
w13_input_scale
is
None
or
layer
.
w2_input_scale
is
None
):
raise
ValueError
(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if
(
not
all_close_1d
(
layer
.
a
13_scale
)
or
not
all_close_1d
(
layer
.
a2
_scale
)):
if
(
not
all_close_1d
(
layer
.
w
13_
input_
scale
)
or
not
all_close_1d
(
layer
.
w2_input
_scale
)):
print_warning_once
(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. "
)
layer
.
a13_scale
=
torch
.
nn
.
Parameter
(
layer
.
a13_scale
.
max
(),
layer
.
w13_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w13_input_scale
.
max
(),
requires_grad
=
False
)
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
# If rocm, normalize the weights and scales to e4m3fnuz
if
is_hip
():
# Normalize the weights and scales
w13_weight
,
w13_weight_scale
,
w13_input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
layer
.
w13_weight
,
layer
.
w13_weight_scale
,
layer
.
w13_input_scale
)
w2_weight
,
w2_weight_scale
,
w2_input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
layer
.
w2_weight
,
layer
.
w2_weight_scale
,
layer
.
w2_input_scale
)
# Reset the parameter
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
w13_weight_scale
,
requires_grad
=
False
)
if
w13_input_scale
is
not
None
:
layer
.
w13_input_scale
=
torch
.
nn
.
Parameter
(
w13_input_scale
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
a2
_scale
=
torch
.
nn
.
Parameter
(
layer
.
a2_scale
.
max
()
,
layer
.
w2_weight
_scale
=
torch
.
nn
.
Parameter
(
w2_weight_scale
,
requires_grad
=
False
)
if
w2_input_scale
is
not
None
:
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
w2_input_scale
,
requires_grad
=
False
)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert
layer
.
w13_scale
is
not
None
assert
layer
.
w13_
weight_
scale
is
not
None
shard_size
=
layer
.
intermediate_size_per_partition
max_w13_scales
=
layer
.
w13_scale
.
max
(
dim
=
1
).
values
max_w13_scales
=
layer
.
w13_
weight_
scale
.
max
(
dim
=
1
).
values
for
expert_id
in
range
(
layer
.
num_experts
):
start
=
0
for
shard_id
in
range
(
2
):
dq_weight
=
per_tensor_dequantize
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
layer
.
w13_scale
[
expert_id
][
shard_id
])
layer
.
w13_
weight_
scale
[
expert_id
][
shard_id
])
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
=
ops
.
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
start
+=
shard_size
layer
.
w13_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
layer
.
w13_
weight_
scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
return
...
...
@@ -398,27 +480,33 @@ class Fp8MoEMethod(FusedMoEMethodBase):
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert
_group
:
Optional
[
int
]
=
None
,
topk
_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
renormalize
:
bool
,
use_grouped_topk
:
bool
,
topk
_group
:
Optional
[
int
]
=
None
,
num_expert
_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_moe
return
fused_moe
(
x
,
from
vllm.model_executor.layers.fused_moe
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
router_logits
,
top_k
,
renormalize
=
renormalize
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
use_fp8
=
True
,
w1_scale
=
layer
.
w13_scale
,
w2_scale
=
layer
.
w2_scale
,
a1_scale
=
layer
.
a13_scale
,
a2_scale
=
layer
.
a2_scale
,
use_grouped_topk
=
use_grouped_topk
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
)
use_fp8_w8a8
=
True
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
)
class
Fp8KVCacheMethod
(
BaseKVCacheMethod
):
...
...
vllm/model_executor/layers/quantization/gguf.py
0 → 100644
View file @
af7f4372
from
typing
import
Any
,
Dict
,
List
,
Optional
import
gguf
import
torch
from
torch.nn.parameter
import
Parameter
,
UninitializedParameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.utils
import
set_weight_attrs
class
GGUFConfig
(
QuantizationConfig
):
"""Config class for GGUF."""
def
__init__
(
self
,
)
->
None
:
pass
def
__repr__
(
self
)
->
str
:
return
(
"GGUFConfig()"
)
def
get_name
(
self
)
->
str
:
return
"gguf"
def
get_supported_act_dtypes
(
self
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
60
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
# no extra configs.
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"GGUFConfig"
:
return
cls
()
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
return
GGUFLinearMethod
(
self
)
elif
isinstance
(
layer
,
VocabParallelEmbedding
):
return
GGUFEmbeddingMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
def
_fuse_mul_mat
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
qweight_type
:
int
)
->
torch
.
Tensor
:
# use dequantize mulmat for IQmatrix, mmq for k-quants
if
qweight_type
>=
16
:
block_size
,
type_size
=
gguf
.
GGML_QUANT_SIZES
[
qweight_type
]
shape
=
(
qweight
.
shape
[
0
],
qweight
.
shape
[
1
]
//
type_size
*
block_size
)
weight
=
ops
.
ggml_dequantize
(
qweight
,
qweight_type
,
*
shape
)
y
=
x
@
weight
.
T
else
:
y
=
ops
.
ggml_mul_mat_a8
(
qweight
,
x
,
qweight_type
,
qweight
.
shape
[
0
])
return
y
class
GGUFLinearMethod
(
LinearMethodBase
):
"""Linear method for GGUF.
Args:
quant_config: The GGUF quantization config.
"""
def
__init__
(
self
,
quant_config
:
GGUFConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
output_size_per_partition
=
sum
(
output_partition_sizes
)
tensor_shape
=
(
output_size_per_partition
,
input_size_per_partition
)
qweight
=
UninitializedParameter
(
requires_grad
=
False
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
,
"tensor_shape"
:
tensor_shape
,
"is_gguf_weight"
:
True
,
"shard_size"
:
{},
"shard_id"
:
[],
})
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
qweight_type
=
Parameter
(
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
uint8
),
requires_grad
=
False
)
set_weight_attrs
(
qweight_type
,
{
"is_gguf_weight_type"
:
True
,
"weight_type"
:
0
,
"shard_weight_type"
:
{},
"ignore_warning"
:
True
})
set_weight_attrs
(
qweight_type
,
extra_weight_attrs
)
layer
.
register_parameter
(
"qweight_type"
,
qweight_type
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
shard_size
=
getattr
(
layer
.
qweight
,
"shard_size"
,
None
)
shard_id
=
getattr
(
layer
.
qweight
,
"shard_id"
,
None
)
if
shard_id
and
shard_size
:
result
=
[]
offset
=
0
# dequantize shard weights respectively
shard_id
=
[
"q"
,
"k"
,
"v"
]
if
"q"
in
shard_id
else
shard_id
for
id
in
shard_id
:
shard_weight
=
layer
.
qweight
[
offset
:
offset
+
shard_size
[
id
][
0
],
:
shard_size
[
id
][
1
]].
contiguous
()
qweight_type
=
layer
.
qweight_type
.
shard_weight_type
[
id
]
result
.
append
(
_fuse_mul_mat
(
x
,
shard_weight
,
qweight_type
))
offset
+=
shard_size
[
id
][
0
]
out
=
torch
.
cat
(
result
,
axis
=
1
)
else
:
qweight
=
layer
.
qweight
qweight_type
=
layer
.
qweight_type
.
weight_type
out
=
_fuse_mul_mat
(
x
,
qweight
,
qweight_type
)
if
bias
is
not
None
:
out
.
add_
(
bias
)
return
out
class
GGUFEmbeddingMethod
(
GGUFLinearMethod
):
"""Embedding method for GGUF.
Args:
quant_config: The GGUF quantization config.
"""
def
embedding
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
qweight
=
layer
.
qweight
qweight_type
=
layer
.
qweight_type
.
weight_type
block_size
,
type_size
=
gguf
.
GGML_QUANT_SIZES
[
qweight_type
]
hidden_size
=
qweight
.
shape
[
1
]
//
type_size
*
block_size
if
qweight_type
<
2
:
return
torch
.
embedding
(
qweight
,
x
)
x_flat
=
x
.
flatten
()
quant
=
torch
.
index_select
(
qweight
,
dim
=
0
,
index
=
x_flat
)
dequant
=
ops
.
ggml_dequantize
(
quant
,
qweight_type
,
hidden_size
,
x_flat
.
shape
[
0
])
return
dequant
.
view
(
*
x
.
shape
,
hidden_size
)
vllm/model_executor/layers/quantization/gptq.py
View file @
af7f4372
...
...
@@ -204,13 +204,7 @@ class GPTQLinearMethod(LinearMethodBase):
layer
.
exllama_state
=
exllama_state
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
qweight
=
layer
.
qweight
out_shape
=
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
-
1
],
)
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if
layer
.
exllama_state
==
ExllamaState
.
UNINITIALIZED
:
...
...
@@ -218,10 +212,19 @@ class GPTQLinearMethod(LinearMethodBase):
layer
.
g_idx
.
data
=
torch
.
argsort
(
layer
.
g_idx
).
to
(
torch
.
int
)
else
:
layer
.
g_idx
.
data
=
torch
.
empty
((
0
,
),
dtype
=
torch
.
int
,
device
=
layer
.
g_idx
.
device
)
layer
.
exllama_state
=
ExllamaState
.
READY
ops
.
gptq_shuffle
(
layer
.
qweight
,
layer
.
g_idx
,
self
.
quant_config
.
weight_bits
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
out_shape
=
x
.
shape
[:
-
1
]
+
(
layer
.
qweight
.
shape
[
-
1
],
)
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
output
=
ops
.
gptq_gemm
(
reshaped_x
,
layer
.
qweight
,
layer
.
qzeros
,
layer
.
scales
,
layer
.
g_idx
,
layer
.
exllama_state
==
ExllamaState
.
READY
,
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
af7f4372
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn
.parameter
import
Parameter
from
torch.nn
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
...
...
@@ -15,6 +14,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedColumnParameter
,
PackedvLLMParameter
,
RowvLLMParameter
)
from
vllm.scalar_type
import
scalar_types
logger
=
init_logger
(
__name__
)
...
...
@@ -136,8 +140,7 @@ class GPTQMarlinConfig(QuantizationConfig):
return
False
return
check_marlin_supported
(
quant_type
=
cls
.
TYPE_MAP
[(
num_bits
,
sym
)],
group_size
=
group_size
,
min_capability
=
cls
.
get_min_capability
())
group_size
=
group_size
)
class
GPTQMarlinLinearMethod
(
LinearMethodBase
):
...
...
@@ -160,9 +163,11 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
)
->
None
:
del
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
is_row_parallel
=
input_size
!=
input_size_per_partition
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
# Normalize group_size
if
self
.
quant_config
.
group_size
!=
-
1
:
...
...
@@ -191,80 +196,66 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
scales_and_zp_size
=
input_size_per_partition
//
group_size
# Quantized weights
qweight
=
Parameter
(
torch
.
empty
(
qweight
=
PackedvLLM
Parameter
(
data
=
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
**
extra_weight_attrs
,
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
0
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
},
)
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
0
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
weight_loader
=
weight_loader
)
# Activation order
g_idx
=
Parameter
(
torch
.
empty
(
g_idx
=
RowvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs
(
g_idx
,
{
**
extra_weight_attrs
,
"input_dim"
:
0
,
"ignore_warning"
:
True
},
)
# Scales
scales
=
Parameter
(
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
**
extra_weight_attrs
,
"input_dim"
:
scales_and_zp_input_dim
,
"output_dim"
:
1
,
},
)
input_dim
=
0
,
weight_loader
=
weight_loader
)
# Quantized zero-points
qzeros
=
Parameter
(
qzeros_args
=
{
"data"
:
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
device
=
"meta"
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qzeros
,
{
**
extra_weight_attrs
,
"input_dim"
:
scales_and_zp_input_dim
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
},
)
"weight_loader"
:
weight_loader
}
weight_scale_args
=
{
"data"
:
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
"weight_loader"
:
weight_loader
}
if
scales_and_zp_input_dim
is
None
:
scales
=
ChannelQuantScaleParameter
(
output_dim
=
1
,
**
weight_scale_args
)
qzeros
=
PackedColumnParameter
(
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
**
qzeros_args
)
else
:
scales
=
GroupQuantScaleParameter
(
output_dim
=
1
,
input_dim
=
0
,
**
weight_scale_args
)
qzeros
=
PackedvLLMParameter
(
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
**
qzeros_args
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"g_idx"
,
g_idx
)
...
...
@@ -282,6 +273,10 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
device
=
layer
.
qweight
.
device
# required by torch.compile
layer
.
qweight
=
Parameter
(
layer
.
qweight
.
data
,
requires_grad
=
False
)
layer
.
scales
=
Parameter
(
layer
.
scales
.
data
,
requires_grad
=
False
)
# Allocate marlin workspace
layer
.
workspace
=
marlin_make_workspace
(
layer
.
output_size_per_partition
,
device
)
...
...
vllm/model_executor/layers/quantization/marlin.py
View file @
af7f4372
...
...
@@ -9,7 +9,10 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedvLLMParameter
)
logger
=
init_logger
(
__name__
)
...
...
@@ -132,6 +135,7 @@ class MarlinLinearMethod(LinearMethodBase):
**
extra_weight_attrs
,
):
del
output_size
# Unused.
weight_loader
=
extra_weight_attrs
[
"weight_loader"
]
if
params_dtype
!=
torch
.
float16
:
raise
ValueError
(
...
...
@@ -170,64 +174,64 @@ class MarlinLinearMethod(LinearMethodBase):
"Each permutation group must reside on the same gpu"
)
# Quantized 4Bit weights packed into Int32.
qweight
=
Parameter
(
torch
.
empty
(
qweight
=
PackedvLLM
Parameter
(
data
=
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
tile_size
,
output_size_per_partition
*
self
.
quant_config
.
tile_size
//
self
.
quant_config
.
pack_factor
,
device
=
"cuda"
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
"marlin_tile_size"
:
self
.
quant_config
.
tile_size
,
},
)
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
marlin_tile_size
=
self
.
quant_config
.
tile_size
,
weight_loader
=
weight_loader
)
# Determine if channelwise or not
input_groups
=
(
1
if
self
.
quant_config
.
group_size
==
-
1
else
input_size_per_partition
//
self
.
quant_config
.
group_size
)
scales
=
Parameter
(
weight_scale_args
=
{
"data"
:
torch
.
empty
(
input_groups
,
output_size_per_partition
,
device
=
"cuda"
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
"input_dim"
:
None
if
input_groups
==
1
else
0
,
"output_dim"
:
1
,
},
)
"weight_loader"
:
weight_loader
}
if
input_groups
==
1
:
scales
=
ChannelQuantScaleParameter
(
output_dim
=
1
,
**
weight_scale_args
)
else
:
scales
=
GroupQuantScaleParameter
(
output_dim
=
1
,
input_dim
=
0
,
**
weight_scale_args
)
# Allocate workspace (Used for internal locking mechanism)
max_workspace_size
=
(
output_size_per_partition
//
self
.
quant_config
.
min_n_threads
)
*
self
.
quant_config
.
max_parallel
workspace
=
Parameter
(
torch
.
zeros
(
max_workspace_size
,
workspace
=
BasevLLMParameter
(
data
=
torch
.
zeros
(
max_workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
int
),
requires_grad
=
False
)
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"B"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"s"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
layer
.
register_parameter
(
"workspace"
,
workspace
)
set_weight_attrs
(
workspace
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# required by torch.compile
layer
.
B
=
Parameter
(
layer
.
B
.
data
,
requires_grad
=
False
)
layer
.
s
=
Parameter
(
layer
.
s
.
data
,
requires_grad
=
False
)
layer
.
workspace
=
Parameter
(
layer
.
workspace
.
data
,
requires_grad
=
False
)
def
apply
(
self
,
...
...
vllm/model_executor/layers/quantization/tpu_int8.py
0 → 100644
View file @
af7f4372
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
ACTIVATION_SCHEMES
=
[
"none"
]
class
Int8TpuConfig
(
QuantizationConfig
):
"""Int8 Quantization Config class for TPU Backend."""
def
__init__
(
self
,
activation_scheme
:
str
=
"none"
,
)
->
None
:
if
activation_scheme
not
in
ACTIVATION_SCHEMES
:
raise
ValueError
(
f
"Unsupported activation scheme
{
activation_scheme
}
"
)
self
.
activation_scheme
=
activation_scheme
def
get_name
(
self
)
->
str
:
return
"tpu_int8"
def
get_supported_act_dtypes
(
self
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
float16
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
raise
NotImplementedError
(
"This function should not be called with TPU Backend"
)
@
staticmethod
def
get_config_filenames
()
->
List
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"Int8TpuConfig"
:
activation_scheme
=
cls
.
get_from_keys
(
config
,
[
"activation_scheme"
])
return
cls
(
activation_scheme
=
activation_scheme
)
def
get_quant_method
(
self
,
layer
:
Module
,
prefix
:
str
)
->
Optional
[
"TPUInt8LinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
TPUInt8LinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
TPUInt8LinearMethod
(
LinearMethodBase
):
"""Int8 Linear method for TPU Quant. """
def
__init__
(
self
,
quant_config
:
Int8TpuConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
**
extra_weight_attrs
,
"input_dim"
:
1
,
"output_dim"
:
0
,
})
def
_quantize_weight
(
self
,
weight
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
weight_dtype
=
weight
.
dtype
weight
=
weight
.
cpu
().
to
(
torch
.
float32
)
n_bit
=
8
eps
=
1e-5
max_int
=
2
**
(
n_bit
-
1
)
-
1
min_int
=
-
(
2
**
(
n_bit
-
1
))
max_val
=
weight
.
abs
().
amax
(
dim
=-
1
,
keepdim
=
True
)
max_val
=
max_val
.
clamp
(
min
=
eps
)
qscale
=
max_val
/
max_int
qweight
=
torch
.
clamp
(
torch
.
round
(
weight
*
(
1.0
/
qscale
)),
min_int
,
max_int
).
to
(
torch
.
int8
)
qscale
=
qscale
.
squeeze
().
to
(
weight_dtype
)
return
qweight
,
qscale
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
device
=
layer
.
weight
.
device
qweight
,
qscale
=
self
.
_quantize_weight
(
layer
.
weight
)
qweight
=
qweight
.
to
(
device
)
qscale
=
qscale
.
to
(
device
)
layer
.
weight
=
Parameter
(
qweight
,
requires_grad
=
False
)
layer
.
scale
=
Parameter
(
qscale
,
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
try
:
import
torch_xla.experimental.xla_quantized_matmul
# noqa: F401
except
ImportError
as
err
:
raise
ImportError
(
"Please install torch_xla by following the instructions at "
"https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html "
# noqa: E501
"to run vLLM on TPU."
)
from
err
weight
=
layer
.
weight
scale
=
layer
.
scale
out
=
torch
.
ops
.
xla
.
quantized_matmul
(
x
,
weight
,
scale
)
if
bias
is
not
None
:
out
=
out
+
bias
return
out
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
af7f4372
...
...
@@ -26,12 +26,13 @@ USE_FP32_REDUCE_DEFAULT = True
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# TODO: we may want to move this into the C++ so its closer to the actual impl
def
query_marlin_supported_quant_types
(
has_zp
:
bool
,
min_capability
:
Optional
[
int
]
=
None
):
if
min_capability
is
None
:
device_capability
:
Optional
[
int
]
=
None
):
if
device_capability
is
None
:
major
,
minor
=
current_platform
.
get_device_capability
()
min
_capability
=
major
*
10
+
minor
device
_capability
=
major
*
10
+
minor
if
min
_capability
<
80
:
if
device
_capability
<
80
:
return
[]
if
has_zp
:
...
...
@@ -48,20 +49,20 @@ def _check_marlin_supported(
quant_type
:
ScalarType
,
group_size
:
Optional
[
int
],
has_zp
:
bool
,
min
_capability
:
Optional
[
int
]
=
None
)
->
Tuple
[
bool
,
Optional
[
str
]]:
device
_capability
:
Optional
[
int
]
=
None
)
->
Tuple
[
bool
,
Optional
[
str
]]:
if
min
_capability
is
None
:
if
device
_capability
is
None
:
major
,
minor
=
current_platform
.
get_device_capability
()
min
_capability
=
major
*
10
+
minor
device
_capability
=
major
*
10
+
minor
supported_types
=
query_marlin_supported_quant_types
(
has_zp
,
min
_capability
)
has_zp
,
device
_capability
)
if
quant_type
not
in
supported_types
:
return
(
False
,
f
"Marlin does not support weight_bits =
{
quant_type
}
. "
f
"Only types =
{
supported_types
}
"
f
"are supported (for group_size =
{
group_size
}
, "
f
"
min
_capability =
{
min
_capability
}
, zp =
{
has_zp
}
)."
)
f
"
device
_capability =
{
device
_capability
}
, zp =
{
has_zp
}
)."
)
if
(
group_size
is
None
or
group_size
not
in
MARLIN_SUPPORTED_GROUP_SIZES
):
return
(
False
,
f
"Marlin does not support group_size =
{
group_size
}
. "
f
"Only group_sizes =
{
MARLIN_SUPPORTED_GROUP_SIZES
}
"
...
...
@@ -73,9 +74,9 @@ def _check_marlin_supported(
def
check_marlin_supported
(
quant_type
:
ScalarType
,
group_size
:
int
,
has_zp
:
bool
=
False
,
min
_capability
:
Optional
[
int
]
=
None
)
->
bool
:
device
_capability
:
Optional
[
int
]
=
None
)
->
bool
:
cond
,
_
=
_check_marlin_supported
(
quant_type
,
group_size
,
has_zp
,
min
_capability
)
device
_capability
)
return
cond
...
...
vllm/model_executor/layers/quantization/utils/quant_utils.py
View file @
af7f4372
...
...
@@ -81,7 +81,8 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
def
quantize_weights
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
group_size
:
int
,
zero_points
:
bool
=
False
):
zero_points
:
bool
=
False
,
ref_zero_points_after_scales
:
bool
=
False
):
assert
quant_type
.
is_integer
(),
\
"Floating point quantization may work but has not been tested"
...
...
@@ -126,6 +127,12 @@ def quantize_weights(w: torch.Tensor,
w_q
=
torch
.
clamp
(
w_q
,
min_q_val
,
max_q_val
)
# Compute ref (dequantized)
# For some kernels (namely Machete) the zero-points are applied after the
# scales are applied, for this case computing the reference in similar way
# allows us to use tighter error tolerances in our unit tests.
if
ref_zero_points_after_scales
and
zero_points
:
w_ref
=
w_q
.
to
(
orig_type
)
*
w_s
-
maybe_w_zp
.
to
(
orig_type
)
*
w_s
else
:
w_ref
=
(
w_q
-
(
maybe_w_zp
if
zero_points
else
0
)).
to
(
orig_type
)
*
w_s
if
quant_type
.
has_bias
():
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
af7f4372
...
...
@@ -6,9 +6,19 @@ from torch.nn import Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_hip
# scaled_mm in pytorch on rocm has a bug that requires always
# providing scaling factor for result. This value is created
# as global value to avoid multiple tensor allocations, and
# can be removed once pytorch fixes the bug.
TORCH_SCALED_MM_SCALE_RESULT
=
torch
.
ones
(
1
).
cuda
()
if
is_hip
()
else
None
def
cutlass_fp8_supported
()
->
bool
:
# cutlass is not supported on Rocm
if
is_hip
():
return
False
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
...
...
@@ -147,13 +157,19 @@ def apply_fp8_linear(
if
per_tensor_weights
and
per_tensor_activations
:
# Fused GEMM_DQ
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
output
=
torch
.
_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
scale_result
=
TORCH_SCALED_MM_SCALE_RESULT
,
bias
=
bias
)
# Since in torch 2.5, scaled_mm only returns single value
# This should be removed when vllm-nvidia also moves to 2.5
if
is_hip
():
return
torch
.
narrow
(
output
,
0
,
0
,
input
.
shape
[
0
])
return
torch
.
narrow
(
output
[
0
],
0
,
0
,
input
.
shape
[
0
])
else
:
# Fallback for channelwise case, where we use unfused DQ
...
...
@@ -207,3 +223,27 @@ def apply_int8_linear(
scale_b
=
weight_scale
,
out_dtype
=
input
.
dtype
,
bias
=
bias
)
def
normalize_e4m3fn_to_e4m3fnuz
(
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
assert
weight
.
dtype
==
torch
.
float8_e4m3fn
# The bits pattern 10000000(-128) represents zero in e4m3fn
# but NaN in e4m3fnuz. So here we set it to 0.
# https://onnx.ai/onnx/technical/float8.html
weight_as_int8
=
weight
.
view
(
torch
.
int8
)
ROCM_FP8_NAN_AS_INT
=
-
128
weight_as_int8
[
weight_as_int8
==
ROCM_FP8_NAN_AS_INT
]
=
0
weight
=
weight_as_int8
.
view
(
torch
.
float8_e4m3fnuz
)
# For the same bits representation, e4m3fnuz value is half of
# the e4m3fn value, so we should double the scaling factor to
# get the same dequantized value.
# https://onnx.ai/onnx/technical/float8.html
weight_scale
=
weight_scale
*
2.0
if
input_scale
is
not
None
:
input_scale
=
input_scale
*
2.0
return
weight
,
weight_scale
,
input_scale
vllm/model_executor/layers/rejection_sampler.py
View file @
af7f4372
...
...
@@ -78,8 +78,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
if
self
.
_strict_mode
:
self
.
_raise_if_incorrect_input
(
target_probs
,
bonus
_token_ids
,
draft_prob
s
,
draft_
token_id
s
)
self
.
_raise_if_incorrect_input
(
target_probs
,
draft
_token_ids
,
bonus_token_id
s
,
draft_
prob
s
)
accepted
,
recovered_token_ids
=
(
self
.
_batch_modified_rejection_sampling
(
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
af7f4372
...
...
@@ -28,7 +28,7 @@ import torch
import
torch.nn
as
nn
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.
util
s
import
is_tpu
from
vllm.
platform
s
import
current_platform
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -46,15 +46,23 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
def
_apply_rotary_emb
(
x
:
torch
.
Tensor
,
freqs_cis
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
x_
=
torch
.
view_as_complex
(
torch
.
stack
(
torch
.
chunk
(
x
.
transpose
(
1
,
2
).
float
(),
2
,
dim
=-
1
),
dim
=-
1
))
x_out
=
torch
.
view_as_real
(
x_
*
freqs_cis
).
type_as
(
x
)
x_out
=
torch
.
cat
(
torch
.
chunk
(
x_out
,
2
,
dim
=-
1
),
dim
=-
2
)
x_out
=
x_out
.
reshape
(
x_out
.
shape
[
0
],
x_out
.
shape
[
1
],
x_out
.
shape
[
2
],
-
1
).
transpose
(
1
,
2
)
return
x_out
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
"""
orig_dtype
=
x
.
dtype
x
=
x
.
float
()
x1
,
x2
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
cos
=
cos
.
unsqueeze
(
-
2
)
sin
=
sin
.
unsqueeze
(
-
2
)
o1
=
x1
*
cos
-
x2
*
sin
o2
=
x2
*
cos
+
x1
*
sin
return
torch
.
cat
((
o1
,
o2
),
dim
=-
1
).
to
(
orig_dtype
)
class
RotaryEmbedding
(
CustomOp
):
...
...
@@ -78,22 +86,13 @@ class RotaryEmbedding(CustomOp):
self
.
dtype
=
dtype
cache
=
self
.
_compute_cos_sin_cache
()
self
.
use_native2
=
is_tpu
()
and
is_neox_style
if
not
self
.
use_native2
:
cache
=
cache
.
to
(
dtype
)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
else
:
cos
,
sin
=
cache
.
chunk
(
2
,
dim
=-
1
)
freqs_cis
=
cos
+
1j
*
sin
self
.
register_buffer
(
"freqs_cis"
,
freqs_cis
,
persistent
=
False
)
self
.
use_native2
=
current_platform
.
is_tpu
()
and
is_neox_style
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
"""Compute the inverse frequency."""
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
# However, we use `torch.arange(..., dtype=torch.float)` instead to
# avoid numerical issues with large base values (e.g., 10000000).
# This may cause a slight numerical difference between the HF
# implementation and ours.
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
...
...
@@ -173,28 +172,25 @@ class RotaryEmbedding(CustomOp):
This method might perform better than `forward_native()` when compiled.
"""
if
positions
.
dim
()
==
1
:
batch_size
=
1
seq_len
=
positions
.
shape
[
0
]
else
:
batch_size
,
seq_len
=
positions
.
shape
if
offsets
is
not
None
:
positions
=
positions
+
offsets
freqs_cis
=
self
.
freqs_cis
.
index_select
(
0
,
positions
.
flatten
())
freqs_cis
=
freqs_cis
.
view
(
batch_size
,
1
,
seq_len
,
-
1
)
positions
=
positions
.
flatten
()
num_tokens
=
positions
.
shape
[
0
]
cos_sin
=
self
.
cos_sin_cache
.
index_select
(
0
,
positions
)
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
query_shape
=
query
.
shape
query
=
query
.
view
(
batch_size
,
seq_len
,
-
1
,
self
.
head_size
)
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
_apply_rotary_emb
(
query_rot
,
freqs_cis
)
query_rot
=
_apply_rotary_emb
(
query_rot
,
cos
,
sin
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
key
=
key
.
view
(
batch_size
,
seq_len
,
-
1
,
self
.
head_size
)
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
_apply_rotary_emb
(
key_rot
,
freqs_cis
)
key_rot
=
_apply_rotary_emb
(
key_rot
,
cos
,
sin
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
...
...
@@ -723,44 +719,50 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
return
query
,
key
class
Gem
maRotaryEmbedding
(
RotaryEmbedding
):
class
Lla
ma
3
RotaryEmbedding
(
RotaryEmbedding
):
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
int64
).
float
()
/
self
.
rotary_dim
))
return
inv_freq
class
ExtendedRotaryEmbedding
(
RotaryEmbedding
):
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
scaling_factor
:
float
,
low_freq_factor
:
float
,
high_freq_factor
:
float
,
orig_max_position
:
int
,
)
->
None
:
self
.
scaling_factor
=
scaling_factor
self
.
low_freq_factor
=
low_freq_factor
self
.
high_freq_factor
=
high_freq_factor
self
.
orig_max_position
=
orig_max_position
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
inv_freqs
=
super
().
_compute_inv_freq
(
base
)
return
self
.
apply_scaling
(
inv_freqs
)
def
apply_scaling
(
self
,
freqs
:
torch
.
Tensor
):
scale_factor
=
8
low_freq_factor
=
1
high_freq_factor
=
4
old_context_len
=
8192
low_freq_wavelen
=
old_context_len
/
low_freq_factor
high_freq_wavelen
=
old_context_len
/
high_freq_factor
new_freqs
=
[]
for
freq
in
freqs
:
wavelen
=
2
*
math
.
pi
/
freq
if
wavelen
<
high_freq_wavelen
:
new_freqs
.
append
(
freq
)
elif
wavelen
>
low_freq_wavelen
:
new_freqs
.
append
(
freq
/
scale_factor
)
low_freq_wavelen
=
self
.
orig_max_position
/
self
.
low_freq_factor
high_freq_wavelen
=
self
.
orig_max_position
/
self
.
high_freq_factor
wave_len
=
2
*
math
.
pi
/
inv_freqs
if
self
.
low_freq_factor
!=
self
.
high_freq_factor
:
smooth
=
(
self
.
orig_max_position
/
wave_len
-
self
.
low_freq_factor
)
/
(
self
.
high_freq_factor
-
self
.
low_freq_factor
)
else
:
assert
low_freq_wavelen
!=
high_freq_wavelen
smooth
=
(
old_context_len
/
wavelen
-
low_freq_factor
)
/
(
high_freq_factor
-
low_freq_factor
)
new_freqs
.
append
((
1
-
smooth
)
*
freq
/
scale_factor
+
smooth
*
freq
)
return
torch
.
tensor
(
new_freqs
,
dtype
=
freqs
.
dtype
,
device
=
freqs
.
device
)
smooth
=
0
new_freqs
=
torch
.
where
(
wave_len
<
high_freq_wavelen
,
inv_freqs
,
torch
.
where
(
wave_len
>
low_freq_wavelen
,
inv_freqs
/
self
.
scaling_factor
,
(
1
-
smooth
)
*
inv_freqs
/
self
.
scaling_factor
+
smooth
*
inv_freqs
,
),
)
return
new_freqs
_ROPE_DICT
:
Dict
[
Tuple
,
RotaryEmbedding
]
=
{}
...
...
@@ -774,7 +776,7 @@ def get_rope(
is_neox_style
:
bool
=
True
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
rotary_percent
:
float
=
1.0
,
partial_rotary_factor
:
float
=
1.0
,
)
->
RotaryEmbedding
:
if
dtype
is
None
:
dtype
=
torch
.
get_default_dtype
()
...
...
@@ -787,12 +789,13 @@ def get_rope(
rope_scaling_args
=
tuple
(
rope_scaling_tuple
.
items
())
else
:
rope_scaling_args
=
None
if
rotary_percent
<
1.0
:
rotary_dim
=
int
(
rotary_dim
*
rotary_percent
)
if
partial_rotary_factor
<
1.0
:
rotary_dim
=
int
(
rotary_dim
*
partial_rotary_factor
)
key
=
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
rope_scaling_args
,
dtype
)
if
key
in
_ROPE_DICT
:
return
_ROPE_DICT
[
key
]
if
rope_scaling
is
None
:
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
)
...
...
@@ -801,12 +804,19 @@ def get_rope(
"type"
]
if
"type"
in
rope_scaling
else
rope_scaling
[
"rope_type"
]
# The correct one should be "longrope" but keep "su" here
# for backward compatible
if
scaling_type
not
in
{
"su"
,
"longrope"
,
"llama3"
}:
if
scaling_type
not
in
{
"su"
,
"longrope"
}:
scaling_factor
=
rope_scaling
[
"factor"
]
if
scaling_type
==
"llama3"
:
rotary_emb
=
ExtendedRotaryEmbedding
(
head_size
,
rotary_dim
,
low_freq_factor
=
rope_scaling
[
"low_freq_factor"
]
high_freq_factor
=
rope_scaling
[
"high_freq_factor"
]
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
rotary_emb
=
Llama3RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
)
is_neox_style
,
dtype
,
scaling_factor
,
low_freq_factor
,
high_freq_factor
,
original_max_position
)
elif
scaling_type
==
"linear"
:
rotary_emb
=
LinearScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
...
...
vllm/model_executor/layers/sampler.py
View file @
af7f4372
"""A layer that samples the next tokens from the model's outputs."""
import
itertools
import
warnings
from
importlib.util
import
find_spec
from
math
import
inf
from
typing
import
Dict
,
List
,
Optional
,
Tuple
...
...
@@ -11,6 +13,7 @@ from vllm.triton_utils import HAS_TRITON
if
HAS_TRITON
:
from
vllm.model_executor.layers.ops.sample
import
sample
as
sample_triton
import
vllm.envs
as
envs
from
vllm.model_executor.sampling_metadata
import
(
SamplingMetadata
,
SamplingTensors
,
SequenceGroupToSample
)
...
...
@@ -19,6 +22,16 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
SequenceOutput
)
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
find_spec
(
"flashinfer"
):
import
flashinfer.sampling
# yapf: disable
from
flashinfer.sampling
import
(
top_k_top_p_sampling_from_probs
as
flashinfer_top_k_top_p_sampling
)
# yapf: enable
else
:
flashinfer_top_k_top_p_sampling
=
None
# (num_token_ids, num_parent_ids) per sequence group.
SampleResultType
=
List
[
Tuple
[
List
[
int
],
List
[
int
]]]
...
...
@@ -51,6 +64,7 @@ class Sampler(nn.Module):
# containing the sampled token ids and probabilities. This is used by
# speculative decoding.
self
.
include_gpu_probs_tensor
=
False
self
.
should_modify_greedy_probs_inplace
=
False
def
_init_sampling_tensors
(
self
,
...
...
@@ -117,11 +131,12 @@ class Sampler(nn.Module):
sampling_tensors
.
frequency_penalties
,
sampling_tensors
.
repetition_penalties
)
#
A
pply temperature scaling.
#
Use float32 to a
pply temperature scaling.
# Use in-place division to avoid creating a new tensor.
logits
=
logits
.
to
(
torch
.
float
)
logits
.
div_
(
sampling_tensors
.
temperatures
.
unsqueeze
(
dim
=
1
))
if
do_top_p_top_k
:
if
do_top_p_top_k
and
flashinfer_top_k_top_p_sampling
is
None
:
logits
=
_apply_top_k_top_p
(
logits
,
sampling_tensors
.
top_ps
,
sampling_tensors
.
top_ks
)
...
...
@@ -177,8 +192,7 @@ class Sampler(nn.Module):
This is used by speculative decoding, which requires that the sampling
method be encoded into the probability distribution.
"""
# Modify greedy probs if include_gpu_probs_tensor is set.
return
self
.
include_gpu_probs_tensor
return
self
.
should_modify_greedy_probs_inplace
def
_get_bin_counts_and_mask
(
...
...
@@ -475,14 +489,7 @@ def _multinomial(
seq_groups
:
Optional
[
List
[
SequenceGroupToSample
]]
=
None
,
)
->
torch
.
Tensor
:
if
num_samples
>
1
:
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
# This allows us to do sampling with replacement by creating
# num_samples copies of each row in the tensor, and then
# batch sampling the resulting tensor.
probs
=
probs
[:,
None
,
:].
expand
(
probs
.
shape
[
0
],
num_samples
,
probs
.
shape
[
1
]).
contiguous
().
view
(
-
1
,
probs
.
shape
[
1
])
probs
=
probs
.
repeat_interleave
(
num_samples
,
dim
=
0
)
q
=
torch
.
empty_like
(
probs
)
if
seq_groups
is
None
:
q
.
exponential_
()
...
...
@@ -490,17 +497,57 @@ def _multinomial(
sample_idx
=
0
for
seq_group
in
seq_groups
:
seq_ids
=
seq_group
.
seq_ids
next_sample_idx
=
sample_idx
+
len
(
seq_ids
)
*
num_samples
q
[
sample_idx
:
next_sample_idx
].
exponential_
(
generator
=
seq_group
.
generator
)
sample_idx
=
next_sample_idx
stride
=
len
(
seq_ids
)
*
num_samples
assert
seq_group
.
generator
is
not
None
q
[
sample_idx
:
sample_idx
+
stride
].
exponential_
(
generator
=
seq_group
.
generator
)
sample_idx
+=
stride
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
def
_top_k_top_p_multinomial_with_flashinfer
(
probs
:
torch
.
Tensor
,
top_ks
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
,
num_samples
:
int
,
seq_groups
:
Optional
[
List
[
SequenceGroupToSample
]]):
max_top_k_round
=
32
if
num_samples
>
1
:
probs
=
probs
.
repeat_interleave
(
num_samples
,
dim
=
0
)
top_ks
=
top_ks
.
repeat_interleave
(
num_samples
)
top_ps
=
top_ps
.
repeat_interleave
(
num_samples
)
batch_size
=
probs
.
shape
[
0
]
uniform_samples
=
torch
.
empty
((
max_top_k_round
,
batch_size
),
device
=
probs
.
device
)
if
seq_groups
is
None
:
uniform_samples
.
uniform_
()
else
:
sample_idx
=
0
for
seq_group
in
seq_groups
:
seq_ids
=
seq_group
.
seq_ids
stride
=
len
(
seq_ids
)
*
num_samples
assert
seq_group
.
generator
is
not
None
uniform_samples
[:,
sample_idx
:
sample_idx
+
stride
].
uniform_
(
generator
=
seq_group
.
generator
)
sample_idx
+=
stride
batch_next_token_ids
,
success
=
flashinfer_top_k_top_p_sampling
(
probs
,
uniform_samples
,
top_ks
,
top_ps
,
)
if
not
success
.
all
():
warnings
.
warn
(
"FlashInfer rejection sampling failed, fallback."
,
stacklevel
=
1
)
probs
=
flashinfer
.
sampling
.
top_k_renorm_prob
(
probs
,
top_ks
)
probs
=
flashinfer
.
sampling
.
top_p_renorm_prob
(
probs
,
top_ps
)
batch_next_token_ids
=
flashinfer
.
sampling
.
sampling_from_probs
(
probs
,
uniform_samples
[
0
])
return
batch_next_token_ids
.
view
(
-
1
,
num_samples
)
def
_sample_with_torch
(
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
,
)
->
Tuple
[
SampleResultType
,
Optional
[
torch
.
Tensor
]]:
...
...
@@ -563,18 +610,28 @@ def _sample_with_torch(
sampling_params
=
seq_group
.
sampling_params
max_best_of_in_batch
=
max
(
max_best_of_in_batch
,
sampling_params
.
best_of
)
seeded_args
=
{}
if
sampling_type
==
SamplingType
.
RANDOM
else
{
"seq_groups"
:
seq_groups
,
}
seq_groups_arg
=
(
None
if
sampling_type
==
SamplingType
.
RANDOM
else
seq_groups
)
if
flashinfer_top_k_top_p_sampling
is
not
None
:
multinomial_samples
[
sampling_type
]
=
_top_k_top_p_multinomial_with_flashinfer
(
probs
[
long_sample_indices
],
sampling_tensors
.
top_ks
[
long_sample_indices
],
sampling_tensors
.
top_ps
[
long_sample_indices
],
max_best_of_in_batch
,
seq_groups_arg
,
)
else
:
multinomial_samples
[
sampling_type
]
=
_multinomial
(
probs
[
long_sample_indices
],
max_best_of_in_batch
,
**
seeded_args
)
probs
[
long_sample_indices
],
max_best_of_in_batch
,
seq_groups
=
seq_groups_arg
)
if
sampled_token_ids_tensor
is
not
None
:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor
[
long_sample_indices
]
=
multinomial_samples
[
sampling_type
]
sampled_token_ids_tensor
[
long_sample_indices
]
=
\
multinomial_samples
[
sampling_type
]
.
to
(
torch
.
long
)
elif
sampling_type
==
SamplingType
.
BEAM
:
beam_search_logprobs
=
logprobs
[
sample_indices
]
...
...
@@ -692,9 +749,12 @@ def _sample_with_triton_kernel(
def
_sample
(
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
,
)
->
Tuple
[
SampleResultType
,
Optional
[
torch
.
Tensor
]]:
"""
Args:
...
...
@@ -712,6 +772,7 @@ def _sample(
probs
,
logprobs
,
sampling_metadata
,
sampling_tensors
,
include_gpu_probs_tensor
=
include_gpu_probs_tensor
,
modify_greedy_probs
=
modify_greedy_probs
,
)
...
...
vllm/model_executor/layers/spec_decode_base_sampler.py
View file @
af7f4372
from
abc
import
abstractmethod
from
typing
import
Dict
,
Optional
from
typing
import
Dict
,
Optional
,
Union
import
torch
import
torch.jit
...
...
@@ -36,9 +36,12 @@ class SpecDecodeBaseSampler(nn.Module):
self
.
num_emitted_tokens
:
Optional
[
torch
.
Tensor
]
=
None
self
.
num_draft_tokens
:
int
=
0
def
init_gpu_tensors
(
self
,
rank
:
int
)
->
None
:
def
init_gpu_tensors
(
self
,
device
:
Union
[
int
,
str
]
)
->
None
:
assert
self
.
num_accepted_tokens
is
None
device
=
f
"cuda:
{
rank
}
"
if
isinstance
(
device
,
int
):
device
=
f
"cuda:
{
device
}
"
elif
not
isinstance
(
device
,
str
):
raise
ValueError
(
f
"Device must be int or str, get
{
type
(
device
)
}
"
)
self
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
device
)
...
...
vllm/model_executor/layers/vocab_parallel_embedding.py
View file @
af7f4372
...
...
@@ -3,19 +3,46 @@ from typing import List, Optional, Sequence, Tuple
import
torch
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
,
UninitializedParameter
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.linear
import
UnquantizedLinearMethod
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
,
method_has_implemented_embedding
)
from
vllm.model_executor.utils
import
set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE
=
64
class
UnquantizedEmbeddingMethod
(
QuantizeMethodBase
):
"""Unquantized method for embeddings."""
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
"""Create weights for embedding layer."""
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
})
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
F
.
linear
(
x
,
layer
.
weight
,
bias
)
def
embedding
(
self
,
layer
:
torch
.
nn
.
Module
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
F
.
embedding
(
input_
,
layer
.
weight
)
def
pad_vocab_size
(
vocab_size
:
int
,
pad_to
:
int
=
DEFAULT_VOCAB_PADDING_SIZE
)
->
int
:
"""Pad the vocab size to the given value."""
...
...
@@ -199,7 +226,19 @@ class VocabParallelEmbedding(torch.nn.Module):
if
quant_config
is
not
None
:
linear_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
if
linear_method
is
None
:
linear_method
=
UnquantizedLinearMethod
()
linear_method
=
UnquantizedEmbeddingMethod
()
# If we are making an embedding layer, then our quantization linear
# method must implement the embedding operation. If we are another
# layer type like ParallelLMHead, this is not important.
is_embedding_layer
=
type
(
self
.
__class__
)
is
VocabParallelEmbedding
linear_method_implements_embedding
=
method_has_implemented_embedding
(
type
(
linear_method
))
if
is_embedding_layer
and
not
linear_method_implements_embedding
:
raise
NotImplementedError
(
f
"The class
{
type
(
linear_method
).
__name__
}
must implement "
"the 'embedding' method, see UnquantizedEmbeddingMethod."
)
self
.
linear_method
:
QuantizeMethodBase
=
linear_method
if
params_dtype
is
None
:
...
...
@@ -306,6 +345,14 @@ class VocabParallelEmbedding(torch.nn.Module):
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
# If the parameter is a gguf weight, then load it directly.
if
getattr
(
param
,
"is_gguf_weight_type"
,
None
):
param
.
data
.
copy_
(
loaded_weight
)
param
.
weight_type
=
loaded_weight
.
item
()
return
elif
isinstance
(
param
,
UninitializedParameter
):
param
.
materialize
(
loaded_weight
.
shape
,
dtype
=
loaded_weight
.
dtype
)
# If parameter does not have output dim, then it should
# be copied onto all gpus (e.g. g_idx for act_order gptq).
if
output_dim
is
None
:
...
...
@@ -344,7 +391,8 @@ class VocabParallelEmbedding(torch.nn.Module):
else
:
masked_input
=
input_
# Get the embeddings.
output_parallel
=
F
.
embedding
(
masked_input
.
long
(),
self
.
weight
)
output_parallel
=
self
.
linear_method
.
embedding
(
self
,
masked_input
.
long
())
# Mask the output embedding.
if
self
.
tp_size
>
1
:
output_parallel
.
masked_fill_
(
input_mask
.
unsqueeze
(
-
1
),
0
)
...
...
@@ -389,6 +437,7 @@ class ParallelLMHead(VocabParallelEmbedding):
super
().
__init__
(
num_embeddings
,
embedding_dim
,
params_dtype
,
org_num_embeddings
,
padding_size
,
quant_config
,
prefix
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
num_embeddings_per_partition
,
...
...
vllm/model_executor/model_loader/__init__.py
View file @
af7f4372
...
...
@@ -3,8 +3,7 @@ from typing import Optional
from
torch
import
nn
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
ParallelConfig
,
SchedulerConfig
)
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
from
vllm.model_executor.model_loader.loader
import
(
BaseModelLoader
,
get_model_loader
)
from
vllm.model_executor.model_loader.utils
import
(
...
...
@@ -15,13 +14,11 @@ def get_model(*, model_config: ModelConfig, load_config: LoadConfig,
device_config
:
DeviceConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
cache_config
:
CacheConfig
)
->
nn
.
Module
:
loader
=
get_model_loader
(
load_config
)
return
loader
.
load_model
(
model_config
=
model_config
,
device_config
=
device_config
,
lora_config
=
lora_config
,
multimodal_config
=
multimodal_config
,
parallel_config
=
parallel_config
,
scheduler_config
=
scheduler_config
,
cache_config
=
cache_config
)
...
...
vllm/model_executor/model_loader/loader.py
View file @
af7f4372
...
...
@@ -10,11 +10,13 @@ from abc import ABC, abstractmethod
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
,
Type
import
gguf
import
huggingface_hub
import
numpy
as
np
import
torch
from
huggingface_hub
import
HfApi
,
hf_hub_download
from
torch
import
nn
from
transformers
import
AutoModelForCausalLM
,
PretrainedConfig
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
...
...
@@ -31,14 +33,15 @@ from vllm.model_executor.model_loader.utils import (get_model_architecture,
from
vllm.model_executor.model_loader.weight_utils
import
(
download_safetensors_index_file_from_hf
,
download_weights_from_hf
,
filter_duplicate_safetensors_files
,
filter_files_not_needed_for_inference
,
get_quant_config
,
initialize_dummy_weights
,
np_cache_weights_iterator
,
pt_weights_iterator
,
safetensors_weights_iterator
)
get_gguf_extra_tensor_names
,
get_quant_config
,
gguf_quant_weights_iterator
,
initialize_dummy_weights
,
np_cache_weights_iterator
,
pt_weights_iterator
,
safetensors_weights_iterator
)
from
vllm.model_executor.models.interfaces
import
(
has_inner_state
,
supports_lora
,
supports_
vision
)
supports_
multimodal
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_pin_memory_available
,
is_tpu
from
vllm.utils
import
is_pin_memory_available
@
contextmanager
...
...
@@ -91,12 +94,13 @@ def _get_quantization_config(
"""Get the quantization config."""
if
model_config
.
quantization
is
not
None
:
quant_config
=
get_quant_config
(
model_config
,
load_config
)
if
not
current_platform
.
is_tpu
():
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
if
capability
<
quant_config
.
get_min_capability
():
raise
ValueError
(
f
"The quantization method
{
model_config
.
quantization
}
is not
"
"
supported for the current GPU. "
f
"The quantization method
{
model_config
.
quantization
}
"
"is not
supported for the current GPU. "
f
"Minimum capability:
{
quant_config
.
get_min_capability
()
}
. "
f
"Current capability:
{
capability
}
."
)
supported_dtypes
=
quant_config
.
get_supported_act_dtypes
()
...
...
@@ -130,10 +134,8 @@ def _get_model_initialization_kwargs(
"be added in the future. If this is important to you, "
"please open an issue on github."
)
if
supports_vision
(
model_class
):
if
multimodal_config
is
None
:
raise
ValueError
(
"Provide vision related configurations "
"through LLM entrypoint or engine arguments."
)
if
supports_multimodal
(
model_class
):
assert
multimodal_config
is
not
None
extra_kwargs
[
"multimodal_config"
]
=
multimodal_config
...
...
@@ -143,23 +145,40 @@ def _get_model_initialization_kwargs(
return
extra_kwargs
def
build_model
(
model_class
:
Type
[
nn
.
Module
],
hf_config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
],
quant_config
:
Optional
[
QuantizationConfig
],
*
,
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
scheduler_config
:
Optional
[
SchedulerConfig
])
->
nn
.
Module
:
extra_kwargs
=
_get_model_initialization_kwargs
(
model_class
,
lora_config
,
multimodal_config
,
scheduler_config
)
return
model_class
(
config
=
hf_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
**
extra_kwargs
)
def
_initialize_model
(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
cache_config
:
CacheConfig
,
scheduler_config
:
Optional
[
SchedulerConfig
]
=
None
)
->
nn
.
Module
:
"""Initialize a model with the given configurations."""
model_class
=
get_model_architecture
(
model_config
)[
0
]
quant_config
=
_get_quantization_config
(
model_config
,
load_config
)
model_class
,
_
=
get_model_architecture
(
model_config
)
return
model_class
(
config
=
model_config
.
hf_config
,
return
build_model
(
model_class
,
model_config
.
hf_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
**
_get_model_initialization_kwargs
(
model_class
,
lora_config
,
multimodal_config
,
scheduler_config
))
quant_config
=
_get_quantization_config
(
model_config
,
load_config
),
lora_config
=
lora_config
,
multimodal_config
=
model_config
.
multimodal_config
,
scheduler_config
=
scheduler_config
,
)
class
BaseModelLoader
(
ABC
):
...
...
@@ -172,7 +191,6 @@ class BaseModelLoader(ABC):
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
...
...
@@ -301,7 +319,7 @@ class DefaultModelLoader(BaseModelLoader):
else
:
weights_iterator
=
pt_weights_iterator
(
hf_weights_files
)
if
is_tpu
():
if
current_platform
.
is_tpu
():
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
# not too many ops are accumulated in the XLA program.
import
torch_xla.core.xla_model
as
xm
...
...
@@ -317,7 +335,6 @@ class DefaultModelLoader(BaseModelLoader):
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
...
...
@@ -325,8 +342,8 @@ class DefaultModelLoader(BaseModelLoader):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
multimodal
_config
,
cache_config
,
scheduler_config
)
lora_config
,
cache
_config
,
scheduler_config
)
model
.
load_weights
(
self
.
_get_weights_iterator
(
model_config
.
model
,
model_config
.
revision
,
...
...
@@ -360,15 +377,14 @@ class DummyModelLoader(BaseModelLoader):
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
multimodal
_config
,
cache_config
,
scheduler_config
)
lora_config
,
cache
_config
,
scheduler_config
)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights
(
model
)
...
...
@@ -401,7 +417,6 @@ class TensorizerLoader(BaseModelLoader):
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
cache_config
:
CacheConfig
,
)
->
nn
.
Module
:
"""Load a serialized model with tensorizer to the CPU.
...
...
@@ -414,8 +429,7 @@ class TensorizerLoader(BaseModelLoader):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
multimodal_config
,
cache_config
)
lora_config
,
cache_config
)
model
.
load_weights
(
self
.
_get_weights_iterator
())
return
model
.
eval
()
...
...
@@ -425,7 +439,6 @@ class TensorizerLoader(BaseModelLoader):
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
cache_config
:
CacheConfig
,
)
->
nn
.
Module
:
"""Load a serialized model with tensorizer.
...
...
@@ -439,7 +452,7 @@ class TensorizerLoader(BaseModelLoader):
quant_config
=
_get_quantization_config
(
model_config
,
self
.
load_config
)
extra_kwargs
=
_get_model_initialization_kwargs
(
model_class
,
lora_config
,
multimodal_config
)
model_class
,
lora_config
,
model_config
.
multimodal_config
)
extra_kwargs
[
"quant_config"
]
=
quant_config
extra_kwargs
[
"cache_config"
]
=
cache_config
...
...
@@ -454,7 +467,6 @@ class TensorizerLoader(BaseModelLoader):
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
...
...
@@ -468,11 +480,9 @@ class TensorizerLoader(BaseModelLoader):
if
is_vllm_tensorized
(
self
.
tensorizer_config
):
return
self
.
_load_model_serialized
(
model_config
,
device_config
,
lora_config
,
multimodal_config
,
cache_config
)
lora_config
,
cache_config
)
return
self
.
_load_model_serialized_cpu
(
model_config
,
device_config
,
lora_config
,
multimodal_config
,
cache_config
)
lora_config
,
cache_config
)
@
staticmethod
def
save_model
(
...
...
@@ -558,7 +568,6 @@ class ShardedStateLoader(BaseModelLoader):
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
...
...
@@ -572,8 +581,11 @@ class ShardedStateLoader(BaseModelLoader):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
multimodal_config
,
cache_config
)
lora_config
,
cache_config
)
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
quant_method
.
process_weights_after_loading
(
module
)
rank
=
get_tensor_model_parallel_rank
()
pattern
=
os
.
path
.
join
(
local_model_path
,
...
...
@@ -864,11 +876,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if
not
hasattr
(
model
,
'load_weights'
):
raise
AttributeError
(
"The required method 'load_weights' is not defined in class"
f
"
{
type
(
s
el
f
).
__name__
}
."
)
f
"
{
type
(
mod
el
).
__name__
}
."
)
if
not
hasattr
(
model
,
'bitsandbytes_stacked_params_mapping'
):
raise
AttributeError
(
f
"Model
{
type
(
s
el
f
).
__name__
}
does not support BitsAndBytes "
f
"Model
{
type
(
mod
el
).
__name__
}
does not support BitsAndBytes "
"quantization yet."
)
logger
.
info
(
"Loading weights with BitsAndBytes quantization. "
...
...
@@ -936,21 +948,101 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
multimodal_config
,
cache_config
)
lora_config
,
cache_config
)
self
.
_load_weights
(
model_config
,
model
)
return
model
.
eval
()
class
GGUFModelLoader
(
BaseModelLoader
):
"""
Model loader that can load GGUF files. This is useful for loading models
that are quantized with GGUF and saved in the GGUF format. This loader
supports loading both full models and sharded models.
"""
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
if
load_config
.
model_loader_extra_config
:
raise
ValueError
(
f
"Model loader extra config is not supported for "
f
"load format
{
load_config
.
load_format
}
"
)
def
_prepare_weights
(
self
,
model_name_or_path
:
str
):
if
os
.
path
.
isfile
(
model_name_or_path
):
return
model_name_or_path
else
:
raise
ValueError
(
f
"
{
model_name_or_path
}
is not a file."
)
def
_get_gguf_weights_map
(
self
,
model_config
:
ModelConfig
):
"""
GGUF uses this naming convention for their tensors from HF checkpoint:
`blk.N.BB.weight` and `blk.N.BB.bias`
where N signifies the block number of a layer, and BB signifies the
attention/mlp layer components.
See "Standardized tensor names" in
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
"""
config
=
model_config
.
hf_config
model_type
=
config
.
model_type
# hack: ggufs have a different name than transformers
if
model_type
==
"cohere"
:
model_type
=
"command-r"
arch
=
None
for
key
,
value
in
gguf
.
MODEL_ARCH_NAMES
.
items
():
if
value
==
model_type
:
arch
=
key
break
if
arch
is
None
:
raise
RuntimeError
(
f
"Unknown gguf model_type:
{
model_type
}
"
)
num_layers
=
config
.
num_hidden_layers
name_map
=
gguf
.
get_tensor_name_map
(
arch
,
num_layers
)
with
torch
.
device
(
"meta"
):
dummy_model
=
AutoModelForCausalLM
.
from_config
(
config
)
state_dict
=
dummy_model
.
state_dict
()
gguf_to_hf_name_map
=
{}
for
hf_name
in
state_dict
:
name
,
suffix
=
hf_name
.
rsplit
(
"."
,
1
)
gguf_name
=
name_map
.
get_name
(
name
)
gguf_to_hf_name_map
[
f
"
{
gguf_name
}
.
{
suffix
}
"
]
=
hf_name
return
gguf_to_hf_name_map
def
_get_weights_iterator
(
self
,
model_name_or_path
:
str
,
gguf_to_hf_name_map
:
Dict
[
str
,
str
]
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
return
gguf_quant_weights_iterator
(
model_name_or_path
,
gguf_to_hf_name_map
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
local_model_path
=
self
.
_prepare_weights
(
model_config
.
model
)
gguf_weights_map
=
self
.
_get_gguf_weights_map
(
model_config
)
# we can only know if tie word embeddings after mapping weights
if
"lm_head.weight"
in
get_gguf_extra_tensor_names
(
local_model_path
,
gguf_weights_map
):
model_config
.
hf_config
.
update
({
"tie_word_embeddings"
:
True
})
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
cache_config
)
model
.
load_weights
(
self
.
_get_weights_iterator
(
local_model_path
,
gguf_weights_map
))
return
model
def
get_model_loader
(
load_config
:
LoadConfig
)
->
BaseModelLoader
:
"""Get a model loader based on the load format."""
...
...
@@ -969,4 +1061,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if
load_config
.
load_format
==
LoadFormat
.
BITSANDBYTES
:
return
BitsAndBytesModelLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
GGUF
:
return
GGUFModelLoader
(
load_config
)
return
DefaultModelLoader
(
load_config
)
vllm/model_executor/model_loader/utils.py
View file @
af7f4372
...
...
@@ -47,13 +47,7 @@ def get_model_architecture(
and
"MixtralForCausalLM"
in
architectures
):
architectures
=
[
"QuantMixtralForCausalLM"
]
for
arch
in
architectures
:
model_cls
=
ModelRegistry
.
load_model_cls
(
arch
)
if
model_cls
is
not
None
:
return
(
model_cls
,
arch
)
raise
ValueError
(
f
"Model architectures
{
architectures
}
are not supported for now. "
f
"Supported architectures:
{
ModelRegistry
.
get_supported_archs
()
}
"
)
return
ModelRegistry
.
resolve_model_cls
(
architectures
)
def
get_architecture_class_name
(
model_config
:
ModelConfig
)
->
str
:
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
af7f4372
...
...
@@ -6,9 +6,10 @@ import json
import
os
import
tempfile
from
collections
import
defaultdict
from
typing
import
Any
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
filelock
import
gguf
import
huggingface_hub.constants
import
numpy
as
np
import
torch
...
...
@@ -18,6 +19,7 @@ from tqdm.auto import tqdm
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
vllm.config
import
LoadConfig
,
ModelConfig
from
vllm.distributed
import
get_tensor_model_parallel_rank
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
(
QuantizationConfig
,
get_quantization_config
)
...
...
@@ -121,9 +123,18 @@ def get_quant_config(model_config: ModelConfig,
load_config
:
LoadConfig
)
->
QuantizationConfig
:
quant_cls
=
get_quantization_config
(
model_config
.
quantization
)
# GGUF doesn't have config file
if
model_config
.
quantization
==
"gguf"
:
return
quant_cls
.
from_config
({})
# Read the quantization config from the HF model config, if available.
hf_quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
None
)
# some vision model may keep quantization_config in their text_config
hf_text_config
=
getattr
(
model_config
.
hf_config
,
"text_config"
,
None
)
if
hf_quant_config
is
None
and
hf_text_config
is
not
None
:
hf_quant_config
=
getattr
(
hf_text_config
,
"quantization_config"
,
None
)
if
hf_quant_config
is
None
:
# compressed-tensors uses a compressions_config
hf_quant_config
=
getattr
(
model_config
.
hf_config
,
"compression_config"
,
...
...
@@ -409,6 +420,47 @@ def pt_weights_iterator(
torch
.
cuda
.
empty_cache
()
def
get_gguf_extra_tensor_names
(
gguf_file
:
str
,
gguf_to_hf_name_map
:
Dict
[
str
,
str
])
->
List
[
str
]:
reader
=
gguf
.
GGUFReader
(
gguf_file
)
expected_gguf_keys
=
set
(
gguf_to_hf_name_map
.
keys
())
exact_gguf_keys
=
set
([
tensor
.
name
for
tensor
in
reader
.
tensors
])
extra_keys
=
expected_gguf_keys
-
exact_gguf_keys
return
[
gguf_to_hf_name_map
[
key
]
for
key
in
extra_keys
]
def
gguf_quant_weights_iterator
(
gguf_file
:
str
,
gguf_to_hf_name_map
:
Dict
[
str
,
str
]
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""
Iterate over the quant weights in the model gguf files and convert
them to torch tensors
"""
reader
=
gguf
.
GGUFReader
(
gguf_file
)
for
tensor
in
reader
.
tensors
:
if
tensor
.
name
in
gguf_to_hf_name_map
:
weight_type
=
tensor
.
tensor_type
name
=
gguf_to_hf_name_map
[
tensor
.
name
]
if
weight_type
.
name
!=
"F32"
:
weight_type_name
=
name
.
replace
(
"weight"
,
"qweight_type"
)
weight_type
=
torch
.
tensor
(
weight_type
)
yield
weight_type_name
,
weight_type
for
tensor
in
reader
.
tensors
:
if
tensor
.
name
in
gguf_to_hf_name_map
:
weight
=
tensor
.
data
weight_type
=
tensor
.
tensor_type
name
=
gguf_to_hf_name_map
[
tensor
.
name
]
if
weight_type
.
name
!=
"F32"
:
name
=
name
.
replace
(
"weight"
,
"qweight"
)
param
=
torch
.
tensor
(
weight
)
yield
name
,
param
def
kv_cache_scales_loader
(
filename
:
str
,
tp_rank
:
int
,
tp_size
:
int
,
num_hidden_layers
:
int
,
model_type
:
Optional
[
str
])
->
Iterable
[
Tuple
[
int
,
float
]]:
...
...
@@ -467,8 +519,36 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
def
default_weight_loader
(
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
"""Default weight loader."""
assert
param
.
size
()
==
loaded_weight
.
size
()
try
:
if
param
.
numel
()
==
1
and
loaded_weight
.
numel
()
==
1
:
# Sometimes scalar values aren't considered tensors with shapes
# so if both param and loaded_weight are a scalar,
# "broadcast" instead of copy
param
.
data
.
fill_
(
loaded_weight
.
item
())
else
:
assert
param
.
size
()
==
loaded_weight
.
size
(),
(
f
"Attempted to load weight (
{
loaded_weight
.
size
()
}
) "
f
"into parameter (
{
param
.
size
()
}
)"
)
param
.
data
.
copy_
(
loaded_weight
)
except
Exception
:
# NOTE: This exception is added for the purpose of setting breakpoint to
# debug weight loading issues.
raise
def
row_parallel_weight_loader
(
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
"""Load weights that are row-parallelized."""
tp_rank
=
get_tensor_model_parallel_rank
()
shard_dim
=
0
if
param
.
dim
()
!=
1
else
None
if
shard_dim
is
not
None
:
shard_size
=
param
.
data
.
shape
[
shard_dim
]
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
start_idx
,
shard_size
)
return
default_weight_loader
(
param
,
loaded_weight
)
def
initialize_dummy_weights
(
...
...
vllm/model_executor/models/__init__.py
View file @
af7f4372
import
functools
import
importlib
from
typing
import
Dict
,
List
,
Optional
,
Type
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch.nn
as
nn
...
...
@@ -9,17 +9,12 @@ from vllm.utils import is_hip
logger
=
init_logger
(
__name__
)
# Architecture -> (module, class).
_GENERATION_MODELS
=
{
"AquilaModel"
:
(
"llama"
,
"LlamaForCausalLM"
),
"AquilaForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
# AquilaChat2
"BaiChuanForCausalLM"
:
(
"baichuan"
,
"BaiChuanForCausalLM"
),
# baichuan-7b
"BaichuanForCausalLM"
:
(
"baichuan"
,
"BaichuanForCausalLM"
),
# baichuan-13b
"BloomForCausalLM"
:
(
"bloom"
,
"BloomForCausalLM"
),
"Blip2ForConditionalGeneration"
:
(
"blip2"
,
"Blip2ForConditionalGeneration"
),
"ChameleonForConditionalGeneration"
:
(
"chameleon"
,
"ChameleonForConditionalGeneration"
),
"ChatGLMModel"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
"ChatGLMForConditionalGeneration"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
"CohereForCausalLM"
:
(
"commandr"
,
"CohereForCausalLM"
),
...
...
@@ -28,7 +23,6 @@ _GENERATION_MODELS = {
"DeepseekForCausalLM"
:
(
"deepseek"
,
"DeepseekForCausalLM"
),
"DeepseekV2ForCausalLM"
:
(
"deepseek_v2"
,
"DeepseekV2ForCausalLM"
),
"FalconForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"FuyuForCausalLM"
:
(
"fuyu"
,
"FuyuForCausalLM"
),
"GemmaForCausalLM"
:
(
"gemma"
,
"GemmaForCausalLM"
),
"Gemma2ForCausalLM"
:
(
"gemma2"
,
"Gemma2ForCausalLM"
),
"GPT2LMHeadModel"
:
(
"gpt2"
,
"GPT2LMHeadModel"
),
...
...
@@ -37,13 +31,8 @@ _GENERATION_MODELS = {
"GPTNeoXForCausalLM"
:
(
"gpt_neox"
,
"GPTNeoXForCausalLM"
),
"InternLMForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"InternLM2ForCausalLM"
:
(
"internlm2"
,
"InternLM2ForCausalLM"
),
"InternVLChatModel"
:
(
"internvl"
,
"InternVLChatModel"
),
"JAISLMHeadModel"
:
(
"jais"
,
"JAISLMHeadModel"
),
"LlamaForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"LlavaForConditionalGeneration"
:
(
"llava"
,
"LlavaForConditionalGeneration"
),
"LlavaNextForConditionalGeneration"
:
(
"llava_next"
,
"LlavaNextForConditionalGeneration"
),
# For decapoda-research/llama-*
"LLaMAForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"MistralForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
...
...
@@ -53,17 +42,13 @@ _GENERATION_MODELS = {
"MptForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MPTForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MiniCPMForCausalLM"
:
(
"minicpm"
,
"MiniCPMForCausalLM"
),
"MiniCPMV"
:
(
"minicpmv"
,
"MiniCPMV"
),
"NemotronForCausalLM"
:
(
"nemotron"
,
"NemotronForCausalLM"
),
"OlmoForCausalLM"
:
(
"olmo"
,
"OlmoForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"OrionForCausalLM"
:
(
"orion"
,
"OrionForCausalLM"
),
"PersimmonForCausalLM"
:
(
"persimmon"
,
"PersimmonForCausalLM"
),
"PaliGemmaForConditionalGeneration"
:
(
"paligemma"
,
"PaliGemmaForConditionalGeneration"
),
"PhiForCausalLM"
:
(
"phi"
,
"PhiForCausalLM"
),
"Phi3ForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"Phi3VForCausalLM"
:
(
"phi3v"
,
"Phi3VForCausalLM"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"Qwen2ForCausalLM"
:
(
"qwen2"
,
"Qwen2ForCausalLM"
),
"Qwen2MoeForCausalLM"
:
(
"qwen2_moe"
,
"Qwen2MoeForCausalLM"
),
...
...
@@ -75,15 +60,43 @@ _GENERATION_MODELS = {
"XverseForCausalLM"
:
(
"xverse"
,
"XverseForCausalLM"
),
"Phi3SmallForCausalLM"
:
(
"phi3_small"
,
"Phi3SmallForCausalLM"
),
"MedusaModel"
:
(
"medusa"
,
"Medusa"
),
"EAGLEModel"
:
(
"eagle"
,
"EAGLE"
),
"MLPSpeculatorPreTrainedModel"
:
(
"mlp_speculator"
,
"MLPSpeculator"
),
"JambaForCausalLM"
:
(
"jamba"
,
"JambaForCausalLM"
)
"JambaForCausalLM"
:
(
"jamba"
,
"JambaForCausalLM"
)
,
}
_EMBEDDING_MODELS
=
{
"MistralModel"
:
(
"llama_embedding"
,
"LlamaEmbeddingModel"
),
}
_MODELS
=
{
**
_GENERATION_MODELS
,
**
_EMBEDDING_MODELS
}
_MULTIMODAL_MODELS
=
{
"Blip2ForConditionalGeneration"
:
(
"blip2"
,
"Blip2ForConditionalGeneration"
),
"ChameleonForConditionalGeneration"
:
(
"chameleon"
,
"ChameleonForConditionalGeneration"
),
"FuyuForCausalLM"
:
(
"fuyu"
,
"FuyuForCausalLM"
),
"InternVLChatModel"
:
(
"internvl"
,
"InternVLChatModel"
),
"LlavaForConditionalGeneration"
:
(
"llava"
,
"LlavaForConditionalGeneration"
),
"LlavaNextForConditionalGeneration"
:
(
"llava_next"
,
"LlavaNextForConditionalGeneration"
),
"MiniCPMV"
:
(
"minicpmv"
,
"MiniCPMV"
),
"PaliGemmaForConditionalGeneration"
:
(
"paligemma"
,
"PaliGemmaForConditionalGeneration"
),
"Phi3VForCausalLM"
:
(
"phi3v"
,
"Phi3VForCausalLM"
),
"UltravoxModel"
:
(
"ultravox"
,
"UltravoxModel"
),
}
_CONDITIONAL_GENERATION_MODELS
=
{
"BartModel"
:
(
"bart"
,
"BartForConditionalGeneration"
),
"BartForConditionalGeneration"
:
(
"bart"
,
"BartForConditionalGeneration"
),
}
_MODELS
=
{
**
_GENERATION_MODELS
,
**
_EMBEDDING_MODELS
,
**
_MULTIMODAL_MODELS
,
**
_CONDITIONAL_GENERATION_MODELS
,
}
# Architecture -> type.
# out of tree models
...
...
@@ -126,7 +139,7 @@ class ModelRegistry:
return
getattr
(
module
,
model_cls_name
,
None
)
@
staticmethod
def
load_model_cls
(
model_arch
:
str
)
->
Optional
[
Type
[
nn
.
Module
]]:
def
_try_
load_model_cls
(
model_arch
:
str
)
->
Optional
[
Type
[
nn
.
Module
]]:
if
model_arch
in
_OOT_MODELS
:
return
_OOT_MODELS
[
model_arch
]
if
model_arch
not
in
_MODELS
:
...
...
@@ -143,9 +156,21 @@ class ModelRegistry:
return
ModelRegistry
.
_get_model
(
model_arch
)
@
staticmethod
def
resolve_model_cls
(
architectures
:
List
[
str
])
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
for
arch
in
architectures
:
model_cls
=
ModelRegistry
.
_try_load_model_cls
(
arch
)
if
model_cls
is
not
None
:
return
(
model_cls
,
arch
)
raise
ValueError
(
f
"Model architectures
{
architectures
}
are not supported for now. "
f
"Supported architectures:
{
ModelRegistry
.
get_supported_archs
()
}
"
)
@
staticmethod
def
get_supported_archs
()
->
List
[
str
]:
return
list
(
_MODELS
.
keys
())
return
list
(
_MODELS
.
keys
())
+
list
(
_OOT_MODELS
.
keys
())
@
staticmethod
def
register_model
(
model_arch
:
str
,
model_cls
:
Type
[
nn
.
Module
]):
...
...
@@ -161,6 +186,15 @@ class ModelRegistry:
def
is_embedding_model
(
model_arch
:
str
)
->
bool
:
return
model_arch
in
_EMBEDDING_MODELS
@
staticmethod
def
is_multimodal_model
(
model_arch
:
str
)
->
bool
:
# TODO: find a way to avoid initializing CUDA prematurely to
# use `supports_multimodal` to determine if a model is multimodal
# model_cls = ModelRegistry._try_load_model_cls(model_arch)
# from vllm.model_executor.models.interfaces import supports_multimodal
return
model_arch
in
_MULTIMODAL_MODELS
__all__
=
[
"ModelRegistry"
,
...
...
vllm/model_executor/models/arctic.py
View file @
af7f4372
...
...
@@ -414,6 +414,8 @@ class ArcticForCausalLM(nn.Module):
config
.
hidden_size
,
quant_config
=
quant_config
,
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
num_experts
=
config
.
num_local_experts
self
.
num_experts_per_tok
=
config
.
num_experts_per_tok
self
.
unpadded_vocab_size
=
config
.
vocab_size
...
...
@@ -433,8 +435,11 @@ class ArcticForCausalLM(nn.Module):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
Prev
1
…
17
18
19
20
21
22
23
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment