Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4eabe123
Commit
4eabe123
authored
May 28, 2025
by
zhuwenwen
Browse files
Merge remote-tracking branch 'mirror/releases/v0.9.0' into v0.9.0-ori
parents
45840cd2
58738772
Changes
670
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
352 additions
and
214 deletions
+352
-214
vllm/model_executor/layers/quantization/auto_round.py
vllm/model_executor/layers/quantization/auto_round.py
+29
-27
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+2
-3
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
..._executor/layers/quantization/compressed_tensors/utils.py
+1
-1
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+13
-51
vllm/model_executor/layers/quantization/gguf.py
vllm/model_executor/layers/quantization/gguf.py
+181
-42
vllm/model_executor/layers/quantization/ipex_quant.py
vllm/model_executor/layers/quantization/ipex_quant.py
+1
-3
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+2
-2
vllm/model_executor/layers/quantization/quark/utils.py
vllm/model_executor/layers/quantization/quark/utils.py
+2
-1
vllm/model_executor/layers/quantization/utils/gptq_utils.py
vllm/model_executor/layers/quantization/utils/gptq_utils.py
+1
-1
vllm/model_executor/layers/rejection_sampler.py
vllm/model_executor/layers/rejection_sampler.py
+18
-17
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+11
-9
vllm/model_executor/layers/typical_acceptance_sampler.py
vllm/model_executor/layers/typical_acceptance_sampler.py
+24
-32
vllm/model_executor/model_loader/__init__.py
vllm/model_executor/model_loader/__init__.py
+10
-3
vllm/model_executor/model_loader/base_loader.py
vllm/model_executor/model_loader/base_loader.py
+2
-1
vllm/model_executor/model_loader/bitsandbytes_loader.py
vllm/model_executor/model_loader/bitsandbytes_loader.py
+2
-3
vllm/model_executor/model_loader/default_loader.py
vllm/model_executor/model_loader/default_loader.py
+6
-5
vllm/model_executor/model_loader/dummy_loader.py
vllm/model_executor/model_loader/dummy_loader.py
+2
-2
vllm/model_executor/model_loader/gguf_loader.py
vllm/model_executor/model_loader/gguf_loader.py
+2
-2
vllm/model_executor/model_loader/neuronx_distributed.py
vllm/model_executor/model_loader/neuronx_distributed.py
+41
-6
vllm/model_executor/model_loader/runai_streamer_loader.py
vllm/model_executor/model_loader/runai_streamer_loader.py
+2
-3
No files found.
vllm/model_executor/layers/quantization/auto_round.py
View file @
4eabe123
...
...
@@ -8,6 +8,7 @@ import torch
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
...
...
@@ -74,7 +75,7 @@ class AutoRoundConfig(QuantizationConfig):
f
"group_size=
{
self
.
group_size
}
, sym=
{
self
.
sym
}
)"
)
@
classmethod
def
get_name
(
cls
)
:
## use str will trigger preci issue
def
get_name
(
cls
)
->
QuantizationMethods
:
return
"auto-round"
@
classmethod
...
...
@@ -142,18 +143,18 @@ class AutoRoundConfig(QuantizationConfig):
prefix
,
layer
.
__class__
.
__name__
,
weight_bits
,
group_size
,
sym
)
if
backend
==
"auto"
or
"marlin"
in
backend
:
AWQ_TYPE_MAP
=
{
4
:
scalar_types
.
uint4
,
8
:
scalar_types
.
uint8
,
}
use_marlin
=
(
weight_bits
in
AWQ_TYPE_MAP
)
and
check_marlin_supported
(
AWQ_TYPE_MAP
[
weight_bits
],
group_size
,
not
sym
)
if
isinstance
(
layer
,
FusedMoE
):
use_marlin
=
check_moe_marlin_supports_layer
(
layer
,
group_size
)
else
:
use_marlin
=
use_marlin
and
check_moe_marlin_supports_layer
(
layer
,
group_size
)
AWQ_TYPE_MAP
=
{
4
:
scalar_types
.
uint4
,
8
:
scalar_types
.
uint8
,
}
use_marlin
=
((
weight_bits
,
sym
)
in
AWQ_TYPE_MAP
and
check_marlin_supported
(
AWQ_TYPE_MAP
[(
weight_bits
)],
group_size
,
not
sym
))
else
:
use_marlin
=
False
if
use_marlin
:
...
...
@@ -180,10 +181,11 @@ class AutoRoundConfig(QuantizationConfig):
from
vllm.model_executor.layers.quantization.moe_wna16
import
(
MoeWNA16Config
)
config
=
{
"
linear_
quant_method"
:
"awq"
,
"
weight_
bits"
:
weight_bits
,
"quant_method"
:
"awq"
,
"bits"
:
weight_bits
,
"group_size"
:
group_size
,
"zero_point"
:
not
sym
,
"lm_head"
:
False
,
}
return
MoeWNA16Config
.
from_config
(
config
).
get_quant_method
(
layer
,
prefix
)
...
...
@@ -213,18 +215,18 @@ class AutoRoundConfig(QuantizationConfig):
prefix
,
layer
.
__class__
.
__name__
,
weight_bits
,
group_size
,
sym
)
if
backend
==
"auto"
or
"marlin"
in
backend
:
GPTQ_TYPE_MAP
=
{
(
4
,
True
):
scalar_types
.
uint4b8
,
(
8
,
True
):
scalar_types
.
uint8b128
,
}
use_marlin
=
((
weight_bits
,
sym
)
in
GPTQ_TYPE_MAP
and
check_marlin_supported
(
GPTQ_TYPE_MAP
[(
weight_bits
,
sym
)],
group_size
,
has_zp
=
not
sym
))
if
isinstance
(
layer
,
FusedMoE
):
use_marlin
=
check_moe_marlin_supports_layer
(
layer
,
group_size
)
else
:
GPTQ_TYPE_MAP
=
{
(
4
,
True
):
scalar_types
.
uint4b8
,
(
8
,
True
):
scalar_types
.
uint8b128
,
}
use_marlin
=
((
weight_bits
,
sym
)
in
GPTQ_TYPE_MAP
and
check_marlin_supported
(
GPTQ_TYPE_MAP
[(
weight_bits
,
sym
)],
group_size
,
has_zp
=
not
sym
))
use_marlin
=
use_marlin
and
check_moe_marlin_supports_layer
(
layer
,
group_size
)
else
:
use_marlin
=
False
if
use_marlin
:
...
...
@@ -251,11 +253,11 @@ class AutoRoundConfig(QuantizationConfig):
from
vllm.model_executor.layers.quantization.moe_wna16
import
(
MoeWNA16Config
)
config
=
{
"
linear_
quant_method"
:
"gptq"
,
"
weight_
bits"
:
weight_bits
,
"quant_method"
:
"gptq"
,
"bits"
:
weight_bits
,
"group_size"
:
group_size
,
"sym"
:
sym
,
"lm_head
_quantized
"
:
False
,
"lm_head"
:
False
,
}
return
MoeWNA16Config
.
from_config
(
config
).
get_quant_method
(
layer
,
prefix
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
4eabe123
...
...
@@ -286,9 +286,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
rocm_aiter_fused_experts
,
shuffle_weights
)
# reshaping weights is required for aiter moe kernel.
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
layer
.
w13_weight
.
data
,
layer
.
w2_weight
.
data
,
layout
=
(
16
,
16
))
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
layer
.
w13_weight
.
data
,
layer
.
w2_weight
.
data
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
shuffled_w13
,
requires_grad
=
False
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
View file @
4eabe123
# SPDX-License-Identifier: Apache-2.0
import
re
from
collections.abc
import
Iterable
,
Mapping
from
types
import
MappingProxyType
from
typing
import
Optional
import
regex
as
re
from
compressed_tensors
import
CompressionFormat
from
torch.nn
import
Module
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
4eabe123
...
...
@@ -10,7 +10,6 @@ from torch.nn import Module
from
torch.nn.parameter
import
Parameter
import
vllm.envs
as
envs
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
_custom_ops
as
ops
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
...
...
@@ -63,10 +62,9 @@ class Fp8Config(QuantizationConfig):
weight_block_size
:
Optional
[
list
[
int
]]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
if
is_checkpoint_fp8_serialized
:
logger
.
warning
(
"Detected fp8 checkpoint. Please note that the "
"format is experimental and subject to change."
)
if
activation_scheme
not
in
ACTIVATION_SCHEMES
:
raise
ValueError
(
f
"Unsupported activation scheme
{
activation_scheme
}
"
)
...
...
@@ -461,7 +459,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
logger
.
warning_once
(
"DeepGemm not supported on the current platform."
)
self
.
fused_experts
=
functools
.
partial
(
self
.
fused_experts
=
functools
.
partial
(
# type: ignore
fused_experts
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
allow_deep_gemm
=
self
.
allow_deep_gemm
)
...
...
@@ -597,7 +595,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# Lazy import to avoid importing triton too early.
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
expand_weights
,
is_rocm_aiter_moe_enabled
,
shuffle_weights
)
is_rocm_aiter_moe_enabled
,
shuffle_weights
)
self
.
rocm_aiter_moe_enabled
=
is_rocm_aiter_moe_enabled
()
...
...
@@ -629,9 +627,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if
self
.
rocm_aiter_moe_enabled
:
# reshaping weights is required for aiter moe kernel.
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
layer
.
w13_weight
.
data
,
layer
.
w2_weight
.
data
,
layout
=
(
16
,
16
))
layer
.
w13_weight
.
data
,
layer
.
w2_weight
.
data
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
shuffled_w13
,
requires_grad
=
False
)
...
...
@@ -677,20 +673,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
requires_grad
=
False
)
if
self
.
rocm_aiter_moe_enabled
:
# reshaping weights is required for aiter moe kernel.
w13_scales
,
w2_scales
=
expand_weights
(
layer
.
w13_weight_scale
.
data
,
layer
.
w2_weight_scale
.
data
,
expansion_dims
=
[
layer
.
w13_weight
.
shape
[
1
],
layer
.
w2_weight
.
shape
[
1
]
])
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
w13_scales
.
contiguous
(),
requires_grad
=
False
)
layer
.
w2_weight_scale
=
torch
.
nn
.
Parameter
(
w2_scales
.
contiguous
(),
requires_grad
=
False
)
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
layer
.
w13_weight
,
layer
.
w2_weight
,
layout
=
(
16
,
16
))
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
layer
.
w13_weight
,
layer
.
w2_weight
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
shuffled_w13
,
requires_grad
=
False
)
...
...
@@ -762,20 +746,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
start
+=
shard_size
if
self
.
rocm_aiter_moe_enabled
:
# reshaping weights is required for aiter moe kernel.
expansion_dims
=
[
layer
.
w13_weight
.
shape
[
1
],
layer
.
w2_weight
.
shape
[
1
]
]
max_w13_scales
,
w2_scales
=
expand_weights
(
max_w13_scales
,
layer
.
w2_weight_scale
.
data
,
expansion_dims
=
expansion_dims
)
layer
.
w2_weight_scale
=
torch
.
nn
.
Parameter
(
w2_scales
.
contiguous
(),
requires_grad
=
False
)
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
layer
.
w13_weight
,
layer
.
w2_weight
,
layout
=
(
32
,
32
))
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
layer
.
w13_weight
,
layer
.
w2_weight
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
shuffled_w13
,
requires_grad
=
False
)
...
...
@@ -791,17 +763,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del
layer
.
w13_input_scale
del
layer
.
w2_input_scale
def
set_prepare_finalize
(
self
,
dp_size
:
int
,
world_size
:
int
,
prepare_finalize
:
mk
.
FusedMoEPrepareAndFinalize
,
)
->
bool
:
def
select_gemm_impl
(
self
,
prepare_finalize
):
from
vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe
import
(
TritonOrDeepGemmExperts
)
if
self
.
use_marlin
or
self
.
rocm_aiter_moe_enabled
:
return
False
assert
not
self
.
use_marlin
and
not
self
.
rocm_aiter_moe_enabled
,
(
"Marlin and ROCm AITER are not supported with all2all yet."
)
experts
=
TritonOrDeepGemmExperts
(
use_fp8_w8a8
=
True
,
...
...
@@ -809,12 +776,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
allow_deep_gemm
=
self
.
allow_deep_gemm
,
)
self
.
fused_experts
=
mk
.
FusedMoEModularKernel
(
prepare_finalize
,
experts
,
)
return
True
return
experts
def
apply
(
self
,
...
...
vllm/model_executor/layers/quantization/gguf.py
View file @
4eabe123
...
...
@@ -9,7 +9,6 @@ from torch.nn.parameter import Parameter, UninitializedParameter
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoEMethodBase
)
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
...
...
@@ -19,6 +18,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
direct_register_custom_op
logger
=
init_logger
(
__name__
)
...
...
@@ -96,8 +96,8 @@ MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMQ_QUANT_TYPES
=
STANDARD_QUANT_TYPES
|
KQUANT_TYPES
def
_fuse_mul_mat
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
qweight_type
:
int
)
->
torch
.
Tensor
:
def
_fuse
d
_mul_mat
_gguf
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
qweight_type
:
int
)
->
torch
.
Tensor
:
# HACK: when doing chunked prefill we don't generate output tokens
# so input to logits generator is empty which causes invalid parameter
if
x
.
shape
[
0
]
==
0
:
...
...
@@ -130,6 +130,30 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
return
y
def
_fused_mul_mat_gguf_fake
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
qweight_type
:
int
,
)
->
torch
.
Tensor
:
return
torch
.
empty
(
x
.
shape
[
0
],
qweight
.
shape
[
0
],
dtype
=
x
.
dtype
,
device
=
x
.
device
)
try
:
direct_register_custom_op
(
op_name
=
"_fused_mul_mat_gguf"
,
op_func
=
_fused_mul_mat_gguf
,
mutates_args
=
[],
fake_impl
=
_fused_mul_mat_gguf_fake
,
)
fused_mul_mat_gguf
=
torch
.
ops
.
vllm
.
_fused_mul_mat_gguf
except
AttributeError
as
error
:
raise
error
def
_fused_moe_gguf
(
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
...
...
@@ -138,8 +162,21 @@ def _fused_moe_gguf(
topk_ids
:
torch
.
Tensor
,
qweight_type
:
int
,
qweight_type2
:
int
,
act
,
act
ivation
:
str
,
)
->
torch
.
Tensor
:
def
act
(
x
:
torch
.
Tensor
):
d
=
x
.
shape
[
-
1
]
//
2
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
if
activation
==
"silu"
:
torch
.
ops
.
_C
.
silu_and_mul
(
out
,
x
)
elif
activation
==
"gelu"
:
torch
.
ops
.
_C
.
gelu_and_mul
(
out
,
x
)
else
:
raise
ValueError
(
f
"Unsupported activation:
{
activation
}
"
)
return
out
# lazy import to avoid triggering triton import in CPU backend
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
moe_align_block_size
)
...
...
@@ -189,12 +226,12 @@ def _fused_moe_gguf(
for
ww
,
ii
in
zip
(
w
,
idx
):
expert_up
=
w1
[
ii
]
out
=
_
fuse_mul_mat
(
inp
,
expert_up
,
qweight_type
)
out
=
fuse
d
_mul_mat
_gguf
(
inp
,
expert_up
,
qweight_type
)
out
=
act
(
out
)
expert_down
=
w2
[
ii
]
current_state
=
_
fuse_mul_mat
(
out
,
expert_down
,
qweight_type2
).
mul_
(
ww
)
current_state
=
fuse
d
_mul_mat
_gguf
(
out
,
expert_down
,
qweight_type2
).
mul_
(
ww
)
if
current_hidden_state
is
None
:
current_hidden_state
=
current_state
else
:
...
...
@@ -203,6 +240,78 @@ def _fused_moe_gguf(
return
out_hidden_states
def
_fused_moe_gguf_fake
(
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
qweight_type
:
int
,
qweight_type2
:
int
,
activation
:
str
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
x
)
try
:
direct_register_custom_op
(
op_name
=
"_fused_moe_gguf"
,
op_func
=
_fused_moe_gguf
,
mutates_args
=
[],
fake_impl
=
_fused_moe_gguf_fake
,
)
fused_moe_gguf
=
torch
.
ops
.
vllm
.
_fused_moe_gguf
except
AttributeError
as
error
:
raise
error
def
_apply_gguf_embedding
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
qweight_type
:
int
,
hidden_size
:
int
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
)
->
torch
.
Tensor
:
if
qweight_type
in
UNQUANTIZED_TYPES
:
return
torch
.
embedding
(
qweight
,
x
)
elif
qweight_type
in
DEQUANT_TYPES
:
block_size
,
type_size
=
gguf
.
GGML_QUANT_SIZES
[
qweight_type
]
x_flat
=
x
.
flatten
()
assert
(
hidden_size
==
qweight
.
shape
[
1
]
//
type_size
*
block_size
)
quant
=
torch
.
index_select
(
qweight
,
dim
=
0
,
index
=
x_flat
)
dequant
=
ops
.
ggml_dequantize
(
quant
,
qweight_type
,
hidden_size
,
x_flat
.
shape
[
0
],
dtype
)
return
dequant
.
view
(
*
x
.
shape
,
hidden_size
)
else
:
qweight_type
=
WeightType
(
qweight_type
)
raise
NotImplementedError
(
f
"Unsupported GGUF quantization type:
{
qweight_type
}
"
)
def
_apply_gguf_embedding_fake
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
qweight_type
:
int
,
hidden_size
:
int
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
)
->
torch
.
Tensor
:
return
torch
.
empty
(
x
.
shape
[
0
],
hidden_size
,
dtype
=
dtype
,
device
=
x
.
device
)
try
:
direct_register_custom_op
(
op_name
=
"_apply_gguf_embedding"
,
op_func
=
_apply_gguf_embedding
,
mutates_args
=
[],
fake_impl
=
_apply_gguf_embedding_fake
,
)
apply_gguf_embedding
=
torch
.
ops
.
vllm
.
_apply_gguf_embedding
except
AttributeError
as
error
:
raise
error
class
GGUFLinearMethod
(
LinearMethodBase
):
"""Linear method for GGUF.
...
...
@@ -249,26 +358,76 @@ class GGUFLinearMethod(LinearMethodBase):
set_weight_attrs
(
qweight_type
,
extra_weight_attrs
)
layer
.
register_parameter
(
"qweight_type"
,
qweight_type
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
):
qweight_type
=
layer
.
qweight_type
.
weight_type
if
not
(
qweight_type
in
UNQUANTIZED_TYPES
or
qweight_type
in
DEQUANT_TYPES
):
qweight_type
=
WeightType
(
qweight_type
)
raise
ValueError
(
f
"Unsupported GGUF quantization type
{
qweight_type
}
in "
f
"layer
{
layer
}
."
)
# For MergedColumnParallelLinear and QKVParallelLinear, we need to
# materialize the padded weight parameter for CUDA Graph compatibility.
self
.
_create_padded_weight_param
(
layer
)
def
_create_padded_weight_param
(
self
,
layer
:
torch
.
nn
.
Module
):
"""Create padded weight parameter for GGUF MergedLinear layer."""
qweight
=
layer
.
qweight
shard_id_map
=
qweight
.
shard_id_map
shard_id
=
qweight
.
shard_id
if
len
(
data_container
:
=
qweight
.
data_container
)
>
1
:
dtype
=
{
data
.
dtype
for
data
in
data_container
}
assert
len
(
dtype
)
==
1
,
ValueError
(
f
"Data container has mixed dtypes:
{
dtype
}
"
)
dtype
=
next
(
iter
(
dtype
))
# concat dim0 and pad dim1
padded_side
=
max
(
x
.
size
(
1
)
for
x
in
data_container
)
concat_side
=
sum
(
x
.
size
(
0
)
for
x
in
data_container
)
# Pad the quantized weights to dense tensor, and create a map
# with the location of each shard in the padded tensor.
padded_data
=
torch
.
zeros
((
concat_side
,
padded_side
),
dtype
=
dtype
,
device
=
qweight
.
device
)
# (dim0_start, dim0_end, dim1_size)
shard_offset_map
=
dict
[
str
,
tuple
[
int
,
int
,
int
]]()
for
idx
in
shard_id
:
id_in_container
=
shard_id_map
[
idx
]
start
=
sum
(
x
.
size
(
0
)
for
x
in
data_container
[:
id_in_container
])
end
=
start
+
data_container
[
id_in_container
].
size
(
0
)
size
=
data_container
[
id_in_container
].
size
(
1
)
padded_data
[
start
:
end
,
:
size
]
=
data_container
[
id_in_container
]
shard_offset_map
[
idx
]
=
(
start
,
end
,
size
)
qweight
.
data_container
.
clear
()
padded_param
=
Parameter
(
padded_data
,
requires_grad
=
False
)
set_weight_attrs
(
padded_param
,
vars
(
qweight
))
set_weight_attrs
(
padded_param
,
{
"shard_offset_map"
:
shard_offset_map
})
layer
.
register_parameter
(
"qweight"
,
padded_param
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
shard_id
=
getattr
(
layer
.
qweight
,
"
shard_id
"
,
None
)
shard_id
=
layer
.
qweight
.
shard_id
if
shard_id
:
# dequantize shard weights respectively
shard_id
=
[
"q"
,
"k"
,
"v"
]
if
"q"
in
shard_id
else
shard_id
qweight
=
layer
.
qweight
.
unbind
(
0
)
qweight
=
layer
.
qweight
result
=
[]
for
idx
in
shard_id
:
q_idx
=
layer
.
qweight
.
shard_
id
_map
[
idx
]
start
,
end
,
offset
=
layer
.
qweight
.
shard_
offset
_map
[
idx
]
qweight_type
=
layer
.
qweight_type
.
shard_weight_type
[
idx
]
result
.
append
(
_fuse_mul_mat
(
x
,
qweight
[
q_idx
],
qweight_type
))
result
.
append
(
fused_mul_mat_gguf
(
x
,
qweight
[
start
:
end
,
:
offset
].
contiguous
(),
qweight_type
))
out
=
torch
.
cat
(
result
,
axis
=
1
)
else
:
qweight
=
layer
.
qweight
qweight_type
=
layer
.
qweight_type
.
weight_type
out
=
_
fuse_mul_mat
(
x
,
qweight
,
qweight_type
)
out
=
fuse
d
_mul_mat
_gguf
(
x
,
qweight
,
qweight_type
)
if
bias
is
not
None
:
out
.
add_
(
bias
)
return
out
...
...
@@ -338,7 +497,6 @@ class GGUFMoEMethod(FusedMoEMethodBase):
set_weight_attrs
(
w2_qweight_type
,
extra_weight_attrs
)
layer
.
register_parameter
(
"w2_qweight_type"
,
w2_qweight_type
)
self
.
act
=
SiluAndMul
()
def
apply
(
self
,
...
...
@@ -375,10 +533,10 @@ class GGUFMoEMethod(FusedMoEMethodBase):
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
_
fused_moe_gguf
(
x
,
layer
.
w13_qweight
,
layer
.
w2_qweight
,
topk_weights
,
topk_ids
,
layer
.
w13_qweight_type
.
weight_type
,
layer
.
w2_qweight_type
.
weight_type
,
self
.
act
)
return
fused_moe_gguf
(
x
,
layer
.
w13_qweight
,
layer
.
w2_qweight
,
topk_weights
,
topk_ids
,
layer
.
w13_qweight_type
.
weight_type
,
layer
.
w2_qweight_type
.
weight_type
,
activation
)
class
GGUFEmbeddingMethod
(
GGUFLinearMethod
):
...
...
@@ -392,34 +550,15 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
qweight
=
layer
.
qweight
qweight_type
=
layer
.
qweight_type
.
weight_type
hidden_size
=
qweight
.
tensor_shape
[
1
]
block_size
,
type_size
=
gguf
.
GGML_QUANT_SIZES
[
qweight_type
]
hidden_size
=
qweight
.
shape
[
1
]
//
type_size
*
block_size
if
qweight_type
<
2
:
return
torch
.
embedding
(
qweight
,
x
)
x_flat
=
x
.
flatten
()
quant
=
torch
.
index_select
(
qweight
,
dim
=
0
,
index
=
x_flat
)
dequant
=
ops
.
ggml_dequantize
(
quant
,
qweight_type
,
hidden_size
,
x_flat
.
shape
[
0
],
self
.
params_dtype
)
return
dequant
.
view
(
*
x
.
shape
,
hidden_size
)
return
apply_gguf_embedding
(
x
,
qweight
,
qweight_type
,
hidden_size
,
dtype
=
self
.
params_dtype
)
class
GGUFUninitializedParameter
(
UninitializedParameter
):
cls_to_become
=
Parameter
data_container
:
list
[
torch
.
Tensor
]
def
materialize_nested
(
self
)
->
Parameter
:
dtype
=
{
data
.
dtype
for
data
in
self
.
data_container
}
assert
len
(
dtype
)
==
1
,
ValueError
(
f
"Data container has mixed dtypes:
{
dtype
}
"
)
dtype
=
next
(
iter
(
dtype
))
nested_data
=
torch
.
nested
.
nested_tensor
(
self
.
data_container
,
device
=
self
.
device
,
dtype
=
dtype
)
self
.
data_container
.
clear
()
param
=
torch
.
Tensor
.
_make_subclass
(
self
.
cls_to_become
,
nested_data
,
require_grad
=
False
)
for
k
,
v
in
self
.
__dict__
.
items
():
setattr
(
param
,
k
,
v
)
return
param
vllm/model_executor/layers/quantization/ipex_quant.py
View file @
4eabe123
...
...
@@ -14,7 +14,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from
vllm.model_executor.layers.quantization.gptq
import
GPTQLinearMethod
from
vllm.platforms
import
current_platform
MIN_IPEX_VERSION
=
"2.
5
.0"
MIN_IPEX_VERSION
=
"2.
7
.0"
class
IPEXConfig
(
QuantizationConfig
):
...
...
@@ -181,8 +181,6 @@ class IPEXGPTQLinearMethod(GPTQLinearMethod):
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out
=
layer
.
ipex_qlinear
(
reshaped_x
)
if
bias
is
not
None
:
out
.
add_
(
bias
)
return
out
.
reshape
(
x
.
shape
[:
-
1
]
+
(
layer
.
ipex_output_size
,
))
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
4eabe123
...
...
@@ -192,7 +192,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
@
classmethod
def
get_name
(
cls
)
->
QuantizationMethods
:
return
"
nv
fp4"
return
"
modelopt_
fp4"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
list
[
torch
.
dtype
]:
...
...
@@ -228,7 +228,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
exclude_modules
,
group_size
)
def
is_layer_excluded
(
self
,
prefix
:
str
,
exclude_modules
:
list
):
import
re
import
re
gex
as
re
for
pattern
in
exclude_modules
:
regex_str
=
pattern
.
replace
(
'.'
,
r
'\.'
).
replace
(
'*'
,
r
'.*'
)
if
re
.
fullmatch
(
regex_str
,
prefix
):
...
...
vllm/model_executor/layers/quantization/quark/utils.py
View file @
4eabe123
# SPDX-License-Identifier: Apache-2.0
import
re
from
collections.abc
import
Iterable
,
Mapping
from
types
import
MappingProxyType
from
typing
import
Any
,
Optional
import
regex
as
re
def
deep_compare
(
dict1
:
Any
,
dict2
:
Any
)
->
bool
:
if
type
(
dict1
)
is
not
type
(
dict2
):
...
...
vllm/model_executor/layers/quantization/utils/gptq_utils.py
View file @
4eabe123
# SPDX-License-Identifier: Apache-2.0
import
re
from
copy
import
deepcopy
from
typing
import
Optional
,
Union
import
regex
as
re
import
torch
from
vllm.config
import
QuantizationConfig
...
...
vllm/model_executor/layers/rejection_sampler.py
View file @
4eabe123
...
...
@@ -262,16 +262,16 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
True, then a token can be accepted, else it should be
rejected.
Given
{math}`
q(\hat{x}_{n+1}|x_1, \dots, x_n)
`
, the probability of
{math}`
\hat{x}_{n+1}
`
given context
{math}`
x_1, \dots, x_n
`
according
to the target model, and
{math}`
p(\hat{x}_{n+1}|x_1, \dots, x_n)
`
, the
Given
$
q(\hat{x}_{n+1}|x_1, \dots, x_n)
$
, the probability of
$
\hat{x}_{n+1}
$
given context
$
x_1, \dots, x_n
$
according
to the target model, and
$
p(\hat{x}_{n+1}|x_1, \dots, x_n)
$
, the
same conditional probability according to the draft model, the token
is accepted with probability:
:::{math}
$$
\min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
{p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)
:::
$$
This implementation does not apply causality. When using the output,
if a token is rejected, subsequent tokens should not be used.
...
...
@@ -314,30 +314,31 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target model is recovered (within hardware numerics).
The probability distribution used in this rejection case is constructed
as follows. Given
{math}`
q(x|x_1, \dots, x_n)
`
, the probability of
{math}`x`
given context
{math}`
x_1, \dots, x_n
`
according to the target
model and
{math}`
p(x|x_1, \dots, x_n)
`
, the same conditional probability
as follows. Given
$
q(x|x_1, \dots, x_n)
$
, the probability of
$x$
given context
$
x_1, \dots, x_n
$
according to the target
model and
$
p(x|x_1, \dots, x_n)
$
, the same conditional probability
according to the draft model:
:::{math}
$$
x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+
:::
$$
where
{math}`
(f(x))_+
`
is defined as:
where
$
(f(x))_+
$
is defined as:
:::{math}
$$
(f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}
:::
$$
See https://github.com/vllm-project/vllm/pull/2336 for a visualization
of the draft, target, and recovered probability distributions.
Returns a tensor of shape [batch_size, k, vocab_size].
Note: This batches operations on GPU and thus constructs the recovered
distribution for all tokens, even if they are accepted. This causes
division-by-zero errors, so we use self._smallest_positive_value to
avoid that. This introduces some drift to the distribution.
Note:
This batches operations on GPU and thus constructs the recovered
distribution for all tokens, even if they are accepted. This causes
division-by-zero errors, so we use self._smallest_positive_value to
avoid that. This introduces some drift to the distribution.
"""
_
,
k
,
_
=
draft_probs
.
shape
...
...
vllm/model_executor/layers/sampler.py
View file @
4eabe123
...
...
@@ -228,17 +228,19 @@ class Sampler(nn.Module):
)
->
Optional
[
SamplerOutput
]:
"""
Single-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Pythonize sampling result & logprobs tensor
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Pythonize sampling result & logprobs tensor
Multi-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Defer Pythonization of sampling result & logprobs
tensor
* Encapsulate arguments required for deferred Pythonization
in the {class}`SamplerOutput` structure
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Defer Pythonization of sampling result & logprobs
tensor
* Encapsulate arguments required for deferred Pythonization
in the
[`SamplerOutput`][vllm.model_executor.layers.sampler.SamplerOutput]
structure
Args:
logits: (num_tokens, vocab_size).
...
...
vllm/model_executor/layers/typical_acceptance_sampler.py
View file @
4eabe123
...
...
@@ -93,29 +93,27 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
Evaluates and returns a mask of accepted tokens based on the
posterior probabilities.
Parameters:
----------
target_probs : torch.Tensor
A tensor of shape (batch_size, k, vocab_size) representing
the probabilities of each token in the vocabulary for each
position in the proposed sequence. This is the distribution
generated by the target model.
draft_token_ids : torch.Tensor
A tensor of shape (batch_size, k) representing the proposed
token ids.
Args:
target_probs (torch.Tensor): A tensor of shape
(batch_size, k, vocab_size) representing the probabilities of
each token in the vocabulary for each position in the proposed
sequence. This is the distribution generated by the target
model.
draft_token_ids (torch.Tensor): A tensor of shape (batch_size, k)
representing the proposed token ids.
A draft token_id x_{n+k} is accepted if it satisfies the
following condition
:::{math}
$$
p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) >
\min \left( \epsilon, \delta * \exp \left(
-H(p_{\text{original}}(
\cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right)
:::
$$
where
{math}`
p_{\text{original}}
`
corresponds to target_probs
and
{math}`
\epsilon
`
and
{math}`
\delta
`
correspond to hyperparameters
where
$
p_{\text{original}}
$
corresponds to target_probs
and
$
\epsilon
$
and
$
\delta
$
correspond to hyperparameters
specified using self._posterior_threshold and self._posterior_alpha
This method computes the posterior probabilities for the given
...
...
@@ -126,13 +124,10 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
returns a boolean mask indicating which tokens can be accepted.
Returns:
-------
torch.Tensor
A boolean tensor of shape (batch_size, k) where each element
indicates whether the corresponding draft token has been accepted
or rejected. True indicates acceptance and false indicates
rejection.
torch.Tensor: A boolean tensor of shape (batch_size, k) where each
element indicates whether the corresponding draft token has
been accepted or rejected. True indicates acceptance and false
indicates rejection.
"""
device
=
target_probs
.
device
candidates_prob
=
torch
.
gather
(
...
...
@@ -156,17 +151,14 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
The recovered token ids will fill the first unmatched token
by the target token.
Parameters
----------
target_probs : torch.Tensor
A tensor of shape (batch_size, k, vocab_size) containing
the target probability distribution
Returns
-------
torch.Tensor
A tensor of shape (batch_size, k) with the recovered token
ids which are selected from target probs.
Args:
target_probs (torch.Tensor): A tensor of shape
(batch_size, k, vocab_size) containing the target probability
distribution.
Returns:
torch.Tensor: A tensor of shape (batch_size, k) with the recovered
token ids which are selected from target probs.
"""
max_indices
=
torch
.
argmax
(
target_probs
,
dim
=-
1
)
...
...
vllm/model_executor/model_loader/__init__.py
View file @
4eabe123
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
from
torch
import
nn
from
vllm.config
import
LoadConfig
,
LoadFormat
,
VllmConfig
from
vllm.config
import
LoadConfig
,
LoadFormat
,
ModelConfig
,
VllmConfig
from
vllm.model_executor.model_loader.base_loader
import
BaseModelLoader
from
vllm.model_executor.model_loader.bitsandbytes_loader
import
(
BitsAndBytesModelLoader
)
...
...
@@ -47,9 +49,14 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
return
DefaultModelLoader
(
load_config
)
def
get_model
(
*
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
def
get_model
(
*
,
vllm_config
:
VllmConfig
,
model_config
:
Optional
[
ModelConfig
]
=
None
)
->
nn
.
Module
:
loader
=
get_model_loader
(
vllm_config
.
load_config
)
return
loader
.
load_model
(
vllm_config
=
vllm_config
)
if
model_config
is
None
:
model_config
=
vllm_config
.
model_config
return
loader
.
load_model
(
vllm_config
=
vllm_config
,
model_config
=
model_config
)
__all__
=
[
...
...
vllm/model_executor/model_loader/base_loader.py
View file @
4eabe123
...
...
@@ -18,6 +18,7 @@ class BaseModelLoader(ABC):
raise
NotImplementedError
@
abstractmethod
def
load_model
(
self
,
*
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
def
load_model
(
self
,
*
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
)
->
nn
.
Module
:
"""Load a model with the given configurations."""
raise
NotImplementedError
vllm/model_executor/model_loader/bitsandbytes_loader.py
View file @
4eabe123
...
...
@@ -569,10 +569,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
def
load_model
(
self
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
...
...
vllm/model_executor/model_loader/default_loader.py
View file @
4eabe123
...
...
@@ -11,8 +11,8 @@ import torch
from
torch
import
nn
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
vllm
import
envs
from
vllm.config
import
LoadConfig
,
LoadFormat
,
ModelConfig
,
VllmConfig
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader.base_loader
import
BaseModelLoader
from
vllm.model_executor.model_loader.utils
import
(
...
...
@@ -64,7 +64,7 @@ class DefaultModelLoader(BaseModelLoader):
Returns the path to the downloaded model, or None if the model is not
downloaded from ModelScope."""
if
VLLM_USE_MODELSCOPE
:
if
envs
.
VLLM_USE_MODELSCOPE
:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
# pylint: disable=C.
...
...
@@ -264,13 +264,14 @@ class DefaultModelLoader(BaseModelLoader):
fall_back_to_pt
=
True
,
allow_patterns_overrides
=
None
)
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
def
load_model
(
self
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
model
=
initialize_model
(
vllm_config
=
vllm_config
)
model
=
initialize_model
(
vllm_config
=
vllm_config
,
model_config
=
model_config
)
weights_to_load
=
{
name
for
name
,
_
in
model
.
named_parameters
()}
loaded_weights
=
model
.
load_weights
(
...
...
vllm/model_executor/model_loader/dummy_loader.py
View file @
4eabe123
...
...
@@ -22,9 +22,9 @@ class DummyModelLoader(BaseModelLoader):
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
pass
# Nothing to download
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
def
load_model
(
self
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
...
...
vllm/model_executor/model_loader/gguf_loader.py
View file @
4eabe123
...
...
@@ -92,9 +92,9 @@ class GGUFModelLoader(BaseModelLoader):
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
)
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
def
load_model
(
self
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
local_model_path
=
self
.
_prepare_weights
(
model_config
.
model
)
gguf_weights_map
=
self
.
_get_gguf_weights_map
(
model_config
)
# we can only know if tie word embeddings after mapping weights
...
...
vllm/model_executor/model_loader/neuronx_distributed.py
View file @
4eabe123
...
...
@@ -87,16 +87,29 @@ class NeuronCausalLM(nn.Module):
input_block_ids
:
torch
.
Tensor
,
sampling_params
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# sort block ids sequentially for perf/neuron support reasons
sorted_input_block_ids
,
sorted_indices
=
torch
.
sort
(
input_block_ids
)
input_ids
=
torch
.
index_select
(
input_ids
,
0
,
sorted_indices
)
positions
=
torch
.
index_select
(
positions
,
0
,
sorted_indices
)
sampling_params
=
torch
.
index_select
(
sampling_params
,
0
,
sorted_indices
)
output
=
self
.
model
(
input_ids
,
attention_mask
=
None
,
position_ids
=
positions
,
seq_ids
=
input_block_ids
,
seq_ids
=
sorted_
input_block_ids
,
sampling_params
=
sampling_params
)
# on-device sampling
if
self
.
config
.
neuron_config
.
on_device_sampling_config
:
return
output
.
hidden_states
output
=
output
.
hidden_states
else
:
return
output
.
logits
[:,
-
1
,
:]
output
=
output
.
logits
[:,
-
1
,
:]
restored_indices
=
torch
.
argsort
(
sorted_indices
)
if
input_block_ids
.
shape
[
0
]
!=
1
:
output
=
torch
.
index_select
(
output
,
0
,
restored_indices
)
return
output
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
...
...
@@ -340,14 +353,26 @@ class NeuronSpeculationCausalLM(nn.Module):
input_block_ids
:
torch
.
Tensor
,
sampling_params
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# sort block ids sequentially for perf/neuron support reasons
sorted_input_block_ids
,
sorted_indices
=
torch
.
sort
(
input_block_ids
)
input_ids
=
torch
.
index_select
(
input_ids
,
0
,
sorted_indices
)
positions
=
torch
.
index_select
(
positions
,
0
,
sorted_indices
)
sampling_params
=
torch
.
index_select
(
sampling_params
,
0
,
sorted_indices
)
output
=
self
.
model
(
input_ids
,
attention_mask
=
None
,
position_ids
=
positions
,
seq_ids
=
input_block_ids
,
seq_ids
=
sorted_
input_block_ids
,
sampling_params
=
sampling_params
)
restored_indices
=
torch
.
argsort
(
sorted_indices
)
# CTX encoding
if
(
positions
[:,
0
]).
sum
().
item
()
==
0
:
return
output
.
fused_outputs
[
0
][:,
0
:
1
]
output
=
output
.
fused_outputs
[
0
][:,
0
:
1
]
if
input_block_ids
.
shape
[
0
]
!=
1
:
output
=
torch
.
index_select
(
output
,
0
,
restored_indices
)
return
output
# Fused Spec (Generation)
accepted_tokens_with_padding
=
output
.
fused_outputs
[
0
]
...
...
@@ -362,6 +387,10 @@ class NeuronSpeculationCausalLM(nn.Module):
-
1
)
>=
generated_token_counts
accepted_tokens_with_padding
[
mask
]
=
-
1
if
input_block_ids
.
shape
[
0
]
!=
1
:
accepted_tokens_with_padding
=
torch
.
index_select
(
accepted_tokens_with_padding
,
0
,
restored_indices
)
return
accepted_tokens_with_padding
def
sample
(
...
...
@@ -416,6 +445,10 @@ class NeuronSpeculationCausalLM(nn.Module):
draft_neuron_config
.
speculation_length
=
0
draft_neuron_config
.
trace_tokengen_model
=
True
draft_neuron_config
.
enable_fused_speculation
=
False
if
getattr
(
config
.
neuron_config
,
"draft_model_modules_to_not_convert"
,
None
):
draft_neuron_config
.
modules_to_not_convert
=
(
draft_neuron_config
.
draft_model_modules_to_not_convert
)
if
config
.
neuron_config
.
enable_eagle_speculation
:
draft_neuron_config
.
is_eagle_draft
=
True
draft_neuron_config
.
sequence_parallel_enabled
=
False
...
...
@@ -502,7 +535,7 @@ def _get_default_neuron_config(model_config: ModelConfig,
max_context_length
=
scheduler_config
.
max_model_len
,
seq_len
=
scheduler_config
.
max_model_len
,
enable_bucketing
=
True
,
is_continuous_batching
=
(
batch_size
>
1
)
,
is_continuous_batching
=
True
,
quantized
=
False
,
torch_dtype
=
TORCH_DTYPE_TO_NEURON_AMP
[
model_config
.
dtype
],
padding_side
=
"right"
,
...
...
@@ -520,6 +553,7 @@ def _get_default_speculation_config(model_config: ModelConfig,
args."""
neuron_config
=
dict
(
tp_degree
=
parallel_config
.
tensor_parallel_size
,
ctx_batch_size
=
1
,
batch_size
=
scheduler_config
.
max_num_seqs
,
max_context_length
=
scheduler_config
.
max_model_len
,
seq_len
=
scheduler_config
.
max_model_len
,
...
...
@@ -527,6 +561,7 @@ def _get_default_speculation_config(model_config: ModelConfig,
trace_tokengen_model
=
False
,
enable_fused_speculation
=
True
,
enable_bucketing
=
True
,
is_continuous_batching
=
True
,
quantized
=
False
,
torch_dtype
=
TORCH_DTYPE_TO_NEURON_AMP
[
model_config
.
dtype
],
on_device_sampling_config
=
dict
(
...
...
vllm/model_executor/model_loader/runai_streamer_loader.py
View file @
4eabe123
...
...
@@ -100,11 +100,10 @@ class RunaiModelStreamerLoader(BaseModelLoader):
"""Download model if necessary"""
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
def
load_model
(
self
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
)
->
nn
.
Module
:
"""Perform streaming of the model to destination"""
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
...
...
Prev
1
…
23
24
25
26
27
28
29
30
31
…
34
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment