Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1591c68f
Commit
1591c68f
authored
May 25, 2024
by
zhuwenwen
Browse files
merge v0.4.2
parents
09bcf00b
c7f2cf2b
Changes
265
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1447 additions
and
509 deletions
+1447
-509
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+186
-60
vllm/model_executor/layers/quantization/gptq.py
vllm/model_executor/layers/quantization/gptq.py
+11
-8
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+438
-0
vllm/model_executor/layers/quantization/marlin.py
vllm/model_executor/layers/quantization/marlin.py
+8
-5
vllm/model_executor/layers/quantization/squeezellm.py
vllm/model_executor/layers/quantization/squeezellm.py
+13
-10
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+139
-3
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+404
-209
vllm/model_executor/layers/vocab_parallel_embedding.py
vllm/model_executor/layers/vocab_parallel_embedding.py
+8
-0
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+29
-21
vllm/model_executor/model_loader/tensorizer.py
vllm/model_executor/model_loader/tensorizer.py
+28
-24
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+42
-33
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+7
-6
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+22
-21
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+16
-16
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+19
-18
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+17
-16
vllm/model_executor/models/dbrx.py
vllm/model_executor/models/dbrx.py
+18
-17
vllm/model_executor/models/decilm.py
vllm/model_executor/models/decilm.py
+4
-3
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+22
-23
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+16
-16
No files found.
vllm/model_executor/layers/quantization/fp8.py
View file @
1591c68f
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch.nn
import
Module
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm
import
_custom_ops
as
ops
set_weight_attrs
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
class
FP8Config
(
QuantizationConfig
):
logger
=
init_logger
(
__name__
)
class
Fp8Config
(
QuantizationConfig
):
"""Config class for FP8."""
"""Config class for FP8."""
def
__init__
(
self
,
is_checkpoint_fp8_serialized
:
bool
=
False
,
activation_scheme
:
str
=
"dynamic"
,
)
->
None
:
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
if
is_checkpoint_fp8_serialized
:
logger
.
warning
(
"Detected fp8 checkpoint. Please note that the "
"format is experimental and subject to change."
)
if
activation_scheme
not
in
ACTIVATION_SCHEMES
:
raise
ValueError
(
f
"Unsupported activation scheme
{
activation_scheme
}
"
)
self
.
activation_scheme
=
activation_scheme
@
classmethod
@
classmethod
def
get_name
(
cls
)
->
str
:
def
get_name
(
cls
)
->
str
:
return
"fp8"
return
"fp8"
...
@@ -23,21 +43,25 @@ class FP8Config(QuantizationConfig):
...
@@ -23,21 +43,25 @@ class FP8Config(QuantizationConfig):
@
classmethod
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
# TODO: PyTorch 2.3.0+ is required to run FP8 on
return
89
# SM 89 (e.g. Ada) GPUs. Specifically, this PR has to
# be included: https://github.com/pytorch/pytorch/pull/118881
return
90
@
classmethod
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
return
[]
@
classmethod
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"FP8Config"
:
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"Fp8Config"
:
return
cls
()
quant_method
=
cls
.
get_from_keys
(
config
,
[
"quant_method"
])
is_checkpoint_fp8_serialized
=
(
"fp8"
in
quant_method
)
activation_scheme
=
cls
.
get_from_keys
(
config
,
[
"activation_scheme"
])
return
cls
(
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
,
activation_scheme
=
activation_scheme
)
def
get_linear_method
(
self
)
->
"Fp8LinearMethod"
:
def
get_quant_method
(
return
Fp8LinearMethod
(
self
)
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"Fp8LinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
Fp8LinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
return
[]
...
@@ -45,8 +69,12 @@ class FP8Config(QuantizationConfig):
...
@@ -45,8 +69,12 @@ class FP8Config(QuantizationConfig):
class
Fp8LinearMethod
(
LinearMethodBase
):
class
Fp8LinearMethod
(
LinearMethodBase
):
"""Linear method for FP8.
"""Linear method for FP8.
We now support common FP16/BF16 model checkpoints ONLY. The weight
Supports loading FP8 checkpoints with static weight scale and
scaling factor will be initialized after the model weights are loaded.
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Limitations:
Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support.
1. Only support per-tensor quantization due to torch._scaled_mm support.
...
@@ -57,9 +85,27 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -57,9 +85,27 @@ class Fp8LinearMethod(LinearMethodBase):
quant_config: The quantization config.
quant_config: The quantization config.
"""
"""
def
__init__
(
self
,
quant_config
:
F
P
8Config
):
def
__init__
(
self
,
quant_config
:
F
p
8Config
):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
def
_create_scale_param
(
self
,
scale_name
:
str
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
List
[
int
],
**
extra_weight_attrs
,
)
->
None
:
scale
=
Parameter
(
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
scale_name
,
scale
)
set_weight_attrs
(
scale
,
{
**
extra_weight_attrs
,
"fp8_scales_shard_indexer"
:
self
.
scales_shard_indexer
,
})
def
create_weights
(
def
create_weights
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
@@ -70,70 +116,150 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -70,70 +116,150 @@ class Fp8LinearMethod(LinearMethodBase):
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
**
extra_weight_attrs
,
):
):
del
input_size
,
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
layer
.
process_after_load
=
True
layer
.
logical_widths
=
output_partition_sizes
# WEIGHT
weight_dtype
=
(
torch
.
float8_e4m3fn
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
else
params_dtype
)
weight
=
Parameter
(
torch
.
empty
(
output_size_per_partition
,
weight
=
Parameter
(
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
input_size_per_partition
,
dtype
=
params
_dtype
),
dtype
=
weight
_dtype
),
requires_grad
=
False
)
requires_grad
=
False
)
layer
.
register_parameter
(
"weight"
,
weight
)
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
})
set_weight_attrs
(
weight
,
{
set_weight_attrs
(
weight
,
extra_weight_attrs
)
**
extra_weight_attrs
,
"input_dim"
:
1
,
"output_dim"
:
0
,
})
w_scale
=
Parameter
(
# If checkpoint is serialized fp8, load them.
torch
.
empty
(
1
,
dtype
=
torch
.
float32
),
# Otherwise, wait until process_weights_after_loading.
requires_grad
=
False
,
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
)
# WEIGHT SCALE
layer
.
register_parameter
(
"weight_scaling_factor"
,
w_scale
)
self
.
_create_scale_param
(
scale_name
=
"weight_scale"
,
layer
=
layer
,
output_partition_sizes
=
output_partition_sizes
,
**
extra_weight_attrs
)
# ACTIVATION SCALE
if
self
.
quant_config
.
activation_scheme
==
"static"
:
self
.
_create_scale_param
(
scale_name
=
"act_scale"
,
layer
=
layer
,
output_partition_sizes
=
output_partition_sizes
,
**
extra_weight_attrs
)
def
scales_shard_indexer
(
self
,
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
shard_id
:
Union
[
str
,
int
])
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
qkv_idxs
=
{
"q"
:
0
,
"k"
:
1
,
"v"
:
2
}
if
isinstance
(
shard_id
,
int
):
pass
elif
isinstance
(
shard_id
,
str
):
if
shard_id
not
in
qkv_idxs
:
raise
ValueError
(
f
"Unknown shard_id:
{
shard_id
}
"
)
shard_id
=
qkv_idxs
[
shard_id
]
else
:
ValueError
(
f
"Shard id must be int or str but got
{
type
(
shard_id
)
}
"
)
return
param
[
shard_id
],
loaded_weight
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# Although the linear_method is propagated to all layers,
if
(
not
hasattr
(
layer
,
"process_after_load"
)
# only linear layers invoke "create_weights". So we check
or
not
layer
.
process_after_load
):
# whether "weight_scaling_facor" is registered to determine
return
# whether the layer is a linear layer that requires quantization.
if
not
hasattr
(
layer
,
"weight_scaling_factor"
):
# If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
scale
=
None
)
layer
.
weight
=
Parameter
(
qweight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
logical_widths
=
None
layer
.
act_scale
=
None
return
return
qweight
,
weight_scale
=
per_tensor_quantize
(
layer
.
weight
)
# If checkpoint is fp8, requantize the separately quantized logical
# torch._scaled_mm requires column-major in the second
# weights into a single fp8 weight with a single weight scale.
# input (weight), so we transpose the quantized weight.
else
:
layer
.
weight
=
Parameter
(
qweight
.
t
(),
requires_grad
=
False
)
# WEIGHT_SCALE / WEIGHT
layer
.
weight_scaling_factor
.
data
.
copy_
(
weight_scale
)
# Loop over logical weights, requantizing with single scale.
max_w_scale
=
layer
.
weight_scale
.
max
()
def
apply_weights
(
self
,
start
=
0
layer
:
torch
.
nn
.
Module
,
for
idx
,
logical_width
in
enumerate
(
layer
.
logical_widths
):
x
:
torch
.
Tensor
,
end
=
start
+
logical_width
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
weight_dq
=
per_tensor_dequantize
(
layer
.
weight
[
start
:
end
,
:],
qinput
,
x_scale
=
per_tensor_quantize
(
x
)
layer
.
weight_scale
[
idx
])
layer
.
weight
[
start
:
end
,
:]
=
per_tensor_quantize
(
weight_dq
,
layer
.
weight_scale
.
max
())
start
=
end
layer
.
weight_scale
=
Parameter
(
max_w_scale
,
requires_grad
=
False
)
# WEIGHT
# Transpose weight for passing to torch._scaled_mm
weight
=
layer
.
weight
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
# ACT_SCALE
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
# Static: set to max of the act_scales (since they are equal).
if
self
.
quant_config
.
activation_scheme
==
"dynamic"
:
layer
.
act_scale
=
None
elif
self
.
quant_config
.
activation_scheme
==
"static"
:
if
not
all_close_1d
(
layer
.
act_scale
):
raise
ValueError
(
"All the act_scales for the logical weights of a layer "
f
"must be equal. But got
{
layer
.
act_scale
}
"
)
layer
.
act_scale
=
Parameter
(
layer
.
act_scale
.
max
(),
requires_grad
=
False
)
else
:
raise
ValueError
(
f
"Unknown scheme
{
self
.
quant_config
.
activation_scheme
}
"
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.act_scale is None and x_scale computed from x.
# If static, layer.act_scale is scalar and x_scale set to act_scale.
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
x
,
layer
.
act_scale
)
# Fused GEMM_DQ
output
,
_
=
torch
.
_scaled_mm
(
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
qinput
,
layer
.
weight
,
layer
.
weight
,
out_dtype
=
x
.
dtype
,
out_dtype
=
x
.
dtype
,
scale_a
=
x_scale
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scal
ing_factor
,
scale_b
=
layer
.
weight_scal
e
,
bias
=
bias
,
bias
=
bias
,
)
)
return
output
return
output
def
per_tensor_quantize
(
tensor
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
float
]:
def
all_close_1d
(
x
:
torch
.
Tensor
)
->
bool
:
"""Quantize a tensor using per-tensor static scaling factor.
assert
len
(
x
.
shape
)
==
1
return
all
(
torch
.
allclose
(
x
[
0
],
x
[
i
])
for
i
in
range
(
x
.
shape
[
0
]))
Args:
tensor: The input t
ensor
.
def
per_tensor_quantize
(
tensor
:
torch
.
T
ensor
,
"""
inv_scale
:
float
)
->
torch
.
Tensor
:
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
# Calculate the scale as dtype max divided by absmax.
qweight
=
(
tensor
/
inv_scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
# Since .abs() creates a new tensor, we use aminmax to get
return
qweight
.
to
(
torch
.
float8_e4m3fn
)
# the min and max first and then calculate the absmax.
min_val
,
max_val
=
tensor
.
aminmax
()
amax
=
min_val
.
abs
().
max
(
max_val
.
abs
())
def
per_tensor_dequantize
(
tensor
:
torch
.
Tensor
,
scale
=
finfo
.
max
/
amax
.
clamp
(
min
=
1e-12
)
inv_scale
:
float
)
->
torch
.
Tensor
:
# scale and clamp the tensor to bring it to
fake_qweight
=
tensor
.
to
(
torch
.
float16
)
# the representative range of float8 data type
dq_weight
=
fake_qweight
*
inv_scale
# (as default cast is unsaturated)
return
dq_weight
qweight
=
(
tensor
*
scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight
=
qweight
.
to
(
torch
.
float8_e4m3fn
)
scale
=
scale
.
float
().
reciprocal
()
return
qweight
,
scale
vllm/model_executor/layers/quantization/gptq.py
View file @
1591c68f
...
@@ -7,10 +7,10 @@ import torch
...
@@ -7,10 +7,10 @@ import torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
class
GPTQConfig
(
QuantizationConfig
):
class
GPTQConfig
(
QuantizationConfig
):
...
@@ -63,8 +63,11 @@ class GPTQConfig(QuantizationConfig):
...
@@ -63,8 +63,11 @@ class GPTQConfig(QuantizationConfig):
desc_act
=
cls
.
get_from_keys
(
config
,
[
"desc_act"
])
desc_act
=
cls
.
get_from_keys
(
config
,
[
"desc_act"
])
return
cls
(
weight_bits
,
group_size
,
desc_act
)
return
cls
(
weight_bits
,
group_size
,
desc_act
)
def
get_linear_method
(
self
)
->
"GPTQLinearMethod"
:
def
get_quant_method
(
return
GPTQLinearMethod
(
self
)
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"GPTQLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
GPTQLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
return
[]
...
@@ -194,10 +197,10 @@ class GPTQLinearMethod(LinearMethodBase):
...
@@ -194,10 +197,10 @@ class GPTQLinearMethod(LinearMethodBase):
layer
.
exllama_state
=
exllama_state
layer
.
exllama_state
=
exllama_state
def
apply
_weights
(
self
,
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
qweight
=
layer
.
qweight
qweight
=
layer
.
qweight
out_shape
=
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
-
1
],
)
out_shape
=
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
-
1
],
)
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
0 → 100644
View file @
1591c68f
import
enum
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
GPTQ_MARLIN_TILE
=
16
GPTQ_MARLIN_MIN_THREAD_N
=
64
GPTQ_MARLIN_MIN_THREAD_K
=
128
GPTQ_MARLIN_MAX_PARALLEL
=
16
GPTQ_MARLIN_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
GPTQ_MARLIN_SUPPORTED_SYM
=
[
True
]
# Permutations for Marlin scale shuffling
def
get_scale_perms
(
num_bits
):
scale_perm
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
scale_perm
,
scale_perm_single
def
get_pack_factor
(
num_bits
):
assert
(
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
),
f
"Unsupported num_bits =
{
num_bits
}
"
return
32
//
num_bits
def
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
num_bits
):
scale_perm
,
scale_perm_single
=
get_scale_perms
(
num_bits
)
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
else
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s
class
GPTQMarlinConfig
(
QuantizationConfig
):
"""Config class for GPTQ Marlin"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
is_sym
:
bool
)
->
None
:
if
desc_act
and
group_size
==
-
1
:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act
=
False
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
self
.
is_sym
=
is_sym
# Verify
if
self
.
weight_bits
not
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
:
raise
ValueError
(
f
"Marlin does not support weight_bits =
{
self
.
weight_bits
}
. "
f
"Only weight_bits =
{
GPTQ_MARLIN_SUPPORTED_NUM_BITS
}
"
"are supported."
)
if
self
.
group_size
not
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
:
raise
ValueError
(
f
"Marlin does not support group_size =
{
self
.
group_size
}
. "
f
"Only group_sizes =
{
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
}
"
"are supported."
)
if
self
.
is_sym
not
in
GPTQ_MARLIN_SUPPORTED_SYM
:
raise
ValueError
(
f
"Marlin does not support is_sym =
{
self
.
is_sym
}
. "
f
"Only sym =
{
GPTQ_MARLIN_SUPPORTED_SYM
}
are supported."
)
# Init
self
.
pack_factor
=
get_pack_factor
(
weight_bits
)
self
.
tile_size
=
GPTQ_MARLIN_TILE
self
.
min_thread_n
=
GPTQ_MARLIN_MIN_THREAD_N
self
.
min_thread_k
=
GPTQ_MARLIN_MIN_THREAD_K
self
.
max_parallel
=
GPTQ_MARLIN_MAX_PARALLEL
def
__repr__
(
self
)
->
str
:
return
(
f
"GPTQMarlinConfig(weight_bits=
{
self
.
weight_bits
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"desc_act=
{
self
.
desc_act
}
)"
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"gptq_marlin"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"quantize_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"GPTQMarlinConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
desc_act
=
cls
.
get_from_keys
(
config
,
[
"desc_act"
])
is_sym
=
cls
.
get_from_keys
(
config
,
[
"sym"
])
return
cls
(
weight_bits
,
group_size
,
desc_act
,
is_sym
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"GPTQMarlinLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
GPTQMarlinLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
@
classmethod
def
is_marlin_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
# Extract data from quant config.
num_bits
=
quant_config
.
get
(
"bits"
,
None
)
group_size
=
quant_config
.
get
(
"group_size"
,
None
)
sym
=
quant_config
.
get
(
"sym"
,
None
)
desc_act
=
quant_config
.
get
(
"desc_act"
,
None
)
# If we cannot find the info needed in the config, cannot convert.
if
(
num_bits
is
None
or
group_size
is
None
or
sym
is
None
or
desc_act
is
None
):
return
False
# If the capability of the device is too low, cannot convert.
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
device_capability
=
major
*
10
+
minor
if
device_capability
<
cls
.
get_min_capability
():
return
False
# Otherwise, can convert if model satisfies marlin constraints.
return
(
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
and
group_size
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
and
sym
in
GPTQ_MARLIN_SUPPORTED_SYM
)
class
GPTQMarlinState
(
Enum
):
REPACK
=
enum
.
auto
()
READY
=
enum
.
auto
()
class
GPTQMarlinLinearMethod
(
LinearMethodBase
):
"""Linear method for GPTQ Marlin.
Args:
quant_config: The GPTQ Marlin quantization config.
"""
def
__init__
(
self
,
quant_config
:
GPTQMarlinConfig
)
->
None
:
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
)
->
None
:
del
output_size
# Normalize group_size
if
self
.
quant_config
.
group_size
!=
-
1
:
group_size
=
self
.
quant_config
.
group_size
else
:
group_size
=
input_size
# Validate dtype
if
params_dtype
!=
torch
.
float16
:
raise
ValueError
(
f
"The params dtype must be float16, but got
{
params_dtype
}
"
)
# Validate output_size_per_partition
output_size_per_partition
=
sum
(
output_partition_sizes
)
if
output_size_per_partition
%
self
.
quant_config
.
min_thread_n
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"
{
output_size_per_partition
}
is not divisible by "
f
" min_thread_n =
{
self
.
quant_config
.
min_thread_n
}
."
)
# Validate input_size_per_partition
if
input_size_per_partition
%
self
.
quant_config
.
min_thread_k
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible "
f
"by min_thread_k =
{
self
.
quant_config
.
min_thread_k
}
."
)
if
(
group_size
<
input_size
and
input_size_per_partition
%
group_size
!=
0
):
raise
ValueError
(
f
"Weight input_size_per_partition =
{
input_size_per_partition
}
"
f
" is not divisible by group_size =
{
group_size
}
."
)
# Detect sharding of scales/zp
# By default, no sharding over "input dim"
scales_and_zp_size
=
input_size
//
group_size
scales_and_zp_input_dim
=
None
if
self
.
quant_config
.
desc_act
:
# Act-order case
assert
self
.
quant_config
.
group_size
!=
-
1
is_k_full
=
input_size_per_partition
==
input_size
else
:
# No act-order case
# K is always full due to full alignment with
# group-size and shard of scales/zp
is_k_full
=
True
# If this is a row-parallel case, then shard scales/zp
if
(
input_size
!=
input_size_per_partition
and
self
.
quant_config
.
group_size
!=
-
1
):
scales_and_zp_size
=
input_size_per_partition
//
group_size
scales_and_zp_input_dim
=
0
# Init buffers
# Quantized weights
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
**
extra_weight_attrs
,
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
0
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
},
)
# Activation order
g_idx
=
Parameter
(
torch
.
empty
(
input_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs
(
g_idx
,
{
**
extra_weight_attrs
,
"input_dim"
:
0
,
"ignore_warning"
:
True
},
)
g_idx_sort_indices
=
Parameter
(
torch
.
empty
(
g_idx
.
shape
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
g_idx_sort_indices
,
extra_weight_attrs
)
# Scales
scales
=
Parameter
(
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
**
extra_weight_attrs
,
"input_dim"
:
scales_and_zp_input_dim
,
"output_dim"
:
1
,
},
)
# Quantized zero-points
qzeros
=
Parameter
(
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
device
=
"meta"
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qzeros
,
{
**
extra_weight_attrs
,
"input_dim"
:
scales_and_zp_input_dim
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
},
)
# Allocate marlin workspace
max_workspace_size
=
(
output_size_per_partition
//
self
.
quant_config
.
min_thread_n
)
*
self
.
quant_config
.
max_parallel
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
requires_grad
=
False
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"g_idx"
,
g_idx
)
layer
.
register_parameter
(
"g_idx_sort_indices"
,
g_idx_sort_indices
)
layer
.
register_parameter
(
"scales"
,
scales
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
layer
.
workspace
=
workspace
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
input_size
=
input_size
layer
.
is_k_full
=
is_k_full
layer
.
marlin_state
=
GPTQMarlinState
.
REPACK
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
size_m
=
reshaped_x
.
shape
[
0
]
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
full_size_k
=
layer
.
input_size
out_shape
=
x
.
shape
[:
-
1
]
+
(
part_size_n
,
)
if
layer
.
marlin_state
==
GPTQMarlinState
.
REPACK
:
layer
.
marlin_state
=
GPTQMarlinState
.
READY
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def
replace_tensor
(
name
,
new_t
):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr
(
layer
,
name
).
resize_
(
new_t
.
shape
)
getattr
(
layer
,
name
).
copy_
(
new_t
)
del
new_t
cur_device
=
layer
.
qweight
.
device
# Process act_order
if
self
.
quant_config
.
desc_act
:
# Get sorting based on g_idx
g_idx_sort_indices
=
torch
.
argsort
(
layer
.
g_idx
).
to
(
torch
.
int
)
sorted_g_idx
=
layer
.
g_idx
[
g_idx_sort_indices
]
replace_tensor
(
"g_idx"
,
sorted_g_idx
)
replace_tensor
(
"g_idx_sort_indices"
,
g_idx_sort_indices
)
else
:
# Reset g_idx related tensors
layer
.
g_idx
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
layer
.
g_idx_sort_indices
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
# Repack weights
marlin_qweight
=
ops
.
gptq_marlin_repack
(
layer
.
qweight
,
layer
.
g_idx_sort_indices
,
part_size_k
,
part_size_n
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"qweight"
,
marlin_qweight
)
# Permute scales
scales_size_k
=
part_size_k
scales_size_n
=
part_size_n
if
self
.
quant_config
.
desc_act
:
scales_size_k
=
full_size_k
marlin_scales
=
marlin_permute_scales
(
layer
.
scales
,
scales_size_k
,
scales_size_n
,
self
.
quant_config
.
group_size
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"scales"
,
marlin_scales
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
layer
.
qweight
,
layer
.
scales
,
layer
.
g_idx
,
layer
.
g_idx_sort_indices
,
layer
.
workspace
,
self
.
quant_config
.
weight_bits
,
size_m
,
part_size_n
,
part_size_k
,
layer
.
is_k_full
,
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
vllm/model_executor/layers/quantization/marlin.py
View file @
1591c68f
...
@@ -4,10 +4,10 @@ import torch
...
@@ -4,10 +4,10 @@ import torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
class
MarlinConfig
(
QuantizationConfig
):
class
MarlinConfig
(
QuantizationConfig
):
...
@@ -72,8 +72,11 @@ class MarlinConfig(QuantizationConfig):
...
@@ -72,8 +72,11 @@ class MarlinConfig(QuantizationConfig):
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
return
cls
(
group_size
)
return
cls
(
group_size
)
def
get_linear_method
(
self
)
->
"MarlinLinearMethod"
:
def
get_quant_method
(
return
MarlinLinearMethod
(
self
)
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"MarlinLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
MarlinLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
return
[]
...
@@ -197,7 +200,7 @@ class MarlinLinearMethod(LinearMethodBase):
...
@@ -197,7 +200,7 @@ class MarlinLinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"workspace"
,
workspace
)
layer
.
register_parameter
(
"workspace"
,
workspace
)
set_weight_attrs
(
workspace
,
extra_weight_attrs
)
set_weight_attrs
(
workspace
,
extra_weight_attrs
)
def
apply
_weights
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/squeezellm.py
View file @
1591c68f
...
@@ -4,10 +4,10 @@ import torch
...
@@ -4,10 +4,10 @@ import torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
LinearBase
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
is_hip
from
vllm.utils
import
is_hip
...
@@ -51,14 +51,17 @@ class SqueezeLLMConfig(QuantizationConfig):
...
@@ -51,14 +51,17 @@ class SqueezeLLMConfig(QuantizationConfig):
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"wbits"
])
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"wbits"
])
return
cls
(
weight_bits
)
return
cls
(
weight_bits
)
def
get_linear_method
(
self
)
->
"SqueezeLLMLinearMethod"
:
def
get_quant_method
(
return
SqueezeLLMLinearMethod
(
self
)
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
QuantizeMethodBase
]:
if
isinstance
(
layer
,
LinearBase
):
return
SqueezeLLMLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
return
[]
class
SqueezeLLMLinearMethod
(
Linear
MethodBase
):
class
SqueezeLLMLinearMethod
(
Quantize
MethodBase
):
"""Linear method for SqueezeLLM.
"""Linear method for SqueezeLLM.
Args:
Args:
...
@@ -112,10 +115,10 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
...
@@ -112,10 +115,10 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"lookup_table"
,
lookup_table
)
layer
.
register_parameter
(
"lookup_table"
,
lookup_table
)
set_weight_attrs
(
lookup_table
,
extra_weight_attrs
)
set_weight_attrs
(
lookup_table
,
extra_weight_attrs
)
def
apply
_weights
(
self
,
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
qweight
=
layer
.
qweight
qweight
=
layer
.
qweight
lookup_table
=
layer
.
lookup_table
lookup_table
=
layer
.
lookup_table
out_shape
=
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
-
1
],
)
out_shape
=
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
-
1
],
)
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
1591c68f
...
@@ -156,6 +156,12 @@ class RotaryEmbedding(nn.Module):
...
@@ -156,6 +156,12 @@ class RotaryEmbedding(nn.Module):
self
.
cos_sin_cache
,
self
.
is_neox_style
)
self
.
cos_sin_cache
,
self
.
is_neox_style
)
return
query
,
key
return
query
,
key
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
head_size
}
, rotary_dim=
{
self
.
rotary_dim
}
"
s
+=
f
", max_position_embeddings=
{
self
.
max_position_embeddings
}
"
s
+=
f
", base=
{
self
.
base
}
, is_neox_style=
{
self
.
is_neox_style
}
"
return
s
class
LinearScalingRotaryEmbedding
(
RotaryEmbedding
):
class
LinearScalingRotaryEmbedding
(
RotaryEmbedding
):
"""RotaryEmbedding extended with linear scaling.
"""RotaryEmbedding extended with linear scaling.
...
@@ -338,6 +344,114 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
...
@@ -338,6 +344,114 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
return
cache
return
cache
class
Phi3SuScaledRotaryEmbedding
(
nn
.
Module
):
"""Phi3 family of models scaled rotary embedding.
Based on the original RotaryEmbedding implementation.
"""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
original_max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
short_factor
:
List
[
float
],
long_factor
:
List
[
float
],
short_mscale
:
float
=
1.1
,
long_mscale
:
float
=
1.225
,
):
super
().
__init__
()
if
rotary_dim
!=
head_size
:
raise
ValueError
(
f
"`Phi3SuScaledRotaryEmbedding` does not support rotary_dim !=
\
head_size (
{
rotary_dim
}
!=
{
head_size
}
)."
)
if
is_neox_style
is
False
:
raise
ValueError
(
"`Phi3SuScaledRotaryEmbedding` only supports neox_style."
)
self
.
head_size
=
head_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
original_max_position_embeddings
=
original_max_position_embeddings
self
.
base
=
base
self
.
short_factor
=
short_factor
self
.
long_factor
=
long_factor
self
.
short_mscale
=
short_mscale
self
.
long_mscale
=
long_mscale
short_cache
=
self
.
_compute_cos_sin_cache
(
original_max_position_embeddings
,
short_factor
,
short_mscale
)
short_cache
=
short_cache
.
to
(
torch
.
get_default_dtype
())
self
.
register_buffer
(
"short_cos_sin_cache"
,
short_cache
,
persistent
=
False
)
long_cache
=
self
.
_compute_cos_sin_cache
(
max_position_embeddings
,
long_factor
,
long_mscale
)
long_cache
=
long_cache
.
to
(
torch
.
get_default_dtype
())
self
.
register_buffer
(
"long_cos_sin_cache"
,
long_cache
,
persistent
=
False
)
long_short_cache
=
torch
.
cat
(
[
self
.
short_cos_sin_cache
,
self
.
long_cos_sin_cache
],
dim
=
0
)
self
.
register_buffer
(
"long_short_cos_sin_cache"
,
long_short_cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
rescale_factors
:
List
[
float
])
->
torch
.
Tensor
:
rescale_factors
=
torch
.
tensor
(
rescale_factors
,
dtype
=
torch
.
float32
)
inv_freq
=
1.0
/
(
rescale_factors
*
(
self
.
base
**
(
torch
.
arange
(
0
,
self
.
head_size
,
2
,
dtype
=
torch
.
float
)
/
self
.
head_size
)))
return
inv_freq
def
_compute_cos_sin_cache
(
self
,
max_position_embeddings
:
int
,
rescale_factors
:
List
[
float
],
mscale
:
float
,
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
rescale_factors
)
t
=
torch
.
arange
(
max_position_embeddings
,
dtype
=
torch
.
float
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
*
mscale
sin
=
freqs
.
sin
()
*
mscale
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
def
forward
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
query
=
query
.
view
(
*
query
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
key
=
key
.
view
(
*
key
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
k
=
self
.
original_max_position_embeddings
long_prompt_offset
=
(
torch
.
any
(
positions
>
k
).
float
()
*
torch
.
full_like
(
positions
,
k
)).
long
()
idx
=
(
torch
.
add
(
positions
,
long_prompt_offset
)
if
long_prompt_offset
is
not
None
else
positions
)
self
.
long_short_cos_sin_cache
:
torch
.
Tensor
=
(
self
.
long_short_cos_sin_cache
.
to
(
idx
.
device
))
idx
=
torch
.
add
(
idx
,
offsets
)
if
offsets
is
not
None
else
idx
cos_sin
=
torch
.
index_select
(
self
.
long_short_cos_sin_cache
,
0
,
idx
)
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
cos
=
cos
.
repeat
(
1
,
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
2
).
unsqueeze
(
-
2
)
query
=
query
*
cos
+
_rotate_neox
(
query
)
*
sin
key
=
key
*
cos
+
_rotate_neox
(
key
)
*
sin
return
query
.
flatten
(
-
2
),
key
.
flatten
(
-
2
)
_ROPE_DICT
:
Dict
[
Tuple
,
RotaryEmbedding
]
=
{}
_ROPE_DICT
:
Dict
[
Tuple
,
RotaryEmbedding
]
=
{}
...
@@ -349,17 +463,26 @@ def get_rope(
...
@@ -349,17 +463,26 @@ def get_rope(
is_neox_style
:
bool
=
True
,
is_neox_style
:
bool
=
True
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
RotaryEmbedding
:
)
->
RotaryEmbedding
:
if
rope_scaling
is
not
None
:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple
=
{
k
:
tuple
(
v
)
if
isinstance
(
v
,
list
)
else
v
for
k
,
v
in
rope_scaling
.
items
()
}
rope_scaling_args
=
tuple
(
rope_scaling_tuple
.
items
())
else
:
rope_scaling_args
=
None
key
=
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
key
=
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
tuple
(
rope_scaling
.
items
())
if
rope_scaling
is
not
None
else
None
)
rope_scaling
_args
)
if
key
in
_ROPE_DICT
:
if
key
in
_ROPE_DICT
:
return
_ROPE_DICT
[
key
]
return
_ROPE_DICT
[
key
]
if
rope_scaling
is
None
:
if
rope_scaling
is
None
:
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
)
is_neox_style
)
else
:
else
:
scaling_type
=
rope_scaling
[
"type"
]
scaling_type
=
rope_scaling
[
"type"
]
scaling_factor
=
rope_scaling
[
"factor"
]
if
scaling_type
!=
"su"
:
scaling_factor
=
rope_scaling
[
"factor"
]
if
scaling_type
==
"linear"
:
if
scaling_type
==
"linear"
:
rotary_emb
=
LinearScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
rotary_emb
=
LinearScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
max_position
,
base
,
...
@@ -383,6 +506,19 @@ def get_rope(
...
@@ -383,6 +506,19 @@ def get_rope(
base
,
is_neox_style
,
base
,
is_neox_style
,
scaling_factor
,
scaling_factor
,
**
extra_kwargs
)
**
extra_kwargs
)
elif
scaling_type
==
"su"
:
short_factor
=
rope_scaling
[
"short_factor"
]
long_factor
=
rope_scaling
[
"long_factor"
]
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
extra_kwargs
=
{
k
:
v
for
k
,
v
in
rope_scaling
.
items
()
if
k
in
(
"short_mscale"
,
"long_mscale"
)
}
rotary_emb
=
Phi3SuScaledRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
original_max_position
,
base
,
is_neox_style
,
short_factor
,
long_factor
,
**
extra_kwargs
)
else
:
else
:
raise
ValueError
(
f
"Unknown RoPE scaling type
{
scaling_type
}
"
)
raise
ValueError
(
f
"Unknown RoPE scaling type
{
scaling_type
}
"
)
_ROPE_DICT
[
key
]
=
rotary_emb
_ROPE_DICT
[
key
]
=
rotary_emb
...
...
vllm/model_executor/layers/sampler.py
View file @
1591c68f
...
@@ -7,11 +7,14 @@ import torch.nn as nn
...
@@ -7,11 +7,14 @@ import torch.nn as nn
from
vllm.model_executor.layers.ops.sample
import
sample
as
sample_triton
from
vllm.model_executor.layers.ops.sample
import
sample
as
sample_triton
from
vllm.model_executor.sampling_metadata
import
(
SamplingMetadata
,
from
vllm.model_executor.sampling_metadata
import
(
SamplingMetadata
,
SamplingTensors
)
SamplingTensors
,
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
SequenceGroupToSample
)
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
(
Logprob
,
PromptLogprobs
,
SampleLogprobs
,
from
vllm.sequence
import
(
Logprob
,
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
SequenceData
,
SequenceGroupOutput
,
SamplerOutput
,
SequenceGroupOutput
,
SequenceOutput
)
SequenceOutput
)
# (num_token_ids, num_parent_ids) per sequence group.
SampleResultType
=
List
[
Tuple
[
List
[
int
],
List
[
int
]]]
class
Sampler
(
nn
.
Module
):
class
Sampler
(
nn
.
Module
):
...
@@ -48,11 +51,14 @@ class Sampler(nn.Module):
...
@@ -48,11 +51,14 @@ class Sampler(nn.Module):
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
"""
Args:
logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling.
"""
assert
logits
is
not
None
assert
logits
is
not
None
_
,
vocab_size
=
logits
.
shape
_
,
vocab_size
=
logits
.
shape
# Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
# have not been generated yet
logits
=
_apply_min_tokens_penalty
(
logits
,
sampling_metadata
)
logits
=
_apply_min_tokens_penalty
(
logits
,
sampling_metadata
)
# Prepare sampling tensors with pinned memory to avoid blocking.
# Prepare sampling tensors with pinned memory to avoid blocking.
...
@@ -83,7 +89,6 @@ class Sampler(nn.Module):
...
@@ -83,7 +89,6 @@ class Sampler(nn.Module):
# Compute the probabilities.
# Compute the probabilities.
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
# Compute the log probabilities.
# Compute the log probabilities.
# Use log_softmax to ensure numerical stability.
logprobs
=
torch
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
logprobs
=
torch
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
# Sample the next tokens.
# Sample the next tokens.
...
@@ -98,8 +103,7 @@ class Sampler(nn.Module):
...
@@ -98,8 +103,7 @@ class Sampler(nn.Module):
if
self
.
include_gpu_probs_tensor
:
if
self
.
include_gpu_probs_tensor
:
assert
maybe_sampled_tokens_tensor
is
not
None
assert
maybe_sampled_tokens_tensor
is
not
None
sampled_tokens_tensor
=
maybe_sampled_tokens_tensor
on_device_tensors
=
(
probs
,
logprobs
,
maybe_sampled_tokens_tensor
)
on_device_tensors
=
(
probs
,
sampled_tokens_tensor
)
else
:
else
:
on_device_tensors
=
None
on_device_tensors
=
None
...
@@ -149,46 +153,46 @@ def _apply_min_tokens_penalty(
...
@@ -149,46 +153,46 @@ def _apply_min_tokens_penalty(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
have not been generated yet
"""
# list of indices in logits that will be set to -inf
# list of indices in logits that will be set to -inf
logits_to_penalize
=
[]
logits_to_penalize
:
List
[
Tuple
[
int
,
int
]]
=
[]
start_idx
=
0
logits_applied
=
0
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
)
:
for
seq_group
in
sampling_metadata
.
seq_groups
:
seq_ids
,
sampling_params
=
seq_group
seq_ids
=
seq_group
.
seq_ids
sampling_params
=
seq_group
.
sampling_params
# handle prompt_logprobs by skipping rows in logits added for the prompt
# tokens (prompt logprobs are not penalized)
sample_indices
=
seq_group
.
sample_indices
if
(
i
<
sampling_metadata
.
num_prompts
logits_applied
+=
len
(
sample_indices
)
+
len
(
and
sampling_params
.
prompt_logprobs
is
not
None
):
seq_group
.
prompt_logprob_indices
)
assert
len
(
seq_ids
)
==
1
if
not
seq_group
.
do_sample
:
start_idx
+=
sampling_metadata
.
prompt_lens
[
i
]
-
1
continue
start_idx
=
sample_indices
[
0
]
min_tokens
=
sampling_params
.
min_tokens
min_tokens
=
sampling_params
.
min_tokens
if
min_tokens
>
0
:
token_ids_to_penalize
=
sampling_params
.
all_stop_token_ids
if
min_tokens
>
0
and
token_ids_to_penalize
:
seqs_to_penalize
=
[]
seqs_to_penalize
=
[]
for
i
,
seq_id
in
enumerate
(
seq_ids
):
for
j
,
seq_id
in
enumerate
(
seq_ids
):
seq_data
=
s
ampling_metadata
.
seq_data
[
seq_id
]
seq_data
=
s
eq_group
.
seq_data
[
seq_id
]
if
len
(
seq_data
.
output_token_ids
)
<
min_tokens
:
if
len
(
seq_data
.
output_token_ids
)
<
min_tokens
:
seqs_to_penalize
.
append
(
i
)
seqs_to_penalize
.
append
(
j
)
if
seqs_to_penalize
:
if
seqs_to_penalize
:
# convert to the index into logits
# convert to the index into logits
seqs_to_penalize
=
[
start_idx
+
i
for
i
in
seqs_to_penalize
]
seqs_to_penalize
=
[
start_idx
+
j
for
j
in
seqs_to_penalize
]
# use set() to remove any duplicates
token_ids_to_penalize
=
set
(
sampling_params
.
stop_token_ids
+
[
sampling_params
.
eos_token_id
])
# itertools.product pairs each seq index with every token id
# itertools.product pairs each seq index with every token id
logits_to_penalize
.
extend
(
logits_to_penalize
.
extend
(
itertools
.
product
(
seqs_to_penalize
,
token_ids_to_penalize
))
itertools
.
product
(
seqs_to_penalize
,
token_ids_to_penalize
))
start_idx
+=
len
(
seq_ids
)
if
logits_to_penalize
:
if
logits_to_penalize
:
# use zip and * to group indices along each dimension
# use zip and * to group indices along each dimension
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
logits
[
tuple
(
zip
(
*
logits_to_penalize
))]
=
-
float
(
"inf"
)
logits
[
tuple
(
zip
(
*
logits_to_penalize
))]
=
-
float
(
"inf"
)
# verifies that no rows in logits were missed unexpectedly
# verifies that no rows in logits were missed unexpectedly
assert
start_idx
==
logits
.
shape
[
0
]
assert
logits_applied
==
logits
.
shape
[
0
]
return
logits
return
logits
...
@@ -265,14 +269,30 @@ def _apply_min_p(
...
@@ -265,14 +269,30 @@ def _apply_min_p(
def
_greedy_sample
(
def
_greedy_sample
(
selected_seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]
],
selected_seq_groups
:
List
[
SequenceGroupToSample
],
samples
:
torch
.
Tensor
,
samples
:
torch
.
Tensor
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
)
->
SampleResultType
:
"""Run greedy sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
samples: (num_selected_samples,) A tensor of samples. The length of
samples could be smaller than selected_seq_groups if
seq_group.do_sample is False.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
samples
=
samples
.
tolist
()
samples
=
samples
.
tolist
()
sample_idx
=
0
sample_idx
=
0
results
=
[]
results
:
SampleResultType
=
[]
for
seq_group
in
selected_seq_groups
:
for
seq_group
in
selected_seq_groups
:
seq_ids
,
_
=
seq_group
if
not
seq_group
.
do_sample
:
results
.
append
(([],
[]))
continue
seq_ids
=
seq_group
.
seq_ids
num_parent_seqs
=
len
(
seq_ids
)
num_parent_seqs
=
len
(
seq_ids
)
assert
num_parent_seqs
==
1
,
(
assert
num_parent_seqs
==
1
,
(
"Greedy sampling should have only one seq."
)
"Greedy sampling should have only one seq."
)
...
@@ -284,16 +304,33 @@ def _greedy_sample(
...
@@ -284,16 +304,33 @@ def _greedy_sample(
def
_random_sample
(
def
_random_sample
(
selected_seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
selected_seq_groups
:
List
[
SequenceGroupToSample
],
is_prompts
:
List
[
bool
],
random_samples
:
torch
.
Tensor
,
random_samples
:
torch
.
Tensor
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
)
->
SampleResultType
:
"""Run random sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
random_samples: (num_selected_samples,) A tensor of samples. The
length of samples could be smaller than selected_seq_groups if
seq_group.do_sample is False.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
# Find the maximum best_of value of the prompt phase requests.
# Find the maximum best_of value of the prompt phase requests.
random_samples
=
random_samples
.
cpu
()
random_samples
=
random_samples
.
cpu
()
sample_idx
=
0
sample_idx
=
0
results
=
[]
results
:
SampleResultType
=
[]
for
seq_group
,
is_prompt
in
zip
(
selected_seq_groups
,
is_prompts
):
for
seq_group
in
selected_seq_groups
:
seq_ids
,
sampling_params
=
seq_group
if
not
seq_group
.
do_sample
:
results
.
append
(([],
[]))
continue
seq_ids
=
seq_group
.
seq_ids
sampling_params
=
seq_group
.
sampling_params
is_prompt
=
seq_group
.
is_prompt
num_parent_seqs
=
len
(
seq_ids
)
num_parent_seqs
=
len
(
seq_ids
)
if
is_prompt
:
if
is_prompt
:
# Prompt phase.
# Prompt phase.
...
@@ -311,11 +348,20 @@ def _random_sample(
...
@@ -311,11 +348,20 @@ def _random_sample(
def
_beam_search_sample
(
def
_beam_search_sample
(
selected_seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
selected_seq_groups
:
List
[
SequenceGroupToSample
],
is_prompts
:
List
[
bool
],
seq_data
:
Dict
[
int
,
SequenceData
],
logprobs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
)
->
SampleResultType
:
"""Run beam sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
on selected sample indices.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
# We sample 2 * beam_width candidates to make sure that with high
# We sample 2 * beam_width candidates to make sure that with high
# probability we can get `beam_width` candidates in addition to
# probability we can get `beam_width` candidates in addition to
# the finished sequences for the next iteration. See
# the finished sequences for the next iteration. See
...
@@ -326,9 +372,14 @@ def _beam_search_sample(
...
@@ -326,9 +372,14 @@ def _beam_search_sample(
# NOTE: Beam search is not vectorized, so its speed can be slower than
# NOTE: Beam search is not vectorized, so its speed can be slower than
# other sampling methods.
# other sampling methods.
sample_idx
=
0
sample_idx
=
0
results
=
[]
results
:
SampleResultType
=
[]
for
seq_group
,
is_prompt
in
zip
(
selected_seq_groups
,
is_prompts
):
for
seq_group
in
selected_seq_groups
:
seq_ids
,
sampling_params
=
seq_group
if
not
seq_group
.
do_sample
:
results
.
append
(([],
[]))
continue
is_prompt
=
seq_group
.
is_prompt
seq_ids
,
sampling_params
=
seq_group
.
seq_ids
,
seq_group
.
sampling_params
num_parent_seqs
=
len
(
seq_ids
)
num_parent_seqs
=
len
(
seq_ids
)
beam_width
=
sampling_params
.
best_of
beam_width
=
sampling_params
.
best_of
seq_group_logprobs
=
logprobs
[
sample_idx
:
sample_idx
+
num_parent_seqs
]
seq_group_logprobs
=
logprobs
[
sample_idx
:
sample_idx
+
num_parent_seqs
]
...
@@ -342,15 +393,16 @@ def _beam_search_sample(
...
@@ -342,15 +393,16 @@ def _beam_search_sample(
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
else
:
else
:
# Generation phase.
# Generation phase.
cumulative_logprobs
=
[
cumulative_logprobs
:
List
[
int
]
=
[
seq_data
[
seq_id
].
cumulative_logprob
for
seq_id
in
seq_ids
seq_group
.
seq_data
[
seq_id
].
cumulative_logprob
for
seq_id
in
seq_ids
]
]
cumulative_logprobs
=
torch
.
tensor
(
cumulative_logprobs
_tensor
=
torch
.
tensor
(
cumulative_logprobs
,
cumulative_logprobs
,
dtype
=
torch
.
float
,
dtype
=
torch
.
float
,
device
=
seq_group_logprobs
.
device
)
device
=
seq_group_logprobs
.
device
)
seq_group_logprobs
=
(
seq_group_logprobs
+
seq_group_logprobs
=
(
seq_group_logprobs
+
cumulative_logprobs
.
unsqueeze
(
dim
=
1
))
cumulative_logprobs
_tensor
.
unsqueeze
(
dim
=
1
))
_
,
topk_ids
=
torch
.
topk
(
seq_group_logprobs
.
flatten
(),
_
,
topk_ids
=
torch
.
topk
(
seq_group_logprobs
.
flatten
(),
2
*
beam_width
)
2
*
beam_width
)
topk_ids
=
topk_ids
.
tolist
()
topk_ids
=
topk_ids
.
tolist
()
...
@@ -371,8 +423,7 @@ def _beam_search_sample(
...
@@ -371,8 +423,7 @@ def _beam_search_sample(
def
_multinomial
(
def
_multinomial
(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
num_samples
:
int
,
num_samples
:
int
,
seq_groups
:
Optional
[
List
[
Tuple
[
List
[
int
],
SamplingParams
]]]
=
None
,
seq_groups
:
Optional
[
List
[
SequenceGroupToSample
]]
=
None
,
generators
:
Optional
[
List
[
torch
.
Generator
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
num_samples
>
1
:
if
num_samples
>
1
:
# This is equivalent to torch.repeat_interleaved (which also
# This is equivalent to torch.repeat_interleaved (which also
...
@@ -388,9 +439,11 @@ def _multinomial(
...
@@ -388,9 +439,11 @@ def _multinomial(
q
.
exponential_
()
q
.
exponential_
()
else
:
else
:
sample_idx
=
0
sample_idx
=
0
for
(
seq_ids
,
_
),
generator
in
zip
(
seq_groups
,
generators
):
for
seq_group
in
seq_groups
:
seq_ids
=
seq_group
.
seq_ids
next_sample_idx
=
sample_idx
+
len
(
seq_ids
)
*
num_samples
next_sample_idx
=
sample_idx
+
len
(
seq_ids
)
*
num_samples
q
[
sample_idx
:
next_sample_idx
].
exponential_
(
generator
=
generator
)
q
[
sample_idx
:
next_sample_idx
].
exponential_
(
generator
=
seq_group
.
generator
)
sample_idx
=
next_sample_idx
sample_idx
=
next_sample_idx
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
...
@@ -401,11 +454,13 @@ def _sample_with_torch(
...
@@ -401,11 +454,13 @@ def _sample_with_torch(
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
include_gpu_probs_tensor
:
bool
,
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
,
modify_greedy_probs
:
bool
,
)
->
Tuple
[
List
[
Tuple
[
List
[
int
],
List
[
int
]]],
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
SampleResultType
,
Optional
[
torch
.
Tensor
]]:
categorized_seq_group_ids
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_seq_group_ids
:
Dict
[
SamplingType
,
List
[
int
]]
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
_
,
sampling_params
=
seq_group
sampling_params
=
seq_group
.
sampling_params
sampling_type
=
sampling_params
.
sampling_type
sampling_type
=
sampling_params
.
sampling_type
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
...
@@ -429,13 +484,11 @@ def _sample_with_torch(
...
@@ -429,13 +484,11 @@ def _sample_with_torch(
num_tokens
=
len
(
sample_indices
)
num_tokens
=
len
(
sample_indices
)
if
num_tokens
==
0
:
if
num_tokens
==
0
:
continue
continue
seq_group_ids
=
categorized_seq_group_ids
[
sampling_type
]
seq_groups
=
[
sampling_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_ids
]
is_prompts
=
[
i
<
sampling_metadata
.
num_prompts
for
i
in
seq_group_ids
]
sample_metadata
[
sampling_type
]
=
(
seq_group_ids
,
seq_groups
,
is_prompts
,
sample_indices
)
long_sample_indices
=
sample_indices
.
long
()
seq_group_id
=
categorized_seq_group_ids
[
sampling_type
]
seq_groups
=
[
sampling_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_id
]
sample_metadata
[
sampling_type
]
=
(
seq_group_id
,
seq_groups
)
long_sample_indices
=
sample_indices
.
long
()
if
sampling_type
==
SamplingType
.
GREEDY
:
if
sampling_type
==
SamplingType
.
GREEDY
:
greedy_samples
=
torch
.
argmax
(
logprobs
[
long_sample_indices
],
greedy_samples
=
torch
.
argmax
(
logprobs
[
long_sample_indices
],
dim
=-
1
)
dim
=-
1
)
...
@@ -455,14 +508,13 @@ def _sample_with_torch(
...
@@ -455,14 +508,13 @@ def _sample_with_torch(
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
max_best_of_in_batch
=
1
max_best_of_in_batch
=
1
for
seq_group
,
is_prompt
in
zip
(
seq_groups
,
is_prompts
)
:
for
seq_group
in
seq_groups
:
if
is_prompt
:
if
seq_group
.
is_prompt
:
_
,
sampling_params
=
seq_group
sampling_params
=
seq_group
.
sampling_params
max_best_of_in_batch
=
max
(
max_best_of_in_batch
,
max_best_of_in_batch
=
max
(
max_best_of_in_batch
,
sampling_params
.
best_of
)
sampling_params
.
best_of
)
seeded_args
=
{}
if
sampling_type
==
SamplingType
.
RANDOM
else
{
seeded_args
=
{}
if
sampling_type
==
SamplingType
.
RANDOM
else
{
"seq_groups"
:
seq_groups
,
"seq_groups"
:
seq_groups
,
"generators"
:
sampling_metadata
.
generators
,
}
}
multinomial_samples
[
sampling_type
]
=
_multinomial
(
multinomial_samples
[
sampling_type
]
=
_multinomial
(
...
@@ -481,25 +533,22 @@ def _sample_with_torch(
...
@@ -481,25 +533,22 @@ def _sample_with_torch(
# GPU<->CPU sync happens in the loop below.
# GPU<->CPU sync happens in the loop below.
# This also converts the sample output to Python objects.
# This also converts the sample output to Python objects.
for
sampling_type
in
SamplingType
:
for
sampling_type
in
SamplingType
:
if
sampling_type
not
in
sample_metadata
:
if
sampling_type
not
in
sample_metadata
:
continue
continue
seq_group_ids
,
seq_groups
,
is_prompts
,
sample_indices
=
sample_metadata
[
(
seq_group_id
,
seq_groups
)
=
sample_metadata
[
sampling_type
]
sampling_type
]
if
sampling_type
==
SamplingType
.
GREEDY
:
if
sampling_type
==
SamplingType
.
GREEDY
:
sample_results
=
_greedy_sample
(
seq_groups
,
greedy_samples
)
sample_results
=
_greedy_sample
(
seq_groups
,
greedy_samples
)
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
sample_results
=
_random_sample
(
seq_groups
,
is_prompts
,
sample_results
=
_random_sample
(
seq_groups
,
multinomial_samples
[
sampling_type
])
multinomial_samples
[
sampling_type
])
elif
sampling_type
==
SamplingType
.
BEAM
:
elif
sampling_type
==
SamplingType
.
BEAM
:
sample_results
=
_beam_search_sample
(
seq_groups
,
is_prompts
,
sample_results
=
_beam_search_sample
(
seq_groups
,
sampling_metadata
.
seq_data
,
beam_search_logprobs
)
beam_search_logprobs
)
sample_results_dict
.
update
(
zip
(
seq_group_id
s
,
sample_results
))
sample_results_dict
.
update
(
zip
(
seq_group_id
,
sample_results
))
sample_results
=
[
sample_results
=
[
sample_results_dict
[
i
]
sample_results_dict
.
get
(
i
,
([],
[]))
for
i
in
range
(
len
(
sampling_metadata
.
seq_groups
))
for
i
in
range
(
len
(
sampling_metadata
.
seq_groups
))
]
]
return
sample_results
,
sampled_token_ids_tensor
return
sample_results
,
sampled_token_ids_tensor
...
@@ -510,11 +559,13 @@ def _sample_with_triton_kernel(
...
@@ -510,11 +559,13 @@ def _sample_with_triton_kernel(
logprobs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
sampling_tensors
:
SamplingTensors
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
)
->
SampleResultType
:
categorized_seq_group_ids
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_seq_group_ids
:
Dict
[
SamplingType
,
List
[
int
]]
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
_
,
sampling_params
=
seq_group
sampling_params
=
seq_group
.
sampling_params
sampling_type
=
sampling_params
.
sampling_type
sampling_type
=
sampling_params
.
sampling_type
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
...
@@ -530,17 +581,16 @@ def _sample_with_triton_kernel(
...
@@ -530,17 +581,16 @@ def _sample_with_triton_kernel(
num_tokens
=
len
(
sample_indices
)
num_tokens
=
len
(
sample_indices
)
if
num_tokens
==
0
:
if
num_tokens
==
0
:
continue
continue
seq_group_ids
=
categorized_seq_group_ids
[
sampling_type
]
seq_group_id
=
categorized_seq_group_ids
[
sampling_type
]
seq_groups
=
[
sampling_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_ids
]
seq_groups
=
[
sampling_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_id
]
is_prompts
=
[
i
<
sampling_metadata
.
num_prompts
for
i
in
seq_group_ids
]
sample_metadata
[
sampling_type
]
=
(
seq_group_id
,
seq_groups
,
sample_metadata
[
sampling_type
]
=
(
seq_group_ids
,
seq_groups
,
sample_indices
,
is_prompts
,
sample_indices
,
sampled_token_indices
)
sampled_token_indices
)
if
sampling_type
in
(
SamplingType
.
GREEDY
,
SamplingType
.
RANDOM
,
if
sampling_type
in
(
SamplingType
.
GREEDY
,
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
SamplingType
.
RANDOM_SEED
):
for
seq_group
,
is_prompt
in
zip
(
seq_groups
,
is_prompts
)
:
for
seq_group
in
seq_groups
:
if
is_prompt
:
if
seq_group
.
is_prompt
:
_
,
sampling_params
=
seq_group
sampling_params
=
seq_group
.
sampling_params
max_best_of_in_batch
=
max
(
max_best_of_in_batch
,
max_best_of_in_batch
=
max
(
max_best_of_in_batch
,
sampling_params
.
best_of
)
sampling_params
.
best_of
)
elif
sampling_type
==
SamplingType
.
BEAM
:
elif
sampling_type
==
SamplingType
.
BEAM
:
...
@@ -564,22 +614,21 @@ def _sample_with_triton_kernel(
...
@@ -564,22 +614,21 @@ def _sample_with_triton_kernel(
for
sampling_type
in
SamplingType
:
for
sampling_type
in
SamplingType
:
if
sampling_type
not
in
sample_metadata
:
if
sampling_type
not
in
sample_metadata
:
continue
continue
(
seq_group_id
s
,
seq_groups
,
is_prompts
,
sample_indices
,
(
seq_group_id
,
seq_groups
,
sample_indices
,
sampled_token_indices
)
=
sample_metadata
[
sampling_type
]
sampled_token_indices
)
=
sample_metadata
[
sampling_type
]
if
sampling_type
==
SamplingType
.
GREEDY
:
if
sampling_type
==
SamplingType
.
GREEDY
:
sample_results
=
_greedy_sample
(
sample_results
=
_greedy_sample
(
seq_groups
,
sampled_tokens
[
sampled_token_indices
][:,
0
])
seq_groups
,
sampled_tokens
[
sampled_token_indices
][:,
0
])
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
sample_results
=
_random_sample
(
sample_results
=
_random_sample
(
seq_groups
,
is_prompts
,
sampled_tokens
[
sampled_token_indices
])
seq_groups
,
sampled_tokens
[
sampled_token_indices
])
elif
sampling_type
==
SamplingType
.
BEAM
:
elif
sampling_type
==
SamplingType
.
BEAM
:
sample_results
=
_beam_search_sample
(
seq_groups
,
is_prompts
,
sample_results
=
_beam_search_sample
(
seq_groups
,
sampling_metadata
.
seq_data
,
beam_search_logprobs
)
beam_search_logprobs
)
sample_results_dict
.
update
(
zip
(
seq_group_id
s
,
sample_results
))
sample_results_dict
.
update
(
zip
(
seq_group_id
,
sample_results
))
sample_results
=
[
sample_results
=
[
sample_results_dict
[
i
]
sample_results_dict
.
get
(
i
,
([],
[]))
for
i
in
range
(
len
(
sampling_metadata
.
seq_groups
))
for
i
in
range
(
len
(
sampling_metadata
.
seq_groups
))
]
]
return
sample_results
return
sample_results
...
@@ -589,7 +638,19 @@ def _sample(
...
@@ -589,7 +638,19 @@ def _sample(
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
)
->
Tuple
[
List
[
Tuple
[
List
[
int
],
List
[
int
]]],
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
SampleResultType
,
Optional
[
torch
.
Tensor
]]:
"""
Args:
probs: (num_query_tokens_in_batch, num_vocab)
logprobs: (num_query_tokens_in_batch, num_vocab)
sampling_metadata: The metadata for a batch for sampling.
sampling_tensors: Tensors that include sampling related metadata.
Returns:
(next_token_ids, parent_seq_ids) for each seq group in a batch.
If sampling is skipped, it returns ([], [])
sampled_token_ids_tensor: A tensor of sampled token ids.
"""
return
_sample_with_torch
(
return
_sample_with_torch
(
probs
,
probs
,
logprobs
,
logprobs
,
...
@@ -625,57 +686,98 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
...
@@ -625,57 +686,98 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
def
_get_logprobs
(
def
_get_logprobs
(
logprobs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
sample_results
:
List
[
Tuple
[
List
[
int
],
List
[
int
]]],
sample_results
:
SampleResultType
,
)
->
Tuple
[
List
[
Optional
[
List
[
Optional
[
Dict
[
int
,
float
]]]]],
List
[
List
[
Dict
[
)
->
Tuple
[
List
[
Optional
[
PromptLogprobs
]],
List
[
SampleLogprobs
]]:
int
,
float
]]]]:
"""Return sample lobprobs and prompt logprobs.
# Prepare query indices
batched_logprobs_query_seq_indices
:
List
[
int
]
=
[]
The logic consists of 3 parts.
batched_logprobs_query_token_indices
:
List
[
int
]
=
[]
- Select indices to compute logprob from, ranks of token ids, and
# at least get one logprob for each token
the top k token ids from logprobs.
- Compute prompt logprobs if required.
- Compute sample logprobs if required.
Args:
logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
logprob per vocab. Sequence groups' query tokens are batched in a
single flattened tensor. For example, assuming there are N
seq groups, it is sorted by prefill tokens for seq_group_1 (if
prompt logprob is enabled), decode tokens for seq_group_1 (if
sampling is required), prefill tokens for seq_group_2, ...
sampling_metadata: The sampling metadata.
sample_results: (num_seq_groups) The tuple of (next_token_ids,
parent_ids) for each sequence group. When beam search is enabled,
sample_results can contain different number of seq_ids from
sampling_metadata.seq_groups. It is because beam search creates
2 * BEAM_WIDTH number of samples (whereas there are only up to
BEAM_WIDTH number of seq_ids).
Returns:
A tuple of prompt and sample logprobs per sequence group in a batch.
"""
# The index of query token to calculate logprobs. It includes both
# prompt and sample logprob indices.
query_indices
:
List
[
int
]
=
[]
# The next token ids to get the logprob value from.
next_token_ids
:
List
[
int
]
=
[]
# The largest requested number of logprobs. We find logprobs as many as the
# largest num logprobs in this API.
largest_num_logprobs
=
1
largest_num_logprobs
=
1
sample_idx
=
0
for
i
,
(
seq_group
,
sample_result
)
in
enumerate
(
# Select indices to compute logprob from, ranks of token ids, and the top
zip
(
sampling_metadata
.
seq_groups
,
sample_results
)):
# k token ids from logprobs.
seq_ids
,
sampling_params
=
seq_group
for
(
seq_group
,
sample_result
)
in
zip
(
sampling_metadata
.
seq_groups
,
next_token_ids
,
parent_ids
=
sample_result
sample_results
):
num_parent_seqs
=
len
(
seq_ids
)
sampling_params
=
seq_group
.
sampling_params
if
(
i
<
sampling_metadata
.
num_prompts
# Update indices and tokens for prompt logprobs.
if
(
seq_group
.
is_prompt
and
sampling_params
.
prompt_logprobs
is
not
None
):
and
sampling_params
.
prompt_logprobs
is
not
None
):
largest_num_logprobs
=
max
(
largest_num_logprobs
,
largest_num_logprobs
=
max
(
largest_num_logprobs
,
sampling_params
.
prompt_logprobs
)
sampling_params
.
prompt_logprobs
)
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
next_prompt_tokens
=
_get_next_prompt_tokens
(
seq_group
)
prompt_tokens
=
sampling_metadata
.
seq_data
[
query_indices
.
extend
(
seq_group
.
prompt_logprob_indices
)
seq_ids
[
0
]].
prompt_token_ids
next_token_ids
.
extend
(
next_prompt_tokens
)
batched_logprobs_query_seq_indices
.
extend
(
sample_idx
+
j
for
j
in
range
(
prompt_len
-
1
))
# Update indices and next tokenes for sample logprob.
batched_logprobs_query_token_indices
.
extend
(
if
seq_group
.
do_sample
:
token_id
for
token_id
in
prompt_tokens
[
1
:])
token_ids
,
parent_seq_ids
=
sample_result
sample_idx
+=
prompt_len
-
1
# NOTE: We cannot directly use sample_indices because
batched_logprobs_query_seq_indices
.
extend
(
# sample_indices only contain parent seq_ids of a previous step.
[
sample_idx
+
parent_id
for
parent_id
in
parent_ids
])
# The current step may have different number of seq_ids, and
batched_logprobs_query_token_indices
.
extend
(
next_token_ids
)
# we can obtain it from `sample_result[1]`.
if
sampling_params
.
logprobs
is
not
None
:
query_idx
=
seq_group
.
sample_indices
[
0
]
largest_num_logprobs
=
max
(
largest_num_logprobs
,
query_indices
.
extend
(
sampling_params
.
logprobs
)
[
query_idx
+
parent_id
for
parent_id
in
parent_seq_ids
])
sample_idx
+=
num_parent_seqs
next_token_ids
.
extend
(
token_ids
)
assert
sample_idx
==
logprobs
.
size
(
0
)
if
sampling_params
.
logprobs
is
not
None
:
batched_logprobs_query_seq_indices_gpu
=
torch
.
tensor
(
largest_num_logprobs
=
max
(
largest_num_logprobs
,
batched_logprobs_query_seq_indices
,
device
=
logprobs
.
device
)
sampling_params
.
logprobs
)
batched_logprobs_query_token_indices_gpu
=
torch
.
tensor
(
batched_logprobs_query_token_indices
,
device
=
logprobs
.
device
)
assert
len
(
next_token_ids
)
==
len
(
query_indices
)
# Batched query for logprobs of selected token
if
len
(
query_indices
)
==
0
:
batched_logprobs_query_result
=
logprobs
[[
empty_sampled_logprob
:
SampleLogprobs
=
[]
batched_logprobs_query_seq_indices_gpu
,
empty_prompt_logprob
:
Optional
[
PromptLogprobs
]
=
None
batched_logprobs_query_token_indices_gpu
return
[
empty_prompt_logprob
],
[
empty_sampled_logprob
]
query_indices_gpu
=
torch
.
tensor
(
query_indices
,
device
=
logprobs
.
device
)
next_token_ids_gpu
=
torch
.
tensor
(
next_token_ids
,
device
=
logprobs
.
device
)
# (num_selected_query_tokens, num_logprobs). Note that query_indices can
# contain duplicates if beam search is enabled.
selected_logprobs
=
logprobs
[[
query_indices_gpu
,
next_token_ids_gpu
,
]]
]]
ranks
=
_get_ranks
(
logprobs
[
query_indices_gpu
],
next_token_ids_gpu
,
)
assert
selected_logprobs
.
shape
[
0
]
==
ranks
.
shape
[
0
]
batched_ranks_query_result
=
_get_ranks
(
# Logprobs of topk tokens for a batch of sequence groups.
logprobs
[
batched_logprobs_query_seq_indices_gpu
],
# (num_query_tokens_across_batch).
batched_logprobs_query_token_indices_gpu
)
# Batched query for logprobs of topk tokens
if
largest_num_logprobs
>
0
:
if
largest_num_logprobs
>
0
:
top_logprobs
,
top_token_ids
=
torch
.
topk
(
logprobs
,
top_logprobs
,
top_token_ids
=
torch
.
topk
(
logprobs
,
largest_num_logprobs
,
largest_num_logprobs
,
...
@@ -685,79 +787,136 @@ def _get_logprobs(
...
@@ -685,79 +787,136 @@ def _get_logprobs(
else
:
else
:
top_logprobs
,
top_token_ids
=
None
,
None
top_logprobs
,
top_token_ids
=
None
,
None
batched_logprobs_query_result
=
batched_logprobs_query_result
.
cpu
()
selected_logprobs
=
selected_logprobs
.
cpu
()
batched_ranks_query_result
=
batched_ranks_query_result
.
cpu
()
ranks
=
ranks
.
cpu
()
# Gather results
# Find prompt/sample logprobs.
result_prompt_logprobs
:
List
[
Optional
[
PromptLogprobs
]]
=
[]
prompt_logprobs_per_seq_group
:
List
[
Optional
[
PromptLogprobs
]]
=
[]
result_sample_logprobs
:
List
[
SampleLogprobs
]
=
[]
sample_logprobs_per_seq_group
:
List
[
SampleLogprobs
]
=
[]
sample_idx
=
0
top_logprob_idx
=
0
query_result_idx
=
0
selected_logprobs_idx
=
0
for
i
,
(
seq_group
,
sample_result
)
in
enumerate
(
zip
(
sampling_metadata
.
seq_groups
,
sample_results
)):
for
seq_group
,
sample_result
in
zip
(
sampling_metadata
.
seq_groups
,
seq_ids
,
sampling_params
=
seq_group
sample_results
):
next_token_ids
,
parent_ids
=
sample_result
(
prompt_logprobs
,
top_logprob_idx
,
selected_logprobs_idx
)
=
_get_prompt_logprob_if_needed
(
seq_group
,
selected_logprobs
,
ranks
,
top_token_ids
,
top_logprobs
,
selected_logprobs_idx
,
top_logprob_idx
)
prompt_logprobs_per_seq_group
.
append
(
prompt_logprobs
)
(
sampled_logprobs
,
top_logprob_idx
,
selected_logprobs_idx
)
=
_get_sampled_logprob_if_needed
(
seq_group
,
sample_result
,
selected_logprobs
,
ranks
,
top_token_ids
,
top_logprobs
,
selected_logprobs_idx
,
top_logprob_idx
)
sample_logprobs_per_seq_group
.
append
(
sampled_logprobs
)
return
prompt_logprobs_per_seq_group
,
sample_logprobs_per_seq_group
def
_get_prompt_logprob_if_needed
(
seq_group
:
SequenceGroupToSample
,
selected_logprobs
:
torch
.
Tensor
,
ranks
:
torch
.
Tensor
,
top_token_ids
:
torch
.
Tensor
,
top_logprobs
:
torch
.
Tensor
,
selected_logprobs_idx
:
int
,
top_logprob_idx
:
int
,
):
"""Compute the prompt logprob from a sequence group if needed."""
sampling_params
=
seq_group
.
sampling_params
is_prompt
=
seq_group
.
is_prompt
# Find prompt logprobs
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
if
(
is_prompt
and
sampling_params
.
prompt_logprobs
is
not
None
):
prompt_logprobs
=
[]
num_logprobs
=
sampling_params
.
prompt_logprobs
next_prompt_tokens
=
_get_next_prompt_tokens
(
seq_group
)
for
token_id
in
next_prompt_tokens
:
# Calculate the prompt logprob of the real prompt tokens.
# Use tuple here for performance (to use to_list()).
# {token_id: (logprob, rank_from_vocab)}
prompt_logprobs_dict
:
Dict
[
int
,
Tuple
[
float
,
int
]]
=
{
token_id
:
(
selected_logprobs
[
selected_logprobs_idx
].
item
(),
ranks
[
selected_logprobs_idx
].
item
())
}
# Prompt logprobs
# Add top K prompt logprobs along with its rank.
if
(
i
<
sampling_metadata
.
num_prompts
if
num_logprobs
>
0
:
and
sampling_params
.
prompt_logprobs
is
not
None
):
prompt_logprobs_dict
.
update
(
num_logprobs
=
sampling_params
.
prompt_logprobs
zip
(
prompt_tokens
=
sampling_metadata
.
seq_data
[
top_token_ids
[
top_logprob_idx
,
:
num_logprobs
].
tolist
(),
seq_ids
[
0
]].
prompt_token_ids
group_prompt_logprobs
:
PromptLogprobs
=
[
None
]
for
token_id
in
prompt_tokens
[
1
:]:
prompt_logprobs_dict
=
{
token_id
:
(
batched_logprobs_query_result
[
query_result_idx
].
item
(),
batched_ranks_query_result
[
query_result_idx
].
item
())
}
if
num_logprobs
>
0
:
prompt_logprobs_dict
.
update
(
zip
(
zip
(
top_token_ids
[
sample_idx
,
:
num_logprobs
].
tolist
(),
top_logprobs
[
zip
(
top_logprob_idx
,
:
num_logprobs
].
tolist
(),
top_logprobs
[
# This is ranks. Since top_logprob is sorted,
sample_idx
,
:
num_logprobs
].
tolist
(),
# we can just use a range here.
range
(
1
,
num_logprobs
+
1
))))
range
(
1
,
num_logprobs
+
1
))))
group_prompt_logprobs
.
append
({
prompt_logprobs
.
append
({
token_id
:
Logprob
(
*
logprob_rank
)
token_id
:
Logprob
(
*
logprob_and_rank
)
for
token_id
,
logprob_rank
in
prompt_logprobs_dict
.
items
()
for
token_id
,
logprob_and_rank
in
prompt_logprobs_dict
.
items
()
})
})
sample_idx
+=
1
# + 1 to go to the next prompt token.
query_result_idx
+=
1
top_logprob_idx
+=
1
result_prompt_logprobs
.
append
(
group_prompt_logprobs
)
selected_logprobs_idx
+=
1
else
:
return
prompt_logprobs
,
top_logprob_idx
,
selected_logprobs_idx
result_prompt_logprobs
.
append
(
None
)
# Sample logprobs
def
_get_sampled_logprob_if_needed
(
num_logprobs
=
sampling_params
.
logprobs
seq_group
:
SequenceGroupToSample
,
if
num_logprobs
is
None
:
sample_result
:
Tuple
[
List
[
int
],
List
[
int
]],
num_logprobs
=
0
selected_logprobs
:
torch
.
Tensor
,
group_sample_logprobs
:
SampleLogprobs
=
[]
ranks
:
torch
.
Tensor
,
for
next_token_id
,
parent_id
in
zip
(
next_token_ids
,
parent_ids
):
top_token_ids
:
torch
.
Tensor
,
sample_logprobs_dict
=
{
top_logprobs
:
torch
.
Tensor
,
selected_logprobs_idx
:
int
,
top_logprob_idx
:
int
,
):
"""Compute the sample logprob if needed."""
seq_ids
=
seq_group
.
seq_ids
num_logprobs
=
seq_group
.
sampling_params
.
logprobs
if
num_logprobs
is
None
:
num_logprobs
=
0
sampled_logprobs
:
SampleLogprobs
=
[]
next_token_ids
,
parent_seq_ids
=
sample_result
if
seq_group
.
do_sample
:
assert
len
(
next_token_ids
)
>
0
for
(
next_token_id
,
parent_id
)
in
zip
(
next_token_ids
,
parent_seq_ids
):
# Calculate the sample logprob of the real sampled tokens.
# Use tuple here for performance (to use to_list()).
# token_id: (logprob, rank_from_vocab)
sampled_logprobs_dict
:
Dict
[
int
,
Tuple
[
float
,
int
]]
=
{
next_token_id
:
next_token_id
:
(
batch
ed_logprobs
_query_result
[
query_result
_idx
].
item
(),
(
select
ed_logprobs
[
selected_logprobs
_idx
].
item
(),
batched_ranks_query_result
[
query_result
_idx
].
item
())
ranks
[
selected_logprobs
_idx
].
item
())
}
}
query_result_idx
+=
1
# +1 to go to the next sampled token. Note that
# selected_logprobs can contain duplicates unlike top_logprobs
# when beam search is enabled.
selected_logprobs_idx
+=
1
# Second, add top K logprobs along with its rank.
if
num_logprobs
>=
0
:
if
num_logprobs
>=
0
:
sample_logprobs_dict
.
update
(
sample
d
_logprobs_dict
.
update
(
zip
(
zip
(
top_token_ids
[
sample
_idx
+
top_token_ids
[
top_logprob
_idx
+
parent_id
,
:
num_logprobs
].
tolist
(),
parent_id
,
:
num_logprobs
].
tolist
(),
zip
(
zip
(
top_logprobs
[
sample
_idx
+
top_logprobs
[
top_logprob
_idx
+
parent_id
,
:
num_logprobs
].
tolist
(),
parent_id
,
:
num_logprobs
].
tolist
(),
# This is rank. Since top_logprob is sorted, we
# can just use a range here.
range
(
1
,
num_logprobs
+
1
))))
range
(
1
,
num_logprobs
+
1
))))
group_sample_logprobs
.
append
({
sampled_logprobs
.
append
({
token_id
:
Logprob
(
*
logprob_rank
)
token_id
:
Logprob
(
*
logprob_and_rank
)
for
token_id
,
logprob_rank
in
sample_logprobs_dict
.
items
()
for
token_id
,
logprob_and_rank
in
sampled_logprobs_dict
.
items
()
})
})
result_sample_logprobs
.
append
(
group_sample_logprobs
)
# There are len(seq_ids) number of sampled tokens for the current
sample_idx
+=
len
(
seq_ids
)
# sequence group in top_logprobs. Jump to the next seq_group.
top_logprob_idx
+=
len
(
seq_ids
)
return
result_prompt_logprobs
,
result_sample
_logprobs
return
sampled_logprobs
,
top_logprob_idx
,
selected
_logprobs
_idx
def
_modify_greedy_probs_inplace
(
logprobs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
def
_modify_greedy_probs_inplace
(
logprobs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
...
@@ -805,18 +964,18 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
...
@@ -805,18 +964,18 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
has implications on the overall design of the sampler, e.g. how to record
has implications on the overall design of the sampler, e.g. how to record
accurate logprobs for the user, so this improvement is deferred to later.
accurate logprobs for the user, so this improvement is deferred to later.
"""
"""
logprobs
[
sample_indices
,
:]
=
-
float
(
'inf'
)
# NOTE: logprobs are not modified so they can be returned to the user.
logprobs
[
sample_indices
,
greedy_samples
]
=
0.0
probs
[
sample_indices
,
:]
=
0
probs
[
sample_indices
,
:]
=
0
probs
[
sample_indices
,
greedy_samples
]
=
1.0
probs
[
sample_indices
,
greedy_samples
]
=
1.0
def
_build_sampler_output
(
def
_build_sampler_output
(
sample_results
:
List
[
Tuple
[
List
[
int
],
List
[
int
]]]
,
sample_results
:
SampleResultType
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
prompt_logprobs
:
List
[
Optional
[
PromptLogprobs
]],
prompt_logprobs
:
List
[
Optional
[
PromptLogprobs
]],
sample_logprobs
:
List
[
SampleLogprobs
],
sample_logprobs
:
List
[
SampleLogprobs
],
on_device_tensors
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
on_device_tensors
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
SamplerOutput
:
)
->
SamplerOutput
:
"""Construct Python objects with the output of sampling.
"""Construct Python objects with the output of sampling.
...
@@ -832,7 +991,7 @@ def _build_sampler_output(
...
@@ -832,7 +991,7 @@ def _build_sampler_output(
group_sample_logprobs
)
in
zip
(
sampling_metadata
.
seq_groups
,
group_sample_logprobs
)
in
zip
(
sampling_metadata
.
seq_groups
,
sample_results
,
prompt_logprobs
,
sample_results
,
prompt_logprobs
,
sample_logprobs
):
sample_logprobs
):
seq_ids
,
_
=
seq_group
seq_ids
=
seq_group
.
seq_ids
next_token_ids
,
parent_ids
=
sample_result
next_token_ids
,
parent_ids
=
sample_result
seq_outputs
=
[]
seq_outputs
=
[]
for
parent_id
,
next_token_id
,
logprobs
in
zip
(
parent_ids
,
for
parent_id
,
next_token_id
,
logprobs
in
zip
(
parent_ids
,
...
@@ -845,12 +1004,48 @@ def _build_sampler_output(
...
@@ -845,12 +1004,48 @@ def _build_sampler_output(
# If not specified, store None values in SamplerOutput.
# If not specified, store None values in SamplerOutput.
if
on_device_tensors
is
not
None
:
if
on_device_tensors
is
not
None
:
sampled_token_probs
,
sampled_token_ids
=
on_device_tensors
(
sampled_token_probs
,
logprobs_tensor
,
sampled_token_ids
)
=
on_device_tensors
else
:
else
:
sampled_token_probs
,
sampled_token_ids
=
(
None
,
None
)
sampled_token_probs
,
logprobs_tensor
,
sampled_token_ids
=
(
None
,
None
,
None
)
return
SamplerOutput
(
return
SamplerOutput
(
outputs
=
sampler_output
,
outputs
=
sampler_output
,
sampled_token_probs
=
sampled_token_probs
,
sampled_token_probs
=
sampled_token_probs
,
sampled_token_ids
=
sampled_token_ids
,
sampled_token_ids
=
sampled_token_ids
,
logprobs
=
logprobs_tensor
,
)
)
def
_get_next_prompt_tokens
(
seq_group
:
SequenceGroupToSample
)
->
List
[
int
]:
"""Get a list of next prompt tokens to compute logprob from a
given sequence group.
It is used to compute prompt logprob. Imagine you have logprob for each
query token. Query token needs to know the next prompt token id to compute
prompt logprob. This is a helper to obtain next prompt token ids.
This API has to be used only when the caller knows seq_group is in prefill
stage.
Returns:
A list of next prompt tokens to compute logprob.
"""
assert
seq_group
.
is_prompt
,
(
"Caller should ensure the sequence group is in a prefill stage."
)
seq_ids
=
seq_group
.
seq_ids
query_len
=
seq_group
.
query_len
assert
query_len
is
not
None
# prompt has only 1 seq id.
assert
len
(
seq_ids
)
==
1
seq_data
=
seq_group
.
seq_data
[
seq_ids
[
0
]]
computed_len
=
seq_data
.
get_num_computed_tokens
()
prompt_tokens
=
seq_data
.
prompt_token_ids
# +1 because we are looking for a next prompt token.
next_token_index_start
=
computed_len
+
1
next_token_index_end
=
min
(
computed_len
+
query_len
+
1
,
len
(
prompt_tokens
))
next_prompt_tokens
=
prompt_tokens
[
next_token_index_start
:
next_token_index_end
]
return
next_prompt_tokens
vllm/model_executor/layers/vocab_parallel_embedding.py
View file @
1591c68f
...
@@ -105,6 +105,14 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -105,6 +105,14 @@ class VocabParallelEmbedding(torch.nn.Module):
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
return
output
return
output
def
extra_repr
(
self
)
->
str
:
s
=
f
"num_embeddings=
{
self
.
num_embeddings_per_partition
}
"
s
+=
f
", embedding_dim=
{
self
.
embedding_dim
}
"
s
+=
f
", org_vocab_size=
{
self
.
org_vocab_size
}
"
s
+=
f
', num_embeddings_padded=
{
self
.
num_embeddings_padded
}
'
s
+=
f
', tp_size=
{
self
.
tp_size
}
'
return
s
class
ParallelLMHead
(
VocabParallelEmbedding
):
class
ParallelLMHead
(
VocabParallelEmbedding
):
"""Parallelized LM head.
"""Parallelized LM head.
...
...
vllm/model_executor/model_loader/loader.py
View file @
1591c68f
...
@@ -3,16 +3,19 @@ import copy
...
@@ -3,16 +3,19 @@ import copy
import
glob
import
glob
import
os
import
os
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
,
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
,
Type
Type
)
import
huggingface_hub
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
vllm.config
import
(
VLLM_USE_MODELSCOPE
,
DeviceConfig
,
LoadConfig
,
from
vllm.config
import
(
DeviceConfig
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
LoadFormat
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
VisionLanguageConfig
)
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.model_loader.tensorizer
import
(
from
vllm.model_executor.model_loader.tensorizer
import
(
TensorizerConfig
,
is_vllm_serialized_tensorizer
,
load_with_tensorizer
,
TensorizerConfig
,
is_vllm_serialized_tensorizer
,
load_with_tensorizer
,
tensorizer_weights_iterator
)
tensorizer_weights_iterator
)
...
@@ -24,9 +27,6 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -24,9 +27,6 @@ from vllm.model_executor.model_loader.weight_utils import (
pt_weights_iterator
,
safetensors_weights_iterator
)
pt_weights_iterator
,
safetensors_weights_iterator
)
from
vllm.model_executor.models.llava
import
LlavaForConditionalGeneration
from
vllm.model_executor.models.llava
import
LlavaForConditionalGeneration
if
TYPE_CHECKING
:
from
vllm.model_executor.layers.linear
import
LinearMethodBase
_VISION_MODEL_CLASSES
=
[
_VISION_MODEL_CLASSES
=
[
LlavaForConditionalGeneration
,
LlavaForConditionalGeneration
,
]
]
...
@@ -34,11 +34,10 @@ _VISION_MODEL_CLASSES = [
...
@@ -34,11 +34,10 @@ _VISION_MODEL_CLASSES = [
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
def
_get_
linear_method
(
def
_get_
quantization_config
(
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
load_config
:
LoadConfig
)
->
Optional
[
"LinearMethodBase"
]:
load_config
:
LoadConfig
)
->
Optional
[
QuantizationConfig
]:
"""Get the (maybe quantized) linear method."""
"""Get the quantization config."""
linear_method
=
None
if
model_config
.
quantization
is
not
None
:
if
model_config
.
quantization
is
not
None
:
quant_config
=
get_quant_config
(
model_config
,
load_config
)
quant_config
=
get_quant_config
(
model_config
,
load_config
)
capability
=
torch
.
cuda
.
get_device_capability
()
capability
=
torch
.
cuda
.
get_device_capability
()
...
@@ -55,6 +54,7 @@ def _get_linear_method(
...
@@ -55,6 +54,7 @@ def _get_linear_method(
f
"
{
model_config
.
dtype
}
is not supported for quantization "
f
"
{
model_config
.
dtype
}
is not supported for quantization "
f
"method
{
model_config
.
quantization
}
. Supported dtypes: "
f
"method
{
model_config
.
quantization
}
. Supported dtypes: "
f
"
{
supported_dtypes
}
"
)
f
"
{
supported_dtypes
}
"
)
<<<<<<<
HEAD
linear_method
=
quant_config
.
get_linear_method
()
linear_method
=
quant_config
.
get_linear_method
()
...
@@ -62,6 +62,10 @@ def _get_linear_method(
...
@@ -62,6 +62,10 @@ def _get_linear_method(
os
.
environ
[
'LLAMA_NN'
]
=
'0'
os
.
environ
[
'LLAMA_NN'
]
=
'0'
return
linear_method
return
linear_method
=======
return
quant_config
return
None
>>>>>>>
v0
.
4.2
def
_get_model_initialization_kwargs
(
def
_get_model_initialization_kwargs
(
...
@@ -89,10 +93,10 @@ def _initialize_model(
...
@@ -89,10 +93,10 @@ def _initialize_model(
vision_language_config
:
Optional
[
VisionLanguageConfig
])
->
nn
.
Module
:
vision_language_config
:
Optional
[
VisionLanguageConfig
])
->
nn
.
Module
:
"""Initialize a model with the given configurations."""
"""Initialize a model with the given configurations."""
model_class
=
get_model_architecture
(
model_config
)[
0
]
model_class
=
get_model_architecture
(
model_config
)[
0
]
linear_method
=
_get_linear_method
(
model_config
,
load_config
)
quant_config
=
_get_quantization_config
(
model_config
,
load_config
)
return
model_class
(
config
=
model_config
.
hf_config
,
return
model_class
(
config
=
model_config
.
hf_config
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
**
_get_model_initialization_kwargs
(
**
_get_model_initialization_kwargs
(
model_class
,
lora_config
,
vision_language_config
))
model_class
,
lora_config
,
vision_language_config
))
...
@@ -139,7 +143,9 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -139,7 +143,9 @@ class DefaultModelLoader(BaseModelLoader):
model_path
=
snapshot_download
(
model_path
=
snapshot_download
(
model_id
=
model
,
model_id
=
model
,
cache_dir
=
self
.
load_config
.
download_dir
,
cache_dir
=
self
.
load_config
.
download_dir
,
revision
=
revision
)
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
revision
=
revision
,
)
else
:
else
:
model_path
=
model
model_path
=
model
return
model_path
return
model_path
...
@@ -233,9 +239,11 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -233,9 +239,11 @@ class DefaultModelLoader(BaseModelLoader):
"fall_back_to_pt_during_load"
,
"fall_back_to_pt_during_load"
,
True
)),
)
True
)),
)
for
_
,
module
in
model
.
named_modules
():
for
_
,
module
in
model
.
named_modules
():
linear_method
=
getattr
(
module
,
"linear_method"
,
None
)
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
linear_method
is
not
None
:
if
quant_method
is
not
None
:
linear_method
.
process_weights_after_loading
(
module
)
quant_method
.
process_weights_after_loading
(
module
)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if
hasattr
(
module
,
"process_weights_after_loading"
):
if
hasattr
(
module
,
"process_weights_after_loading"
):
module
.
process_weights_after_loading
()
module
.
process_weights_after_loading
()
return
model
.
eval
()
return
model
.
eval
()
...
@@ -318,11 +326,11 @@ class TensorizerLoader(BaseModelLoader):
...
@@ -318,11 +326,11 @@ class TensorizerLoader(BaseModelLoader):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
with
torch
.
device
(
device_config
.
device
):
model_class
=
get_model_architecture
(
model_config
)[
0
]
model_class
=
get_model_architecture
(
model_config
)[
0
]
linear_method
=
_get_linear_method
(
model
_config
,
quant_config
=
_get_quantization
_config
(
self
.
load_config
)
model_config
,
self
.
load_config
)
extra_kwargs
=
_get_model_initialization_kwargs
(
extra_kwargs
=
_get_model_initialization_kwargs
(
model_class
,
lora_config
,
vision_language_config
)
model_class
,
lora_config
,
vision_language_config
)
extra_kwargs
[
"
linear_method"
]
=
linear_method
extra_kwargs
[
"
quant_config"
]
=
quant_config
tensorizer_config
=
copy
.
copy
(
self
.
tensorizer_config
)
tensorizer_config
=
copy
.
copy
(
self
.
tensorizer_config
)
tensorizer_config
.
model_class
=
model_class
tensorizer_config
.
model_class
=
model_class
...
...
vllm/model_executor/model_loader/tensorizer.py
View file @
1591c68f
...
@@ -11,9 +11,11 @@ import torch
...
@@ -11,9 +11,11 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
import
vllm.envs
as
envs
from
vllm.config
import
ModelConfig
,
ParallelConfig
from
vllm.config
import
ModelConfig
,
ParallelConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
...
@@ -43,7 +45,7 @@ class TensorizerConfig:
...
@@ -43,7 +45,7 @@ class TensorizerConfig:
str
,
bytes
,
os
.
PathLike
,
int
]
str
,
bytes
,
os
.
PathLike
,
int
]
vllm_tensorized
:
bool
vllm_tensorized
:
bool
verify_hash
:
Optional
[
bool
]
=
False
verify_hash
:
Optional
[
bool
]
=
False
num_readers
:
Optional
[
int
]
=
1
num_readers
:
Optional
[
int
]
=
None
encryption_keyfile
:
Optional
[
str
]
=
None
encryption_keyfile
:
Optional
[
str
]
=
None
s3_access_key_id
:
Optional
[
str
]
=
None
s3_access_key_id
:
Optional
[
str
]
=
None
s3_secret_access_key
:
Optional
[
str
]
=
None
s3_secret_access_key
:
Optional
[
str
]
=
None
...
@@ -63,7 +65,7 @@ class TensorizerConfig:
...
@@ -63,7 +65,7 @@ class TensorizerConfig:
"s3_secret_access_key"
:
self
.
s3_secret_access_key
,
"s3_secret_access_key"
:
self
.
s3_secret_access_key
,
"s3_endpoint"
:
self
.
s3_endpoint
,
"s3_endpoint"
:
self
.
s3_endpoint
,
}
}
return
TensorizerArgs
(
**
tensorizer_args
)
return
TensorizerArgs
(
**
tensorizer_args
)
# type: ignore
def
verify_with_parallel_config
(
def
verify_with_parallel_config
(
self
,
self
,
...
@@ -103,7 +105,7 @@ class TensorizerArgs:
...
@@ -103,7 +105,7 @@ class TensorizerArgs:
str
,
bytes
,
os
.
PathLike
,
int
]
str
,
bytes
,
os
.
PathLike
,
int
]
vllm_tensorized
:
bool
vllm_tensorized
:
bool
verify_hash
:
Optional
[
bool
]
=
False
verify_hash
:
Optional
[
bool
]
=
False
num_readers
:
Optional
[
int
]
=
1
num_readers
:
Optional
[
int
]
=
None
encryption_keyfile
:
Optional
[
str
]
=
None
encryption_keyfile
:
Optional
[
str
]
=
None
s3_access_key_id
:
Optional
[
str
]
=
None
s3_access_key_id
:
Optional
[
str
]
=
None
s3_secret_access_key
:
Optional
[
str
]
=
None
s3_secret_access_key
:
Optional
[
str
]
=
None
...
@@ -124,8 +126,9 @@ class TensorizerArgs:
...
@@ -124,8 +126,9 @@ class TensorizerArgs:
the hashes stored in the metadata. A `HashMismatchError` will be
the hashes stored in the metadata. A `HashMismatchError` will be
raised if any of the hashes do not match.
raised if any of the hashes do not match.
num_readers: Controls how many threads are allowed to read concurrently
num_readers: Controls how many threads are allowed to read concurrently
from the source file. Default is 1. This greatly increases
from the source file. Default is `None`, which will dynamically set
performance.
the number of readers based on the number of available
resources and model size. This greatly increases performance.
encryption_keyfile: File path to a binary file containing a
encryption_keyfile: File path to a binary file containing a
binary key to use for decryption. `None` (the default) means
binary key to use for decryption. `None` (the default) means
no decryption. See the example script in
no decryption. See the example script in
...
@@ -140,13 +143,10 @@ class TensorizerArgs:
...
@@ -140,13 +143,10 @@ class TensorizerArgs:
def
__post_init__
(
self
):
def
__post_init__
(
self
):
self
.
file_obj
=
self
.
tensorizer_uri
self
.
file_obj
=
self
.
tensorizer_uri
self
.
s3_access_key_id
=
(
self
.
s3_access_key_id
self
.
s3_access_key_id
=
self
.
s3_access_key_id
or
envs
.
S3_ACCESS_KEY_ID
or
os
.
environ
.
get
(
"S3_ACCESS_KEY_ID"
))
or
None
self
.
s3_secret_access_key
=
(
self
.
s3_secret_access_key
self
.
s3_secret_access_key
=
(
or
envs
.
S3_SECRET_ACCESS_KEY
)
self
.
s3_secret_access_key
self
.
s3_endpoint
=
self
.
s3_endpoint
or
envs
.
S3_ENDPOINT_URL
or
os
.
environ
.
get
(
"S3_SECRET_ACCESS_KEY"
))
or
None
self
.
s3_endpoint
=
(
self
.
s3_endpoint
or
os
.
environ
.
get
(
"S3_ENDPOINT_URL"
))
or
None
self
.
stream_params
=
{
self
.
stream_params
=
{
"s3_access_key_id"
:
self
.
s3_access_key_id
,
"s3_access_key_id"
:
self
.
s3_access_key_id
,
"s3_secret_access_key"
:
self
.
s3_secret_access_key
,
"s3_secret_access_key"
:
self
.
s3_secret_access_key
,
...
@@ -198,10 +198,12 @@ class TensorizerArgs:
...
@@ -198,10 +198,12 @@ class TensorizerArgs:
"use for decryption. Can be a file path or S3 network URI."
)
"use for decryption. Can be a file path or S3 network URI."
)
group
.
add_argument
(
group
.
add_argument
(
"--num-readers"
,
"--num-readers"
,
default
=
1
,
default
=
None
,
type
=
int
,
type
=
int
,
help
=
"Controls how many threads are allowed to read concurrently "
help
=
"Controls how many threads are allowed to read concurrently "
"from the source file."
)
"from the source file. Default is `None`, which will dynamically "
"set the number of readers based on the available resources "
"and model size. This greatly increases performance."
)
group
.
add_argument
(
group
.
add_argument
(
"--s3-access-key-id"
,
"--s3-access-key-id"
,
default
=
None
,
default
=
None
,
...
@@ -251,7 +253,7 @@ class TensorizerAgent:
...
@@ -251,7 +253,7 @@ class TensorizerAgent:
"""
"""
def
__init__
(
self
,
tensorizer_config
:
TensorizerConfig
,
def
__init__
(
self
,
tensorizer_config
:
TensorizerConfig
,
linear_method
:
LinearMethodBase
,
**
extra_kwargs
):
quant_config
:
QuantizationConfig
,
**
extra_kwargs
):
if
tensorizer_load_fail
is
not
None
:
if
tensorizer_load_fail
is
not
None
:
raise
ImportError
(
raise
ImportError
(
"Tensorizer is not installed. Please install tensorizer "
"Tensorizer is not installed. Please install tensorizer "
...
@@ -262,19 +264,21 @@ class TensorizerAgent:
...
@@ -262,19 +264,21 @@ class TensorizerAgent:
self
.
tensorizer_args
=
(
self
.
tensorizer_args
=
(
self
.
tensorizer_config
.
_construct_tensorizer_args
())
self
.
tensorizer_config
.
_construct_tensorizer_args
())
self
.
extra_kwargs
=
extra_kwargs
self
.
extra_kwargs
=
extra_kwargs
if
extra_kwargs
.
get
(
"
linear_method
"
,
None
)
is
not
None
:
if
extra_kwargs
.
get
(
"
quant_config
"
,
None
)
is
not
None
:
self
.
linear_method
=
extra_kwargs
[
"
linear_method
"
]
self
.
quant_config
=
extra_kwargs
[
"
quant_config
"
]
else
:
else
:
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
model
=
self
.
_init_model
()
self
.
model
=
self
.
_init_model
()
def
_init_model
(
self
):
def
_init_model
(
self
):
assert
self
.
tensorizer_config
.
hf_config
is
not
None
model_args
=
self
.
tensorizer_config
.
hf_config
model_args
=
self
.
tensorizer_config
.
hf_config
model_args
.
torch_dtype
=
self
.
tensorizer_config
.
dtype
model_args
.
torch_dtype
=
self
.
tensorizer_config
.
dtype
assert
self
.
tensorizer_config
.
model_class
is
not
None
with
no_init_or_tensor
():
with
no_init_or_tensor
():
return
self
.
tensorizer_config
.
model_class
(
return
self
.
tensorizer_config
.
model_class
(
config
=
model_args
,
config
=
model_args
,
linear_method
=
self
.
linear_method
,
quant_config
=
self
.
quant_config
,
**
self
.
extra_kwargs
)
**
self
.
extra_kwargs
)
def
_resize_lora_embeddings
(
self
):
def
_resize_lora_embeddings
(
self
):
...
@@ -334,10 +338,10 @@ class TensorizerAgent:
...
@@ -334,10 +338,10 @@ class TensorizerAgent:
per_second
=
convert_bytes
(
deserializer
.
total_tensor_bytes
/
duration
)
per_second
=
convert_bytes
(
deserializer
.
total_tensor_bytes
/
duration
)
after_mem
=
get_mem_usage
()
after_mem
=
get_mem_usage
()
deserializer
.
close
()
deserializer
.
close
()
logger
.
info
(
f
"Deserialized
{
total_bytes_str
}
in "
logger
.
info
(
"Deserialized
%s in %0.2fs, %s/s"
,
total_bytes_str
,
f
"
{
end
-
start
:
0.2
f
}
s
,
{
per_second
}
/s"
)
end
-
start
,
per_second
)
logger
.
info
(
f
"Memory usage before:
{
before_mem
}
"
)
logger
.
info
(
"Memory usage before:
%s"
,
before_mem
)
logger
.
info
(
f
"Memory usage after:
{
after_mem
}
"
)
logger
.
info
(
"Memory usage after:
%s"
,
after_mem
)
self
.
_check_tensors_on_meta_device
()
self
.
_check_tensors_on_meta_device
()
self
.
_resize_lora_embeddings
()
self
.
_resize_lora_embeddings
()
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
1591c68f
...
@@ -127,11 +127,14 @@ def get_quant_config(model_config: ModelConfig,
...
@@ -127,11 +127,14 @@ def get_quant_config(model_config: ModelConfig,
if
not
is_local
:
if
not
is_local
:
# Download the config files.
# Download the config files.
with
get_lock
(
model_name_or_path
,
load_config
.
download_dir
):
with
get_lock
(
model_name_or_path
,
load_config
.
download_dir
):
hf_folder
=
snapshot_download
(
model_name_or_path
,
hf_folder
=
snapshot_download
(
revision
=
model_config
.
revision
,
model_name_or_path
,
allow_patterns
=
"*.json"
,
revision
=
model_config
.
revision
,
cache_dir
=
load_config
.
download_dir
,
allow_patterns
=
"*.json"
,
tqdm_class
=
DisabledTqdm
)
cache_dir
=
load_config
.
download_dir
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
tqdm_class
=
DisabledTqdm
,
)
else
:
else
:
hf_folder
=
model_name_or_path
hf_folder
=
model_name_or_path
...
@@ -161,12 +164,14 @@ def get_quant_config(model_config: ModelConfig,
...
@@ -161,12 +164,14 @@ def get_quant_config(model_config: ModelConfig,
return
quant_cls
.
from_config
(
config
)
return
quant_cls
.
from_config
(
config
)
def
download_weights_from_hf
(
model_name_or_path
:
str
,
def
download_weights_from_hf
(
cache_dir
:
Optional
[
str
],
model_name_or_path
:
str
,
allow_patterns
:
List
[
str
],
cache_dir
:
Optional
[
str
],
revision
:
Optional
[
str
]
=
None
)
->
str
:
allow_patterns
:
List
[
str
],
revision
:
Optional
[
str
]
=
None
,
)
->
str
:
"""Download model weights from Hugging Face Hub.
"""Download model weights from Hugging Face Hub.
Args:
Args:
model_name_or_path (str): The model name or path.
model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model
cache_dir (Optional[str]): The cache directory to store the model
...
@@ -179,26 +184,30 @@ def download_weights_from_hf(model_name_or_path: str,
...
@@ -179,26 +184,30 @@ def download_weights_from_hf(model_name_or_path: str,
Returns:
Returns:
str: The path to the downloaded model weights.
str: The path to the downloaded model weights.
"""
"""
# Before we download we look at that is available:
if
not
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
:
fs
=
HfFileSystem
()
# Before we download we look at that is available:
file_list
=
fs
.
ls
(
model_name_or_path
,
detail
=
False
,
revision
=
revision
)
fs
=
HfFileSystem
()
file_list
=
fs
.
ls
(
model_name_or_path
,
detail
=
False
,
revision
=
revision
)
# depending on what is available we download different things
for
pattern
in
allow_patterns
:
# depending on what is available we download different things
matching
=
fnmatch
.
filter
(
file_list
,
pattern
)
for
pattern
in
allow_patterns
:
if
len
(
matching
)
>
0
:
matching
=
fnmatch
.
filter
(
file_list
,
pattern
)
allow_patterns
=
[
pattern
]
if
len
(
matching
)
>
0
:
break
allow_patterns
=
[
pattern
]
break
logger
.
info
(
f
"Using model weights format
{
allow_patterns
}
"
)
logger
.
info
(
"Using model weights format %s"
,
allow_patterns
)
# Use file lock to prevent multiple processes from
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
# downloading the same model weights at the same time.
with
get_lock
(
model_name_or_path
,
cache_dir
):
with
get_lock
(
model_name_or_path
,
cache_dir
):
hf_folder
=
snapshot_download
(
model_name_or_path
,
hf_folder
=
snapshot_download
(
allow_patterns
=
allow_patterns
,
model_name_or_path
,
cache_dir
=
cache_dir
,
allow_patterns
=
allow_patterns
,
tqdm_class
=
DisabledTqdm
,
cache_dir
=
cache_dir
,
revision
=
revision
)
tqdm_class
=
DisabledTqdm
,
revision
=
revision
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
)
return
hf_folder
return
hf_folder
...
@@ -310,17 +319,17 @@ def kv_cache_scales_loader(
...
@@ -310,17 +319,17 @@ def kv_cache_scales_loader(
return
layer_scales_map
.
items
()
return
layer_scales_map
.
items
()
except
FileNotFoundError
:
except
FileNotFoundError
:
logger
.
error
(
f
"File or directory '
{
filename
}
' not found."
)
logger
.
error
(
"File or directory '
%s
' not found."
,
filename
)
except
json
.
JSONDecodeError
:
except
json
.
JSONDecodeError
:
logger
.
error
(
f
"Error decoding JSON in file '
{
filename
}
'."
)
logger
.
error
(
"Error decoding JSON in file '
%s'."
,
filename
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"An error occurred while reading '
{
filename
}
':
{
e
}
"
)
logger
.
error
(
"An error occurred while reading '
%s': %s"
,
filename
,
e
)
# This section is reached if and only if any of the excepts are hit
# This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded
# Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales
# which ultimately defaults to 1.0 scales
logger
.
warning
(
"Defaulting to KV cache scaling factors = 1.0 "
logger
.
warning
(
f
"for all layers in TP rank
{
tp_rank
}
"
"Defaulting to KV cache scaling factors = 1.0 for all
"
"
as an error occurred during loading."
)
"layers in TP rank %d
as an error occurred during loading."
,
tp_rank
)
return
[]
return
[]
...
...
vllm/model_executor/models/__init__.py
View file @
1591c68f
...
@@ -42,10 +42,11 @@ _MODELS = {
...
@@ -42,10 +42,11 @@ _MODELS = {
"MptForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MptForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MPTForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MPTForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MiniCPMForCausalLM"
:
(
"minicpm"
,
"MiniCPMForCausalLM"
),
"MiniCPMForCausalLM"
:
(
"minicpm"
,
"MiniCPMForCausalLM"
),
"O
LM
oForCausalLM"
:
(
"olmo"
,
"O
LM
oForCausalLM"
),
"O
lm
oForCausalLM"
:
(
"olmo"
,
"O
lm
oForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"OrionForCausalLM"
:
(
"orion"
,
"OrionForCausalLM"
),
"OrionForCausalLM"
:
(
"orion"
,
"OrionForCausalLM"
),
"PhiForCausalLM"
:
(
"phi"
,
"PhiForCausalLM"
),
"PhiForCausalLM"
:
(
"phi"
,
"PhiForCausalLM"
),
"Phi3ForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"Qwen2ForCausalLM"
:
(
"qwen2"
,
"Qwen2ForCausalLM"
),
"Qwen2ForCausalLM"
:
(
"qwen2"
,
"Qwen2ForCausalLM"
),
"Qwen2MoeForCausalLM"
:
(
"qwen2_moe"
,
"Qwen2MoeForCausalLM"
),
"Qwen2MoeForCausalLM"
:
(
"qwen2_moe"
,
"Qwen2MoeForCausalLM"
),
...
@@ -90,8 +91,8 @@ class ModelRegistry:
...
@@ -90,8 +91,8 @@ class ModelRegistry:
"ROCm for now."
)
"ROCm for now."
)
if
model_arch
in
_ROCM_PARTIALLY_SUPPORTED_MODELS
:
if
model_arch
in
_ROCM_PARTIALLY_SUPPORTED_MODELS
:
logger
.
warning
(
logger
.
warning
(
f
"Model architecture
{
model_arch
}
is partially supported
"
"Model architecture
%s
is partially supported
by ROCm: %s"
,
"by ROCm: "
+
_ROCM_PARTIALLY_SUPPORTED_MODELS
[
model_arch
])
model_arch
,
_ROCM_PARTIALLY_SUPPORTED_MODELS
[
model_arch
])
module_name
,
model_cls_name
=
_MODELS
[
model_arch
]
module_name
,
model_cls_name
=
_MODELS
[
model_arch
]
module
=
importlib
.
import_module
(
module
=
importlib
.
import_module
(
...
@@ -106,9 +107,9 @@ class ModelRegistry:
...
@@ -106,9 +107,9 @@ class ModelRegistry:
def
register_model
(
model_arch
:
str
,
model_cls
:
Type
[
nn
.
Module
]):
def
register_model
(
model_arch
:
str
,
model_cls
:
Type
[
nn
.
Module
]):
if
model_arch
in
_MODELS
:
if
model_arch
in
_MODELS
:
logger
.
warning
(
logger
.
warning
(
f
"Model architecture
{
model_arch
}
is already registered, "
"Model architecture
%s
is already registered,
and will be
"
"
and will be
overwritten by the new model
"
"overwritten by the new model
class %s."
,
model_arch
,
f
"class
{
model_cls
.
__name__
}
."
)
model_cls
.
__name__
)
global
_OOT_MODELS
global
_OOT_MODELS
_OOT_MODELS
[
model_arch
]
=
model_cls
_OOT_MODELS
[
model_arch
]
=
model_cls
...
...
vllm/model_executor/models/baichuan.py
View file @
1591c68f
...
@@ -31,11 +31,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
...
@@ -31,11 +31,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -77,17 +78,17 @@ class BaiChuanMLP(nn.Module):
...
@@ -77,17 +78,17 @@ class BaiChuanMLP(nn.Module):
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
hidden_act
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
if
hidden_act
!=
"silu"
:
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
"Only silu is supported for now."
)
...
@@ -110,7 +111,7 @@ class BaiChuanAttention(nn.Module):
...
@@ -110,7 +111,7 @@ class BaiChuanAttention(nn.Module):
position_embedding
:
str
,
position_embedding
:
str
,
rope_theta
:
float
=
10000
,
rope_theta
:
float
=
10000
,
max_position_embeddings
:
int
=
8192
,
max_position_embeddings
:
int
=
8192
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -132,13 +133,13 @@ class BaiChuanAttention(nn.Module):
...
@@ -132,13 +133,13 @@ class BaiChuanAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_heads
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
o_proj
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
# Create the alibi slopes and slice them.
# Create the alibi slopes and slice them.
if
self
.
postion_embedding
==
"ALIBI"
:
if
self
.
postion_embedding
==
"ALIBI"
:
...
@@ -184,7 +185,7 @@ class BaiChuanDecoderLayer(nn.Module):
...
@@ -184,7 +185,7 @@ class BaiChuanDecoderLayer(nn.Module):
def
__init__
(
self
,
def
__init__
(
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
position_embedding
:
str
,
position_embedding
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
...
@@ -196,13 +197,13 @@ class BaiChuanDecoderLayer(nn.Module):
...
@@ -196,13 +197,13 @@ class BaiChuanDecoderLayer(nn.Module):
position_embedding
=
position_embedding
,
position_embedding
=
position_embedding
,
rope_theta
=
rope_theta
,
rope_theta
=
rope_theta
,
max_position_embeddings
=
max_position_embeddings
,
max_position_embeddings
=
max_position_embeddings
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
mlp
=
BaiChuanMLP
(
self
.
mlp
=
BaiChuanMLP
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
...
@@ -243,7 +244,7 @@ class BaiChuanModel(nn.Module):
...
@@ -243,7 +244,7 @@ class BaiChuanModel(nn.Module):
def
__init__
(
self
,
def
__init__
(
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
position_embedding
:
str
,
position_embedding
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
...
@@ -254,7 +255,7 @@ class BaiChuanModel(nn.Module):
...
@@ -254,7 +255,7 @@ class BaiChuanModel(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
)
)
self
.
layers
=
nn
.
ModuleList
([
self
.
layers
=
nn
.
ModuleList
([
BaiChuanDecoderLayer
(
config
,
position_embedding
,
linear_method
)
BaiChuanDecoderLayer
(
config
,
position_embedding
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
@@ -303,13 +304,13 @@ class BaiChuanBaseForCausalLM(nn.Module):
...
@@ -303,13 +304,13 @@ class BaiChuanBaseForCausalLM(nn.Module):
self
,
self
,
config
,
config
,
position_embedding
:
str
,
position_embedding
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
model
=
BaiChuanModel
(
config
,
position_embedding
,
linear_method
)
self
.
model
=
BaiChuanModel
(
config
,
position_embedding
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
@@ -388,13 +389,13 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
...
@@ -388,13 +389,13 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
def
__init__
(
def
__init__
(
self
,
self
,
config
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
):
if
config
.
hidden_size
==
4096
:
# baichuan2 7b
if
config
.
hidden_size
==
4096
:
# baichuan2 7b
super
().
__init__
(
config
,
"ROPE"
,
linear_method
,
lora_config
)
super
().
__init__
(
config
,
"ROPE"
,
quant_config
,
lora_config
)
else
:
# baichuan 13b, baichuan2 13b
else
:
# baichuan 13b, baichuan2 13b
super
().
__init__
(
config
,
"ALIBI"
,
linear_method
,
lora_config
)
super
().
__init__
(
config
,
"ALIBI"
,
quant_config
,
lora_config
)
class
BaiChuanForCausalLM
(
BaiChuanBaseForCausalLM
):
class
BaiChuanForCausalLM
(
BaiChuanBaseForCausalLM
):
...
@@ -403,7 +404,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
...
@@ -403,7 +404,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
def
__init__
(
def
__init__
(
self
,
self
,
config
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
):
super
().
__init__
(
config
,
"ROPE"
,
linear_method
,
lora_config
)
super
().
__init__
(
config
,
"ROPE"
,
quant_config
,
lora_config
)
vllm/model_executor/models/bloom.py
View file @
1591c68f
...
@@ -28,10 +28,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
...
@@ -28,10 +28,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
...
@@ -70,7 +71,7 @@ class BloomAttention(nn.Module):
...
@@ -70,7 +71,7 @@ class BloomAttention(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
BloomConfig
,
config
:
BloomConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -87,13 +88,13 @@ class BloomAttention(nn.Module):
...
@@ -87,13 +88,13 @@ class BloomAttention(nn.Module):
self
.
head_dim
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_heads
,
bias
=
True
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
dense
=
RowParallelLinear
(
self
.
dense
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
# Create the alibi slopes and slice them.
# Create the alibi slopes and slice them.
...
@@ -129,21 +130,20 @@ class BloomMLP(nn.Module):
...
@@ -129,21 +130,20 @@ class BloomMLP(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
BloomConfig
,
config
:
BloomConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
hidden_size
=
config
.
hidden_size
hidden_size
=
config
.
hidden_size
self
.
dense_h_to_4h
=
ColumnParallelLinear
(
self
.
dense_h_to_4h
=
ColumnParallelLinear
(
hidden_size
,
hidden_size
,
4
*
hidden_size
,
4
*
hidden_size
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
gelu_impl
=
get_act_fn
(
"gelu"
,
quant_config
,
4
*
hidden_size
)
self
.
gelu_impl
=
get_act_fn
(
"gelu"
,
quant_config
,
4
*
hidden_size
)
self
.
dense_4h_to_h
=
RowParallelLinear
(
self
.
dense_4h_to_h
=
RowParallelLinear
(
4
*
hidden_size
,
4
*
hidden_size
,
hidden_size
,
hidden_size
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -158,17 +158,17 @@ class BloomBlock(nn.Module):
...
@@ -158,17 +158,17 @@ class BloomBlock(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
BloomConfig
,
config
:
BloomConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
hidden_size
=
config
.
hidden_size
hidden_size
=
config
.
hidden_size
self
.
input_layernorm
=
nn
.
LayerNorm
(
hidden_size
,
self
.
input_layernorm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
eps
=
config
.
layer_norm_epsilon
)
self
.
self_attention
=
BloomAttention
(
config
,
linear_method
)
self
.
self_attention
=
BloomAttention
(
config
,
quant_config
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
BloomMLP
(
config
,
linear_method
)
self
.
mlp
=
BloomMLP
(
config
,
quant_config
)
self
.
apply_residual_connection_post_layernorm
=
(
self
.
apply_residual_connection_post_layernorm
=
(
config
.
apply_residual_connection_post_layernorm
)
config
.
apply_residual_connection_post_layernorm
)
...
@@ -214,7 +214,7 @@ class BloomModel(nn.Module):
...
@@ -214,7 +214,7 @@ class BloomModel(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
BloomConfig
,
config
:
BloomConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
embed_dim
=
config
.
hidden_size
self
.
embed_dim
=
config
.
hidden_size
...
@@ -229,7 +229,7 @@ class BloomModel(nn.Module):
...
@@ -229,7 +229,7 @@ class BloomModel(nn.Module):
# Transformer blocks
# Transformer blocks
self
.
h
=
nn
.
ModuleList
([
self
.
h
=
nn
.
ModuleList
([
BloomBlock
(
config
,
linear_method
)
BloomBlock
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
])
...
@@ -262,12 +262,12 @@ class BloomForCausalLM(nn.Module):
...
@@ -262,12 +262,12 @@ class BloomForCausalLM(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
BloomConfig
,
config
:
BloomConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
transformer
=
BloomModel
(
config
,
linear_method
)
self
.
transformer
=
BloomModel
(
config
,
quant_config
)
self
.
lm_head_weight
=
self
.
transformer
.
word_embeddings
.
weight
self
.
lm_head_weight
=
self
.
transformer
.
word_embeddings
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/chatglm.py
View file @
1591c68f
...
@@ -13,11 +13,12 @@ from vllm.config import LoRAConfig
...
@@ -13,11 +13,12 @@ from vllm.config import LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -33,7 +34,7 @@ class GLMAttention(nn.Module):
...
@@ -33,7 +34,7 @@ class GLMAttention(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -65,13 +66,13 @@ class GLMAttention(nn.Module):
...
@@ -65,13 +66,13 @@ class GLMAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
self
.
total_num_kv_heads
,
bias
=
config
.
add_bias_linear
or
config
.
add_qkv_bias
,
bias
=
config
.
add_bias_linear
or
config
.
add_qkv_bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
dense
=
RowParallelLinear
(
self
.
dense
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
total_num_heads
*
self
.
head_dim
,
config
.
hidden_size
,
config
.
hidden_size
,
bias
=
config
.
add_bias_linear
,
bias
=
config
.
add_bias_linear
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
...
@@ -123,7 +124,7 @@ class GLMMLP(nn.Module):
...
@@ -123,7 +124,7 @@ class GLMMLP(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -134,7 +135,7 @@ class GLMMLP(nn.Module):
...
@@ -134,7 +135,7 @@ class GLMMLP(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
[
config
.
ffn_hidden_size
]
*
2
,
[
config
.
ffn_hidden_size
]
*
2
,
bias
=
config
.
add_bias_linear
,
bias
=
config
.
add_bias_linear
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
activation_func
=
SiluAndMul
()
self
.
activation_func
=
SiluAndMul
()
...
@@ -144,7 +145,7 @@ class GLMMLP(nn.Module):
...
@@ -144,7 +145,7 @@ class GLMMLP(nn.Module):
config
.
ffn_hidden_size
,
config
.
ffn_hidden_size
,
config
.
hidden_size
,
config
.
hidden_size
,
bias
=
config
.
add_bias_linear
,
bias
=
config
.
add_bias_linear
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
...
@@ -166,7 +167,7 @@ class GLMBlock(nn.Module):
...
@@ -166,7 +167,7 @@ class GLMBlock(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
apply_residual_connection_post_layernorm
=
(
self
.
apply_residual_connection_post_layernorm
=
(
...
@@ -180,7 +181,7 @@ class GLMBlock(nn.Module):
...
@@ -180,7 +181,7 @@ class GLMBlock(nn.Module):
eps
=
config
.
layernorm_epsilon
)
eps
=
config
.
layernorm_epsilon
)
# Self attention.
# Self attention.
self
.
self_attention
=
GLMAttention
(
config
,
linear_method
)
self
.
self_attention
=
GLMAttention
(
config
,
quant_config
)
self
.
hidden_dropout
=
config
.
hidden_dropout
self
.
hidden_dropout
=
config
.
hidden_dropout
# Layernorm on the attention output
# Layernorm on the attention output
...
@@ -188,7 +189,7 @@ class GLMBlock(nn.Module):
...
@@ -188,7 +189,7 @@ class GLMBlock(nn.Module):
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
)
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
)
# MLP
# MLP
self
.
mlp
=
GLMMLP
(
config
,
linear_method
)
self
.
mlp
=
GLMMLP
(
config
,
quant_config
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -236,7 +237,7 @@ class GLMTransformer(nn.Module):
...
@@ -236,7 +237,7 @@ class GLMTransformer(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
post_layer_norm
=
config
.
post_layer_norm
self
.
post_layer_norm
=
config
.
post_layer_norm
...
@@ -246,7 +247,7 @@ class GLMTransformer(nn.Module):
...
@@ -246,7 +247,7 @@ class GLMTransformer(nn.Module):
# Transformer layers.
# Transformer layers.
self
.
layers
=
nn
.
ModuleList
(
self
.
layers
=
nn
.
ModuleList
(
[
GLMBlock
(
config
,
linear_method
)
for
i
in
range
(
self
.
num_layers
)])
[
GLMBlock
(
config
,
quant_config
)
for
i
in
range
(
self
.
num_layers
)])
if
self
.
post_layer_norm
:
if
self
.
post_layer_norm
:
layer_norm_func
=
RMSNorm
if
config
.
rmsnorm
else
LayerNorm
layer_norm_func
=
RMSNorm
if
config
.
rmsnorm
else
LayerNorm
...
@@ -281,7 +282,7 @@ class ChatGLMModel(nn.Module):
...
@@ -281,7 +282,7 @@ class ChatGLMModel(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -291,7 +292,7 @@ class ChatGLMModel(nn.Module):
...
@@ -291,7 +292,7 @@ class ChatGLMModel(nn.Module):
self
.
num_layers
=
config
.
num_layers
self
.
num_layers
=
config
.
num_layers
self
.
multi_query_group_num
=
config
.
multi_query_group_num
self
.
multi_query_group_num
=
config
.
multi_query_group_num
self
.
kv_channels
=
config
.
kv_channels
self
.
kv_channels
=
config
.
kv_channels
self
.
encoder
=
GLMTransformer
(
config
,
linear_method
)
self
.
encoder
=
GLMTransformer
(
config
,
quant_config
)
self
.
output_layer
=
ParallelLMHead
(
config
.
padded_vocab_size
,
self
.
output_layer
=
ParallelLMHead
(
config
.
padded_vocab_size
,
config
.
hidden_size
)
config
.
hidden_size
)
...
@@ -333,13 +334,13 @@ class ChatGLMForCausalLM(nn.Module):
...
@@ -333,13 +334,13 @@ class ChatGLMForCausalLM(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
ChatGLMConfig
,
config
:
ChatGLMConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
:
ChatGLMConfig
=
config
self
.
config
:
ChatGLMConfig
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
transformer
=
ChatGLMModel
(
config
,
linear_method
)
self
.
transformer
=
ChatGLMModel
(
config
,
quant_config
)
self
.
lm_head_weight
=
self
.
transformer
.
output_layer
.
weight
self
.
lm_head_weight
=
self
.
transformer
.
output_layer
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
padded_vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
padded_vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/commandr.py
View file @
1591c68f
...
@@ -32,11 +32,12 @@ from vllm.attention import Attention, AttentionMetadata
...
@@ -32,11 +32,12 @@ from vllm.attention import Attention, AttentionMetadata
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -91,7 +92,7 @@ class CohereMLP(nn.Module):
...
@@ -91,7 +92,7 @@ class CohereMLP(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -101,13 +102,13 @@ class CohereMLP(nn.Module):
...
@@ -101,13 +102,13 @@ class CohereMLP(nn.Module):
self
.
hidden_size
,
self
.
hidden_size
,
[
self
.
intermediate_size
]
*
2
,
[
self
.
intermediate_size
]
*
2
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
down_proj
=
RowParallelLinear
(
self
.
down_proj
=
RowParallelLinear
(
self
.
intermediate_size
,
self
.
intermediate_size
,
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
act_fn
=
SiluAndMul
()
self
.
act_fn
=
SiluAndMul
()
...
@@ -123,7 +124,7 @@ class CohereAttention(nn.Module):
...
@@ -123,7 +124,7 @@ class CohereAttention(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
CohereConfig
,
config
:
CohereConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
...
@@ -158,13 +159,13 @@ class CohereAttention(nn.Module):
...
@@ -158,13 +159,13 @@ class CohereAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
o_proj
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
total_num_heads
*
self
.
head_dim
,
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
self
.
head_dim
,
...
@@ -218,13 +219,13 @@ class CohereDecoderLayer(nn.Module):
...
@@ -218,13 +219,13 @@ class CohereDecoderLayer(nn.Module):
def
__init__
(
self
,
def
__init__
(
self
,
config
:
CohereConfig
,
config
:
CohereConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
CohereAttention
(
config
,
linear_method
=
linear_method
)
self
.
self_attn
=
CohereAttention
(
config
,
quant_config
=
quant_config
)
self
.
mlp
=
CohereMLP
(
config
,
linear_method
=
linear_method
)
self
.
mlp
=
CohereMLP
(
config
,
quant_config
=
quant_config
)
self
.
input_layernorm
=
LayerNorm
(
param_shape
=
(
config
.
hidden_size
),
self
.
input_layernorm
=
LayerNorm
(
param_shape
=
(
config
.
hidden_size
),
eps
=
config
.
layer_norm_eps
)
eps
=
config
.
layer_norm_eps
)
...
@@ -257,7 +258,7 @@ class CohereModel(nn.Module):
...
@@ -257,7 +258,7 @@ class CohereModel(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
CohereConfig
,
config
:
CohereConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -265,7 +266,7 @@ class CohereModel(nn.Module):
...
@@ -265,7 +266,7 @@ class CohereModel(nn.Module):
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
([
self
.
layers
=
nn
.
ModuleList
([
CohereDecoderLayer
(
config
,
linear_method
=
linear_method
)
CohereDecoderLayer
(
config
,
quant_config
=
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
])
self
.
norm
=
LayerNorm
(
param_shape
=
(
config
.
hidden_size
),
self
.
norm
=
LayerNorm
(
param_shape
=
(
config
.
hidden_size
),
...
@@ -298,14 +299,14 @@ class CohereForCausalLM(nn.Module):
...
@@ -298,14 +299,14 @@ class CohereForCausalLM(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
CohereConfig
,
config
:
CohereConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
,
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
,
scale
=
config
.
logit_scale
)
scale
=
config
.
logit_scale
)
self
.
model
=
CohereModel
(
config
,
linear_method
)
self
.
model
=
CohereModel
(
config
,
quant_config
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
...
vllm/model_executor/models/dbrx.py
View file @
1591c68f
...
@@ -9,11 +9,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
...
@@ -9,11 +9,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -44,7 +45,7 @@ class DbrxRouter(nn.Module):
...
@@ -44,7 +45,7 @@ class DbrxRouter(nn.Module):
self
.
num_total_experts
,
self
.
num_total_experts
,
bias
=
False
,
bias
=
False
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
linear_method
=
None
,
quant_config
=
None
,
)
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -63,7 +64,7 @@ class DbrxExperts(nn.Module):
...
@@ -63,7 +64,7 @@ class DbrxExperts(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
DbrxConfig
,
config
:
DbrxConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -165,7 +166,7 @@ class DbrxAttention(nn.Module):
...
@@ -165,7 +166,7 @@ class DbrxAttention(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
DbrxConfig
,
config
:
DbrxConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
d_model
=
config
.
d_model
self
.
d_model
=
config
.
d_model
...
@@ -183,13 +184,13 @@ class DbrxAttention(nn.Module):
...
@@ -183,13 +184,13 @@ class DbrxAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
out_proj
=
RowParallelLinear
(
self
.
out_proj
=
RowParallelLinear
(
self
.
d_model
,
self
.
d_model
,
self
.
d_model
,
self
.
d_model
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
self
.
head_dim
,
...
@@ -244,11 +245,11 @@ class DbrxFusedNormAttention(nn.Module):
...
@@ -244,11 +245,11 @@ class DbrxFusedNormAttention(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
DbrxConfig
,
config
:
DbrxConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
d_model
=
config
.
d_model
self
.
d_model
=
config
.
d_model
self
.
attn
=
DbrxAttention
(
config
,
linear_method
)
self
.
attn
=
DbrxAttention
(
config
,
quant_config
)
self
.
norm_1
=
nn
.
LayerNorm
(
self
.
d_model
)
self
.
norm_1
=
nn
.
LayerNorm
(
self
.
d_model
)
self
.
norm_2
=
nn
.
LayerNorm
(
self
.
d_model
)
self
.
norm_2
=
nn
.
LayerNorm
(
self
.
d_model
)
...
@@ -278,11 +279,11 @@ class DbrxBlock(nn.Module):
...
@@ -278,11 +279,11 @@ class DbrxBlock(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
DbrxConfig
,
config
:
DbrxConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
norm_attn_norm
=
DbrxFusedNormAttention
(
config
,
linear_method
)
self
.
norm_attn_norm
=
DbrxFusedNormAttention
(
config
,
quant_config
)
self
.
ffn
=
DbrxExperts
(
config
,
linear_method
)
self
.
ffn
=
DbrxExperts
(
config
,
quant_config
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -307,7 +308,7 @@ class DbrxModel(nn.Module):
...
@@ -307,7 +308,7 @@ class DbrxModel(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
DbrxConfig
,
config
:
DbrxConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
wte
=
VocabParallelEmbedding
(
self
.
wte
=
VocabParallelEmbedding
(
...
@@ -315,7 +316,7 @@ class DbrxModel(nn.Module):
...
@@ -315,7 +316,7 @@ class DbrxModel(nn.Module):
config
.
d_model
,
config
.
d_model
,
)
)
self
.
blocks
=
nn
.
ModuleList
(
self
.
blocks
=
nn
.
ModuleList
(
[
DbrxBlock
(
config
,
linear_method
)
for
_
in
range
(
config
.
n_layers
)])
[
DbrxBlock
(
config
,
quant_config
)
for
_
in
range
(
config
.
n_layers
)])
self
.
norm_f
=
nn
.
LayerNorm
(
config
.
d_model
,
eps
=
1e-5
)
self
.
norm_f
=
nn
.
LayerNorm
(
config
.
d_model
,
eps
=
1e-5
)
for
module
in
self
.
modules
():
for
module
in
self
.
modules
():
if
hasattr
(
module
,
"bias"
)
and
isinstance
(
module
.
bias
,
if
hasattr
(
module
,
"bias"
)
and
isinstance
(
module
.
bias
,
...
@@ -348,13 +349,13 @@ class DbrxForCausalLM(nn.Module):
...
@@ -348,13 +349,13 @@ class DbrxForCausalLM(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
DbrxConfig
,
config
:
DbrxConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
transformer
=
DbrxModel
(
config
,
linear_method
)
self
.
transformer
=
DbrxModel
(
config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
vocab_size
,
config
.
d_model
,
config
.
d_model
,
...
...
vllm/model_executor/models/decilm.py
View file @
1591c68f
...
@@ -29,7 +29,8 @@ import torch
...
@@ -29,7 +29,8 @@ import torch
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.config
import
LoRAConfig
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.layers.linear
import
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
...
@@ -55,13 +56,13 @@ class DeciLMForCausalLM(LlamaForCausalLM):
...
@@ -55,13 +56,13 @@ class DeciLMForCausalLM(LlamaForCausalLM):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
Optional
[
PretrainedConfig
]
=
None
,
config
:
Optional
[
PretrainedConfig
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
)
->
None
:
config
.
num_key_value_heads
=
max
(
config
.
num_key_value_heads_per_layer
)
config
.
num_key_value_heads
=
max
(
config
.
num_key_value_heads_per_layer
)
delattr
(
config
,
"num_key_value_heads_per_layer"
)
delattr
(
config
,
"num_key_value_heads_per_layer"
)
super
().
__init__
(
config
=
config
,
super
().
__init__
(
config
=
config
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
lora_config
=
lora_config
)
lora_config
=
lora_config
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
vllm/model_executor/models/deepseek.py
View file @
1591c68f
...
@@ -34,12 +34,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
...
@@ -34,12 +34,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -56,18 +57,18 @@ class DeepseekMLP(nn.Module):
...
@@ -56,18 +57,18 @@ class DeepseekMLP(nn.Module):
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
hidden_act
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
reduce_results
:
bool
=
True
,
reduce_results
:
bool
=
True
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
)
reduce_results
=
reduce_results
)
if
hidden_act
!=
"silu"
:
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
...
@@ -86,7 +87,7 @@ class DeepseekMoE(nn.Module):
...
@@ -86,7 +87,7 @@ class DeepseekMoE(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -103,7 +104,7 @@ class DeepseekMoE(nn.Module):
...
@@ -103,7 +104,7 @@ class DeepseekMoE(nn.Module):
DeepseekMLP
(
hidden_size
=
config
.
hidden_size
,
DeepseekMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
hidden_act
=
config
.
hidden_act
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
reduce_results
=
False
)
reduce_results
=
False
)
for
idx
in
range
(
self
.
n_routed_experts
)
for
idx
in
range
(
self
.
n_routed_experts
)
])
])
...
@@ -112,7 +113,7 @@ class DeepseekMoE(nn.Module):
...
@@ -112,7 +113,7 @@ class DeepseekMoE(nn.Module):
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
n_routed_experts
,
self
.
n_routed_experts
,
bias
=
False
,
bias
=
False
,
linear_method
=
None
)
quant_config
=
None
)
if
config
.
n_shared_experts
is
not
None
:
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
(
config
.
moe_intermediate_size
*
intermediate_size
=
(
config
.
moe_intermediate_size
*
...
@@ -121,7 +122,7 @@ class DeepseekMoE(nn.Module):
...
@@ -121,7 +122,7 @@ class DeepseekMoE(nn.Module):
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
intermediate_size
=
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
reduce_results
=
False
,
reduce_results
=
False
,
)
)
...
@@ -177,7 +178,7 @@ class DeepseekAttention(nn.Module):
...
@@ -177,7 +178,7 @@ class DeepseekAttention(nn.Module):
rope_theta
:
float
=
10000
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
max_position_embeddings
:
int
=
8192
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -208,14 +209,14 @@ class DeepseekAttention(nn.Module):
...
@@ -208,14 +209,14 @@ class DeepseekAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
o_proj
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
(
...
@@ -251,7 +252,7 @@ class DeepseekDecoderLayer(nn.Module):
...
@@ -251,7 +252,7 @@ class DeepseekDecoderLayer(nn.Module):
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
layer_idx
:
int
,
layer_idx
:
int
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -266,18 +267,18 @@ class DeepseekDecoderLayer(nn.Module):
...
@@ -266,18 +267,18 @@ class DeepseekDecoderLayer(nn.Module):
rope_theta
=
rope_theta
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
max_position_embeddings
=
max_position_embeddings
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
if
(
config
.
n_routed_experts
is
not
None
if
(
config
.
n_routed_experts
is
not
None
and
layer_idx
>=
config
.
first_k_dense_replace
and
layer_idx
>=
config
.
first_k_dense_replace
and
layer_idx
%
config
.
moe_layer_freq
==
0
):
and
layer_idx
%
config
.
moe_layer_freq
==
0
):
self
.
mlp
=
DeepseekMoE
(
config
=
config
,
linear_method
=
linear_method
)
self
.
mlp
=
DeepseekMoE
(
config
=
config
,
quant_config
=
quant_config
)
else
:
else
:
self
.
mlp
=
DeepseekMLP
(
self
.
mlp
=
DeepseekMLP
(
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
...
@@ -320,7 +321,7 @@ class DeepseekModel(nn.Module):
...
@@ -320,7 +321,7 @@ class DeepseekModel(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
...
@@ -331,9 +332,7 @@ class DeepseekModel(nn.Module):
...
@@ -331,9 +332,7 @@ class DeepseekModel(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
)
)
self
.
layers
=
nn
.
ModuleList
([
self
.
layers
=
nn
.
ModuleList
([
DeepseekDecoderLayer
(
config
,
DeepseekDecoderLayer
(
config
,
layer_idx
,
quant_config
=
quant_config
)
layer_idx
,
linear_method
=
linear_method
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
])
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
@@ -361,12 +360,12 @@ class DeepseekForCausalLM(nn.Module):
...
@@ -361,12 +360,12 @@ class DeepseekForCausalLM(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
model
=
DeepseekModel
(
config
,
linear_method
)
self
.
model
=
DeepseekModel
(
config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/falcon.py
View file @
1591c68f
...
@@ -32,10 +32,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
...
@@ -32,10 +32,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -76,7 +77,7 @@ class FalconAttention(nn.Module):
...
@@ -76,7 +77,7 @@ class FalconAttention(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
FalconConfig
,
config
:
FalconConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -115,7 +116,7 @@ class FalconAttention(nn.Module):
...
@@ -115,7 +116,7 @@ class FalconAttention(nn.Module):
self
.
total_num_kv_heads
,
self
.
total_num_kv_heads
,
bias
=
config
.
bias
,
bias
=
config
.
bias
,
skip_bias_add
=
True
,
skip_bias_add
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
...
@@ -129,7 +130,7 @@ class FalconAttention(nn.Module):
...
@@ -129,7 +130,7 @@ class FalconAttention(nn.Module):
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
config
.
bias
,
bias
=
config
.
bias
,
skip_bias_add
=
True
,
skip_bias_add
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
reduce_results
=
self
.
reduce_row_parallel_results
)
reduce_results
=
self
.
reduce_row_parallel_results
)
self
.
use_rotary
=
config
.
rotary
self
.
use_rotary
=
config
.
rotary
...
@@ -192,7 +193,7 @@ class FalconMLP(nn.Module):
...
@@ -192,7 +193,7 @@ class FalconMLP(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
FalconConfig
,
config
:
FalconConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
hidden_size
=
config
.
hidden_size
hidden_size
=
config
.
hidden_size
...
@@ -201,8 +202,7 @@ class FalconMLP(nn.Module):
...
@@ -201,8 +202,7 @@ class FalconMLP(nn.Module):
4
*
hidden_size
,
4
*
hidden_size
,
bias
=
config
.
bias
,
bias
=
config
.
bias
,
skip_bias_add
=
True
,
skip_bias_add
=
True
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
"gelu"
,
quant_config
,
4
*
hidden_size
)
self
.
act
=
get_act_fn
(
"gelu"
,
quant_config
,
4
*
hidden_size
)
self
.
reduce_row_parallel_results
=
not
(
config
.
new_decoder_architecture
self
.
reduce_row_parallel_results
=
not
(
config
.
new_decoder_architecture
or
config
.
parallel_attn
)
or
config
.
parallel_attn
)
...
@@ -212,7 +212,7 @@ class FalconMLP(nn.Module):
...
@@ -212,7 +212,7 @@ class FalconMLP(nn.Module):
bias
=
config
.
bias
,
bias
=
config
.
bias
,
skip_bias_add
=
True
,
skip_bias_add
=
True
,
reduce_results
=
self
.
reduce_row_parallel_results
,
reduce_results
=
self
.
reduce_row_parallel_results
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
# NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
...
@@ -229,13 +229,13 @@ class FalconDecoderLayer(nn.Module):
...
@@ -229,13 +229,13 @@ class FalconDecoderLayer(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
FalconConfig
,
config
:
FalconConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
hidden_size
=
config
.
hidden_size
hidden_size
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
num_heads
=
config
.
num_attention_heads
self
.
self_attention
=
FalconAttention
(
config
,
linear_method
)
self
.
self_attention
=
FalconAttention
(
config
,
quant_config
)
self
.
mlp
=
FalconMLP
(
config
,
linear_method
)
self
.
mlp
=
FalconMLP
(
config
,
quant_config
)
self
.
config
=
config
self
.
config
=
config
if
config
.
new_decoder_architecture
:
if
config
.
new_decoder_architecture
:
...
@@ -311,7 +311,7 @@ class FalconModel(nn.Module):
...
@@ -311,7 +311,7 @@ class FalconModel(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
FalconConfig
,
config
:
FalconConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -327,7 +327,7 @@ class FalconModel(nn.Module):
...
@@ -327,7 +327,7 @@ class FalconModel(nn.Module):
# Transformer blocks
# Transformer blocks
self
.
h
=
nn
.
ModuleList
([
self
.
h
=
nn
.
ModuleList
([
FalconDecoderLayer
(
config
,
linear_method
)
FalconDecoderLayer
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
])
...
@@ -359,12 +359,12 @@ class FalconForCausalLM(nn.Module):
...
@@ -359,12 +359,12 @@ class FalconForCausalLM(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
FalconConfig
,
config
:
FalconConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
transformer
=
FalconModel
(
config
,
linear_method
)
self
.
transformer
=
FalconModel
(
config
,
quant_config
)
self
.
lm_head_weight
=
self
.
transformer
.
word_embeddings
.
weight
self
.
lm_head_weight
=
self
.
transformer
.
word_embeddings
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
Prev
1
…
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment