Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1591c68f
Commit
1591c68f
authored
May 25, 2024
by
zhuwenwen
Browse files
merge v0.4.2
parents
09bcf00b
c7f2cf2b
Changes
265
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1447 additions
and
509 deletions
+1447
-509
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+186
-60
vllm/model_executor/layers/quantization/gptq.py
vllm/model_executor/layers/quantization/gptq.py
+11
-8
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+438
-0
vllm/model_executor/layers/quantization/marlin.py
vllm/model_executor/layers/quantization/marlin.py
+8
-5
vllm/model_executor/layers/quantization/squeezellm.py
vllm/model_executor/layers/quantization/squeezellm.py
+13
-10
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+139
-3
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+404
-209
vllm/model_executor/layers/vocab_parallel_embedding.py
vllm/model_executor/layers/vocab_parallel_embedding.py
+8
-0
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+29
-21
vllm/model_executor/model_loader/tensorizer.py
vllm/model_executor/model_loader/tensorizer.py
+28
-24
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+42
-33
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+7
-6
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+22
-21
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+16
-16
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+19
-18
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+17
-16
vllm/model_executor/models/dbrx.py
vllm/model_executor/models/dbrx.py
+18
-17
vllm/model_executor/models/decilm.py
vllm/model_executor/models/decilm.py
+4
-3
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+22
-23
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+16
-16
No files found.
vllm/model_executor/layers/quantization/fp8.py
View file @
1591c68f
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
set_weight_attrs
)
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
class
FP8Config
(
QuantizationConfig
):
logger
=
init_logger
(
__name__
)
class
Fp8Config
(
QuantizationConfig
):
"""Config class for FP8."""
def
__init__
(
self
,
is_checkpoint_fp8_serialized
:
bool
=
False
,
activation_scheme
:
str
=
"dynamic"
,
)
->
None
:
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
if
is_checkpoint_fp8_serialized
:
logger
.
warning
(
"Detected fp8 checkpoint. Please note that the "
"format is experimental and subject to change."
)
if
activation_scheme
not
in
ACTIVATION_SCHEMES
:
raise
ValueError
(
f
"Unsupported activation scheme
{
activation_scheme
}
"
)
self
.
activation_scheme
=
activation_scheme
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"fp8"
...
...
@@ -23,21 +43,25 @@ class FP8Config(QuantizationConfig):
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# TODO: PyTorch 2.3.0+ is required to run FP8 on
# SM 89 (e.g. Ada) GPUs. Specifically, this PR has to
# be included: https://github.com/pytorch/pytorch/pull/118881
return
90
return
89
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"FP8Config"
:
return
cls
()
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"Fp8Config"
:
quant_method
=
cls
.
get_from_keys
(
config
,
[
"quant_method"
])
is_checkpoint_fp8_serialized
=
(
"fp8"
in
quant_method
)
activation_scheme
=
cls
.
get_from_keys
(
config
,
[
"activation_scheme"
])
return
cls
(
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
,
activation_scheme
=
activation_scheme
)
def
get_linear_method
(
self
)
->
"Fp8LinearMethod"
:
return
Fp8LinearMethod
(
self
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"Fp8LinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
Fp8LinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
...
...
@@ -45,8 +69,12 @@ class FP8Config(QuantizationConfig):
class
Fp8LinearMethod
(
LinearMethodBase
):
"""Linear method for FP8.
We now support common FP16/BF16 model checkpoints ONLY. The weight
scaling factor will be initialized after the model weights are loaded.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support.
...
...
@@ -57,9 +85,27 @@ class Fp8LinearMethod(LinearMethodBase):
quant_config: The quantization config.
"""
def
__init__
(
self
,
quant_config
:
F
P
8Config
):
def
__init__
(
self
,
quant_config
:
F
p
8Config
):
self
.
quant_config
=
quant_config
def
_create_scale_param
(
self
,
scale_name
:
str
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
List
[
int
],
**
extra_weight_attrs
,
)
->
None
:
scale
=
Parameter
(
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
scale_name
,
scale
)
set_weight_attrs
(
scale
,
{
**
extra_weight_attrs
,
"fp8_scales_shard_indexer"
:
self
.
scales_shard_indexer
,
})
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -70,70 +116,150 @@ class Fp8LinearMethod(LinearMethodBase):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
del
input_size
,
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
layer
.
process_after_load
=
True
layer
.
logical_widths
=
output_partition_sizes
# WEIGHT
weight_dtype
=
(
torch
.
float8_e4m3fn
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
else
params_dtype
)
weight
=
Parameter
(
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
params
_dtype
),
dtype
=
weight
_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
})
set_weight_attrs
(
weight
,
extra_weight_attrs
)
set_weight_attrs
(
weight
,
{
**
extra_weight_attrs
,
"input_dim"
:
1
,
"output_dim"
:
0
,
})
w_scale
=
Parameter
(
torch
.
empty
(
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"weight_scaling_factor"
,
w_scale
)
# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# WEIGHT SCALE
self
.
_create_scale_param
(
scale_name
=
"weight_scale"
,
layer
=
layer
,
output_partition_sizes
=
output_partition_sizes
,
**
extra_weight_attrs
)
# ACTIVATION SCALE
if
self
.
quant_config
.
activation_scheme
==
"static"
:
self
.
_create_scale_param
(
scale_name
=
"act_scale"
,
layer
=
layer
,
output_partition_sizes
=
output_partition_sizes
,
**
extra_weight_attrs
)
def
scales_shard_indexer
(
self
,
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
shard_id
:
Union
[
str
,
int
])
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
qkv_idxs
=
{
"q"
:
0
,
"k"
:
1
,
"v"
:
2
}
if
isinstance
(
shard_id
,
int
):
pass
elif
isinstance
(
shard_id
,
str
):
if
shard_id
not
in
qkv_idxs
:
raise
ValueError
(
f
"Unknown shard_id:
{
shard_id
}
"
)
shard_id
=
qkv_idxs
[
shard_id
]
else
:
ValueError
(
f
"Shard id must be int or str but got
{
type
(
shard_id
)
}
"
)
return
param
[
shard_id
],
loaded_weight
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# Although the linear_method is propagated to all layers,
# only linear layers invoke "create_weights". So we check
# whether "weight_scaling_facor" is registered to determine
# whether the layer is a linear layer that requires quantization.
if
not
hasattr
(
layer
,
"weight_scaling_factor"
):
if
(
not
hasattr
(
layer
,
"process_after_load"
)
or
not
layer
.
process_after_load
):
return
# If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
scale
=
None
)
layer
.
weight
=
Parameter
(
qweight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
logical_widths
=
None
layer
.
act_scale
=
None
return
qweight
,
weight_scale
=
per_tensor_quantize
(
layer
.
weight
)
# torch._scaled_mm requires column-major in the second
# input (weight), so we transpose the quantized weight.
layer
.
weight
=
Parameter
(
qweight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scaling_factor
.
data
.
copy_
(
weight_scale
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
qinput
,
x_scale
=
per_tensor_quantize
(
x
)
# If checkpoint is fp8, requantize the separately quantized logical
# weights into a single fp8 weight with a single weight scale.
else
:
# WEIGHT_SCALE / WEIGHT
# Loop over logical weights, requantizing with single scale.
max_w_scale
=
layer
.
weight_scale
.
max
()
start
=
0
for
idx
,
logical_width
in
enumerate
(
layer
.
logical_widths
):
end
=
start
+
logical_width
weight_dq
=
per_tensor_dequantize
(
layer
.
weight
[
start
:
end
,
:],
layer
.
weight_scale
[
idx
])
layer
.
weight
[
start
:
end
,
:]
=
per_tensor_quantize
(
weight_dq
,
layer
.
weight_scale
.
max
())
start
=
end
layer
.
weight_scale
=
Parameter
(
max_w_scale
,
requires_grad
=
False
)
# WEIGHT
# Transpose weight for passing to torch._scaled_mm
weight
=
layer
.
weight
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
# ACT_SCALE
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
# Static: set to max of the act_scales (since they are equal).
if
self
.
quant_config
.
activation_scheme
==
"dynamic"
:
layer
.
act_scale
=
None
elif
self
.
quant_config
.
activation_scheme
==
"static"
:
if
not
all_close_1d
(
layer
.
act_scale
):
raise
ValueError
(
"All the act_scales for the logical weights of a layer "
f
"must be equal. But got
{
layer
.
act_scale
}
"
)
layer
.
act_scale
=
Parameter
(
layer
.
act_scale
.
max
(),
requires_grad
=
False
)
else
:
raise
ValueError
(
f
"Unknown scheme
{
self
.
quant_config
.
activation_scheme
}
"
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.act_scale is None and x_scale computed from x.
# If static, layer.act_scale is scalar and x_scale set to act_scale.
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
x
,
layer
.
act_scale
)
# Fused GEMM_DQ
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
layer
.
weight
,
out_dtype
=
x
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scal
ing_factor
,
scale_b
=
layer
.
weight_scal
e
,
bias
=
bias
,
)
return
output
def
per_tensor_quantize
(
tensor
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
float
]:
"""Quantize a tensor using per-tensor static scaling factor.
def
all_close_1d
(
x
:
torch
.
Tensor
)
->
bool
:
assert
len
(
x
.
shape
)
==
1
return
all
(
torch
.
allclose
(
x
[
0
],
x
[
i
])
for
i
in
range
(
x
.
shape
[
0
]))
Args:
tensor: The input t
ensor
.
"""
def
per_tensor_quantize
(
tensor
:
torch
.
T
ensor
,
inv_scale
:
float
)
->
torch
.
Tensor
:
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
# Calculate the scale as dtype max divided by absmax.
# Since .abs() creates a new tensor, we use aminmax to get
# the min and max first and then calculate the absmax.
min_val
,
max_val
=
tensor
.
aminmax
()
amax
=
min_val
.
abs
().
max
(
max_val
.
abs
())
scale
=
finfo
.
max
/
amax
.
clamp
(
min
=
1e-12
)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
qweight
=
(
tensor
*
scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight
=
qweight
.
to
(
torch
.
float8_e4m3fn
)
scale
=
scale
.
float
().
reciprocal
()
return
qweight
,
scale
qweight
=
(
tensor
/
inv_scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
return
qweight
.
to
(
torch
.
float8_e4m3fn
)
def
per_tensor_dequantize
(
tensor
:
torch
.
Tensor
,
inv_scale
:
float
)
->
torch
.
Tensor
:
fake_qweight
=
tensor
.
to
(
torch
.
float16
)
dq_weight
=
fake_qweight
*
inv_scale
return
dq_weight
vllm/model_executor/layers/quantization/gptq.py
View file @
1591c68f
...
...
@@ -7,10 +7,10 @@ import torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
class
GPTQConfig
(
QuantizationConfig
):
...
...
@@ -63,8 +63,11 @@ class GPTQConfig(QuantizationConfig):
desc_act
=
cls
.
get_from_keys
(
config
,
[
"desc_act"
])
return
cls
(
weight_bits
,
group_size
,
desc_act
)
def
get_linear_method
(
self
)
->
"GPTQLinearMethod"
:
return
GPTQLinearMethod
(
self
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"GPTQLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
GPTQLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
...
...
@@ -194,10 +197,10 @@ class GPTQLinearMethod(LinearMethodBase):
layer
.
exllama_state
=
exllama_state
def
apply
_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
qweight
=
layer
.
qweight
out_shape
=
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
-
1
],
)
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
0 → 100644
View file @
1591c68f
import
enum
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
GPTQ_MARLIN_TILE
=
16
GPTQ_MARLIN_MIN_THREAD_N
=
64
GPTQ_MARLIN_MIN_THREAD_K
=
128
GPTQ_MARLIN_MAX_PARALLEL
=
16
GPTQ_MARLIN_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
GPTQ_MARLIN_SUPPORTED_SYM
=
[
True
]
# Permutations for Marlin scale shuffling
def
get_scale_perms
(
num_bits
):
scale_perm
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
scale_perm
,
scale_perm_single
def
get_pack_factor
(
num_bits
):
assert
(
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
),
f
"Unsupported num_bits =
{
num_bits
}
"
return
32
//
num_bits
def
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
num_bits
):
scale_perm
,
scale_perm_single
=
get_scale_perms
(
num_bits
)
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
else
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s
class
GPTQMarlinConfig
(
QuantizationConfig
):
"""Config class for GPTQ Marlin"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
is_sym
:
bool
)
->
None
:
if
desc_act
and
group_size
==
-
1
:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act
=
False
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
self
.
is_sym
=
is_sym
# Verify
if
self
.
weight_bits
not
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
:
raise
ValueError
(
f
"Marlin does not support weight_bits =
{
self
.
weight_bits
}
. "
f
"Only weight_bits =
{
GPTQ_MARLIN_SUPPORTED_NUM_BITS
}
"
"are supported."
)
if
self
.
group_size
not
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
:
raise
ValueError
(
f
"Marlin does not support group_size =
{
self
.
group_size
}
. "
f
"Only group_sizes =
{
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
}
"
"are supported."
)
if
self
.
is_sym
not
in
GPTQ_MARLIN_SUPPORTED_SYM
:
raise
ValueError
(
f
"Marlin does not support is_sym =
{
self
.
is_sym
}
. "
f
"Only sym =
{
GPTQ_MARLIN_SUPPORTED_SYM
}
are supported."
)
# Init
self
.
pack_factor
=
get_pack_factor
(
weight_bits
)
self
.
tile_size
=
GPTQ_MARLIN_TILE
self
.
min_thread_n
=
GPTQ_MARLIN_MIN_THREAD_N
self
.
min_thread_k
=
GPTQ_MARLIN_MIN_THREAD_K
self
.
max_parallel
=
GPTQ_MARLIN_MAX_PARALLEL
def
__repr__
(
self
)
->
str
:
return
(
f
"GPTQMarlinConfig(weight_bits=
{
self
.
weight_bits
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"desc_act=
{
self
.
desc_act
}
)"
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"gptq_marlin"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"quantize_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"GPTQMarlinConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
desc_act
=
cls
.
get_from_keys
(
config
,
[
"desc_act"
])
is_sym
=
cls
.
get_from_keys
(
config
,
[
"sym"
])
return
cls
(
weight_bits
,
group_size
,
desc_act
,
is_sym
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"GPTQMarlinLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
GPTQMarlinLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
@
classmethod
def
is_marlin_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
# Extract data from quant config.
num_bits
=
quant_config
.
get
(
"bits"
,
None
)
group_size
=
quant_config
.
get
(
"group_size"
,
None
)
sym
=
quant_config
.
get
(
"sym"
,
None
)
desc_act
=
quant_config
.
get
(
"desc_act"
,
None
)
# If we cannot find the info needed in the config, cannot convert.
if
(
num_bits
is
None
or
group_size
is
None
or
sym
is
None
or
desc_act
is
None
):
return
False
# If the capability of the device is too low, cannot convert.
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
device_capability
=
major
*
10
+
minor
if
device_capability
<
cls
.
get_min_capability
():
return
False
# Otherwise, can convert if model satisfies marlin constraints.
return
(
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
and
group_size
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
and
sym
in
GPTQ_MARLIN_SUPPORTED_SYM
)
class
GPTQMarlinState
(
Enum
):
REPACK
=
enum
.
auto
()
READY
=
enum
.
auto
()
class
GPTQMarlinLinearMethod
(
LinearMethodBase
):
"""Linear method for GPTQ Marlin.
Args:
quant_config: The GPTQ Marlin quantization config.
"""
def
__init__
(
self
,
quant_config
:
GPTQMarlinConfig
)
->
None
:
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
)
->
None
:
del
output_size
# Normalize group_size
if
self
.
quant_config
.
group_size
!=
-
1
:
group_size
=
self
.
quant_config
.
group_size
else
:
group_size
=
input_size
# Validate dtype
if
params_dtype
!=
torch
.
float16
:
raise
ValueError
(
f
"The params dtype must be float16, but got
{
params_dtype
}
"
)
# Validate output_size_per_partition
output_size_per_partition
=
sum
(
output_partition_sizes
)
if
output_size_per_partition
%
self
.
quant_config
.
min_thread_n
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"
{
output_size_per_partition
}
is not divisible by "
f
" min_thread_n =
{
self
.
quant_config
.
min_thread_n
}
."
)
# Validate input_size_per_partition
if
input_size_per_partition
%
self
.
quant_config
.
min_thread_k
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible "
f
"by min_thread_k =
{
self
.
quant_config
.
min_thread_k
}
."
)
if
(
group_size
<
input_size
and
input_size_per_partition
%
group_size
!=
0
):
raise
ValueError
(
f
"Weight input_size_per_partition =
{
input_size_per_partition
}
"
f
" is not divisible by group_size =
{
group_size
}
."
)
# Detect sharding of scales/zp
# By default, no sharding over "input dim"
scales_and_zp_size
=
input_size
//
group_size
scales_and_zp_input_dim
=
None
if
self
.
quant_config
.
desc_act
:
# Act-order case
assert
self
.
quant_config
.
group_size
!=
-
1
is_k_full
=
input_size_per_partition
==
input_size
else
:
# No act-order case
# K is always full due to full alignment with
# group-size and shard of scales/zp
is_k_full
=
True
# If this is a row-parallel case, then shard scales/zp
if
(
input_size
!=
input_size_per_partition
and
self
.
quant_config
.
group_size
!=
-
1
):
scales_and_zp_size
=
input_size_per_partition
//
group_size
scales_and_zp_input_dim
=
0
# Init buffers
# Quantized weights
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
**
extra_weight_attrs
,
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
0
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
},
)
# Activation order
g_idx
=
Parameter
(
torch
.
empty
(
input_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs
(
g_idx
,
{
**
extra_weight_attrs
,
"input_dim"
:
0
,
"ignore_warning"
:
True
},
)
g_idx_sort_indices
=
Parameter
(
torch
.
empty
(
g_idx
.
shape
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
g_idx_sort_indices
,
extra_weight_attrs
)
# Scales
scales
=
Parameter
(
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
**
extra_weight_attrs
,
"input_dim"
:
scales_and_zp_input_dim
,
"output_dim"
:
1
,
},
)
# Quantized zero-points
qzeros
=
Parameter
(
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
device
=
"meta"
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qzeros
,
{
**
extra_weight_attrs
,
"input_dim"
:
scales_and_zp_input_dim
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
},
)
# Allocate marlin workspace
max_workspace_size
=
(
output_size_per_partition
//
self
.
quant_config
.
min_thread_n
)
*
self
.
quant_config
.
max_parallel
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
requires_grad
=
False
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"g_idx"
,
g_idx
)
layer
.
register_parameter
(
"g_idx_sort_indices"
,
g_idx_sort_indices
)
layer
.
register_parameter
(
"scales"
,
scales
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
layer
.
workspace
=
workspace
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
input_size
=
input_size
layer
.
is_k_full
=
is_k_full
layer
.
marlin_state
=
GPTQMarlinState
.
REPACK
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
size_m
=
reshaped_x
.
shape
[
0
]
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
full_size_k
=
layer
.
input_size
out_shape
=
x
.
shape
[:
-
1
]
+
(
part_size_n
,
)
if
layer
.
marlin_state
==
GPTQMarlinState
.
REPACK
:
layer
.
marlin_state
=
GPTQMarlinState
.
READY
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def
replace_tensor
(
name
,
new_t
):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr
(
layer
,
name
).
resize_
(
new_t
.
shape
)
getattr
(
layer
,
name
).
copy_
(
new_t
)
del
new_t
cur_device
=
layer
.
qweight
.
device
# Process act_order
if
self
.
quant_config
.
desc_act
:
# Get sorting based on g_idx
g_idx_sort_indices
=
torch
.
argsort
(
layer
.
g_idx
).
to
(
torch
.
int
)
sorted_g_idx
=
layer
.
g_idx
[
g_idx_sort_indices
]
replace_tensor
(
"g_idx"
,
sorted_g_idx
)
replace_tensor
(
"g_idx_sort_indices"
,
g_idx_sort_indices
)
else
:
# Reset g_idx related tensors
layer
.
g_idx
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
layer
.
g_idx_sort_indices
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
# Repack weights
marlin_qweight
=
ops
.
gptq_marlin_repack
(
layer
.
qweight
,
layer
.
g_idx_sort_indices
,
part_size_k
,
part_size_n
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"qweight"
,
marlin_qweight
)
# Permute scales
scales_size_k
=
part_size_k
scales_size_n
=
part_size_n
if
self
.
quant_config
.
desc_act
:
scales_size_k
=
full_size_k
marlin_scales
=
marlin_permute_scales
(
layer
.
scales
,
scales_size_k
,
scales_size_n
,
self
.
quant_config
.
group_size
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"scales"
,
marlin_scales
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
layer
.
qweight
,
layer
.
scales
,
layer
.
g_idx
,
layer
.
g_idx_sort_indices
,
layer
.
workspace
,
self
.
quant_config
.
weight_bits
,
size_m
,
part_size_n
,
part_size_k
,
layer
.
is_k_full
,
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
vllm/model_executor/layers/quantization/marlin.py
View file @
1591c68f
...
...
@@ -4,10 +4,10 @@ import torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
class
MarlinConfig
(
QuantizationConfig
):
...
...
@@ -72,8 +72,11 @@ class MarlinConfig(QuantizationConfig):
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
return
cls
(
group_size
)
def
get_linear_method
(
self
)
->
"MarlinLinearMethod"
:
return
MarlinLinearMethod
(
self
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"MarlinLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
MarlinLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
...
...
@@ -197,7 +200,7 @@ class MarlinLinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"workspace"
,
workspace
)
set_weight_attrs
(
workspace
,
extra_weight_attrs
)
def
apply
_weights
(
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/squeezellm.py
View file @
1591c68f
...
...
@@ -4,10 +4,10 @@ import torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
is_hip
...
...
@@ -51,14 +51,17 @@ class SqueezeLLMConfig(QuantizationConfig):
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"wbits"
])
return
cls
(
weight_bits
)
def
get_linear_method
(
self
)
->
"SqueezeLLMLinearMethod"
:
return
SqueezeLLMLinearMethod
(
self
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
QuantizeMethodBase
]:
if
isinstance
(
layer
,
LinearBase
):
return
SqueezeLLMLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
SqueezeLLMLinearMethod
(
Linear
MethodBase
):
class
SqueezeLLMLinearMethod
(
Quantize
MethodBase
):
"""Linear method for SqueezeLLM.
Args:
...
...
@@ -112,10 +115,10 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"lookup_table"
,
lookup_table
)
set_weight_attrs
(
lookup_table
,
extra_weight_attrs
)
def
apply
_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
qweight
=
layer
.
qweight
lookup_table
=
layer
.
lookup_table
out_shape
=
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
-
1
],
)
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
1591c68f
...
...
@@ -156,6 +156,12 @@ class RotaryEmbedding(nn.Module):
self
.
cos_sin_cache
,
self
.
is_neox_style
)
return
query
,
key
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
head_size
}
, rotary_dim=
{
self
.
rotary_dim
}
"
s
+=
f
", max_position_embeddings=
{
self
.
max_position_embeddings
}
"
s
+=
f
", base=
{
self
.
base
}
, is_neox_style=
{
self
.
is_neox_style
}
"
return
s
class
LinearScalingRotaryEmbedding
(
RotaryEmbedding
):
"""RotaryEmbedding extended with linear scaling.
...
...
@@ -338,6 +344,114 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
return
cache
class
Phi3SuScaledRotaryEmbedding
(
nn
.
Module
):
"""Phi3 family of models scaled rotary embedding.
Based on the original RotaryEmbedding implementation.
"""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
original_max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
short_factor
:
List
[
float
],
long_factor
:
List
[
float
],
short_mscale
:
float
=
1.1
,
long_mscale
:
float
=
1.225
,
):
super
().
__init__
()
if
rotary_dim
!=
head_size
:
raise
ValueError
(
f
"`Phi3SuScaledRotaryEmbedding` does not support rotary_dim !=
\
head_size (
{
rotary_dim
}
!=
{
head_size
}
)."
)
if
is_neox_style
is
False
:
raise
ValueError
(
"`Phi3SuScaledRotaryEmbedding` only supports neox_style."
)
self
.
head_size
=
head_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
original_max_position_embeddings
=
original_max_position_embeddings
self
.
base
=
base
self
.
short_factor
=
short_factor
self
.
long_factor
=
long_factor
self
.
short_mscale
=
short_mscale
self
.
long_mscale
=
long_mscale
short_cache
=
self
.
_compute_cos_sin_cache
(
original_max_position_embeddings
,
short_factor
,
short_mscale
)
short_cache
=
short_cache
.
to
(
torch
.
get_default_dtype
())
self
.
register_buffer
(
"short_cos_sin_cache"
,
short_cache
,
persistent
=
False
)
long_cache
=
self
.
_compute_cos_sin_cache
(
max_position_embeddings
,
long_factor
,
long_mscale
)
long_cache
=
long_cache
.
to
(
torch
.
get_default_dtype
())
self
.
register_buffer
(
"long_cos_sin_cache"
,
long_cache
,
persistent
=
False
)
long_short_cache
=
torch
.
cat
(
[
self
.
short_cos_sin_cache
,
self
.
long_cos_sin_cache
],
dim
=
0
)
self
.
register_buffer
(
"long_short_cos_sin_cache"
,
long_short_cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
rescale_factors
:
List
[
float
])
->
torch
.
Tensor
:
rescale_factors
=
torch
.
tensor
(
rescale_factors
,
dtype
=
torch
.
float32
)
inv_freq
=
1.0
/
(
rescale_factors
*
(
self
.
base
**
(
torch
.
arange
(
0
,
self
.
head_size
,
2
,
dtype
=
torch
.
float
)
/
self
.
head_size
)))
return
inv_freq
def
_compute_cos_sin_cache
(
self
,
max_position_embeddings
:
int
,
rescale_factors
:
List
[
float
],
mscale
:
float
,
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
rescale_factors
)
t
=
torch
.
arange
(
max_position_embeddings
,
dtype
=
torch
.
float
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
*
mscale
sin
=
freqs
.
sin
()
*
mscale
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
def
forward
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
query
=
query
.
view
(
*
query
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
key
=
key
.
view
(
*
key
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
k
=
self
.
original_max_position_embeddings
long_prompt_offset
=
(
torch
.
any
(
positions
>
k
).
float
()
*
torch
.
full_like
(
positions
,
k
)).
long
()
idx
=
(
torch
.
add
(
positions
,
long_prompt_offset
)
if
long_prompt_offset
is
not
None
else
positions
)
self
.
long_short_cos_sin_cache
:
torch
.
Tensor
=
(
self
.
long_short_cos_sin_cache
.
to
(
idx
.
device
))
idx
=
torch
.
add
(
idx
,
offsets
)
if
offsets
is
not
None
else
idx
cos_sin
=
torch
.
index_select
(
self
.
long_short_cos_sin_cache
,
0
,
idx
)
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
cos
=
cos
.
repeat
(
1
,
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
2
).
unsqueeze
(
-
2
)
query
=
query
*
cos
+
_rotate_neox
(
query
)
*
sin
key
=
key
*
cos
+
_rotate_neox
(
key
)
*
sin
return
query
.
flatten
(
-
2
),
key
.
flatten
(
-
2
)
_ROPE_DICT
:
Dict
[
Tuple
,
RotaryEmbedding
]
=
{}
...
...
@@ -349,17 +463,26 @@ def get_rope(
is_neox_style
:
bool
=
True
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
RotaryEmbedding
:
if
rope_scaling
is
not
None
:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple
=
{
k
:
tuple
(
v
)
if
isinstance
(
v
,
list
)
else
v
for
k
,
v
in
rope_scaling
.
items
()
}
rope_scaling_args
=
tuple
(
rope_scaling_tuple
.
items
())
else
:
rope_scaling_args
=
None
key
=
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
tuple
(
rope_scaling
.
items
())
if
rope_scaling
is
not
None
else
None
)
rope_scaling
_args
)
if
key
in
_ROPE_DICT
:
return
_ROPE_DICT
[
key
]
if
rope_scaling
is
None
:
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
)
else
:
scaling_type
=
rope_scaling
[
"type"
]
scaling_factor
=
rope_scaling
[
"factor"
]
if
scaling_type
!=
"su"
:
scaling_factor
=
rope_scaling
[
"factor"
]
if
scaling_type
==
"linear"
:
rotary_emb
=
LinearScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
...
...
@@ -383,6 +506,19 @@ def get_rope(
base
,
is_neox_style
,
scaling_factor
,
**
extra_kwargs
)
elif
scaling_type
==
"su"
:
short_factor
=
rope_scaling
[
"short_factor"
]
long_factor
=
rope_scaling
[
"long_factor"
]
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
extra_kwargs
=
{
k
:
v
for
k
,
v
in
rope_scaling
.
items
()
if
k
in
(
"short_mscale"
,
"long_mscale"
)
}
rotary_emb
=
Phi3SuScaledRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
original_max_position
,
base
,
is_neox_style
,
short_factor
,
long_factor
,
**
extra_kwargs
)
else
:
raise
ValueError
(
f
"Unknown RoPE scaling type
{
scaling_type
}
"
)
_ROPE_DICT
[
key
]
=
rotary_emb
...
...
vllm/model_executor/layers/sampler.py
View file @
1591c68f
...
...
@@ -7,11 +7,14 @@ import torch.nn as nn
from
vllm.model_executor.layers.ops.sample
import
sample
as
sample_triton
from
vllm.model_executor.sampling_metadata
import
(
SamplingMetadata
,
SamplingTensors
)
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
SamplingTensors
,
SequenceGroupToSample
)
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
(
Logprob
,
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
SequenceData
,
SequenceGroupOutput
,
SequenceOutput
)
SamplerOutput
,
SequenceGroupOutput
,
SequenceOutput
)
# (num_token_ids, num_parent_ids) per sequence group.
SampleResultType
=
List
[
Tuple
[
List
[
int
],
List
[
int
]]]
class
Sampler
(
nn
.
Module
):
...
...
@@ -48,11 +51,14 @@ class Sampler(nn.Module):
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
"""
Args:
logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling.
"""
assert
logits
is
not
None
_
,
vocab_size
=
logits
.
shape
# Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
# have not been generated yet
logits
=
_apply_min_tokens_penalty
(
logits
,
sampling_metadata
)
# Prepare sampling tensors with pinned memory to avoid blocking.
...
...
@@ -83,7 +89,6 @@ class Sampler(nn.Module):
# Compute the probabilities.
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
# Compute the log probabilities.
# Use log_softmax to ensure numerical stability.
logprobs
=
torch
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
# Sample the next tokens.
...
...
@@ -98,8 +103,7 @@ class Sampler(nn.Module):
if
self
.
include_gpu_probs_tensor
:
assert
maybe_sampled_tokens_tensor
is
not
None
sampled_tokens_tensor
=
maybe_sampled_tokens_tensor
on_device_tensors
=
(
probs
,
sampled_tokens_tensor
)
on_device_tensors
=
(
probs
,
logprobs
,
maybe_sampled_tokens_tensor
)
else
:
on_device_tensors
=
None
...
...
@@ -149,46 +153,46 @@ def _apply_min_tokens_penalty(
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
"""Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
have not been generated yet
"""
# list of indices in logits that will be set to -inf
logits_to_penalize
=
[]
start_idx
=
0
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
)
:
seq_ids
,
sampling_params
=
seq_group
# handle prompt_logprobs by skipping rows in logits added for the prompt
# tokens (prompt logprobs are not penalized)
if
(
i
<
sampling_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
assert
len
(
seq_ids
)
==
1
start_idx
+=
sampling_metadata
.
prompt_lens
[
i
]
-
1
logits_to_penalize
:
List
[
Tuple
[
int
,
int
]]
=
[]
logits_applied
=
0
for
seq_group
in
sampling_metadata
.
seq_groups
:
seq_ids
=
seq_group
.
seq_ids
sampling_params
=
seq_group
.
sampling_params
sample_indices
=
seq_group
.
sample_indices
logits_applied
+=
len
(
sample_indices
)
+
len
(
seq_group
.
prompt_logprob_indices
)
if
not
seq_group
.
do_sample
:
continue
start_idx
=
sample_indices
[
0
]
min_tokens
=
sampling_params
.
min_tokens
if
min_tokens
>
0
:
token_ids_to_penalize
=
sampling_params
.
all_stop_token_ids
if
min_tokens
>
0
and
token_ids_to_penalize
:
seqs_to_penalize
=
[]
for
i
,
seq_id
in
enumerate
(
seq_ids
):
seq_data
=
s
ampling_metadata
.
seq_data
[
seq_id
]
for
j
,
seq_id
in
enumerate
(
seq_ids
):
seq_data
=
s
eq_group
.
seq_data
[
seq_id
]
if
len
(
seq_data
.
output_token_ids
)
<
min_tokens
:
seqs_to_penalize
.
append
(
i
)
seqs_to_penalize
.
append
(
j
)
if
seqs_to_penalize
:
# convert to the index into logits
seqs_to_penalize
=
[
start_idx
+
i
for
i
in
seqs_to_penalize
]
# use set() to remove any duplicates
token_ids_to_penalize
=
set
(
sampling_params
.
stop_token_ids
+
[
sampling_params
.
eos_token_id
])
seqs_to_penalize
=
[
start_idx
+
j
for
j
in
seqs_to_penalize
]
# itertools.product pairs each seq index with every token id
logits_to_penalize
.
extend
(
itertools
.
product
(
seqs_to_penalize
,
token_ids_to_penalize
))
start_idx
+=
len
(
seq_ids
)
if
logits_to_penalize
:
# use zip and * to group indices along each dimension
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
logits
[
tuple
(
zip
(
*
logits_to_penalize
))]
=
-
float
(
"inf"
)
# verifies that no rows in logits were missed unexpectedly
assert
start_idx
==
logits
.
shape
[
0
]
assert
logits_applied
==
logits
.
shape
[
0
]
return
logits
...
...
@@ -265,14 +269,30 @@ def _apply_min_p(
def
_greedy_sample
(
selected_seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]
],
selected_seq_groups
:
List
[
SequenceGroupToSample
],
samples
:
torch
.
Tensor
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
)
->
SampleResultType
:
"""Run greedy sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
samples: (num_selected_samples,) A tensor of samples. The length of
samples could be smaller than selected_seq_groups if
seq_group.do_sample is False.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
samples
=
samples
.
tolist
()
sample_idx
=
0
results
=
[]
results
:
SampleResultType
=
[]
for
seq_group
in
selected_seq_groups
:
seq_ids
,
_
=
seq_group
if
not
seq_group
.
do_sample
:
results
.
append
(([],
[]))
continue
seq_ids
=
seq_group
.
seq_ids
num_parent_seqs
=
len
(
seq_ids
)
assert
num_parent_seqs
==
1
,
(
"Greedy sampling should have only one seq."
)
...
...
@@ -284,16 +304,33 @@ def _greedy_sample(
def
_random_sample
(
selected_seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
is_prompts
:
List
[
bool
],
selected_seq_groups
:
List
[
SequenceGroupToSample
],
random_samples
:
torch
.
Tensor
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
)
->
SampleResultType
:
"""Run random sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
random_samples: (num_selected_samples,) A tensor of samples. The
length of samples could be smaller than selected_seq_groups if
seq_group.do_sample is False.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
# Find the maximum best_of value of the prompt phase requests.
random_samples
=
random_samples
.
cpu
()
sample_idx
=
0
results
=
[]
for
seq_group
,
is_prompt
in
zip
(
selected_seq_groups
,
is_prompts
):
seq_ids
,
sampling_params
=
seq_group
results
:
SampleResultType
=
[]
for
seq_group
in
selected_seq_groups
:
if
not
seq_group
.
do_sample
:
results
.
append
(([],
[]))
continue
seq_ids
=
seq_group
.
seq_ids
sampling_params
=
seq_group
.
sampling_params
is_prompt
=
seq_group
.
is_prompt
num_parent_seqs
=
len
(
seq_ids
)
if
is_prompt
:
# Prompt phase.
...
...
@@ -311,11 +348,20 @@ def _random_sample(
def
_beam_search_sample
(
selected_seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
is_prompts
:
List
[
bool
],
seq_data
:
Dict
[
int
,
SequenceData
],
selected_seq_groups
:
List
[
SequenceGroupToSample
],
logprobs
:
torch
.
Tensor
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
)
->
SampleResultType
:
"""Run beam sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
on selected sample indices.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
# We sample 2 * beam_width candidates to make sure that with high
# probability we can get `beam_width` candidates in addition to
# the finished sequences for the next iteration. See
...
...
@@ -326,9 +372,14 @@ def _beam_search_sample(
# NOTE: Beam search is not vectorized, so its speed can be slower than
# other sampling methods.
sample_idx
=
0
results
=
[]
for
seq_group
,
is_prompt
in
zip
(
selected_seq_groups
,
is_prompts
):
seq_ids
,
sampling_params
=
seq_group
results
:
SampleResultType
=
[]
for
seq_group
in
selected_seq_groups
:
if
not
seq_group
.
do_sample
:
results
.
append
(([],
[]))
continue
is_prompt
=
seq_group
.
is_prompt
seq_ids
,
sampling_params
=
seq_group
.
seq_ids
,
seq_group
.
sampling_params
num_parent_seqs
=
len
(
seq_ids
)
beam_width
=
sampling_params
.
best_of
seq_group_logprobs
=
logprobs
[
sample_idx
:
sample_idx
+
num_parent_seqs
]
...
...
@@ -342,15 +393,16 @@ def _beam_search_sample(
next_token_ids
=
next_token_ids
.
tolist
()
else
:
# Generation phase.
cumulative_logprobs
=
[
seq_data
[
seq_id
].
cumulative_logprob
for
seq_id
in
seq_ids
cumulative_logprobs
:
List
[
int
]
=
[
seq_group
.
seq_data
[
seq_id
].
cumulative_logprob
for
seq_id
in
seq_ids
]
cumulative_logprobs
=
torch
.
tensor
(
cumulative_logprobs
_tensor
=
torch
.
tensor
(
cumulative_logprobs
,
dtype
=
torch
.
float
,
device
=
seq_group_logprobs
.
device
)
seq_group_logprobs
=
(
seq_group_logprobs
+
cumulative_logprobs
.
unsqueeze
(
dim
=
1
))
cumulative_logprobs
_tensor
.
unsqueeze
(
dim
=
1
))
_
,
topk_ids
=
torch
.
topk
(
seq_group_logprobs
.
flatten
(),
2
*
beam_width
)
topk_ids
=
topk_ids
.
tolist
()
...
...
@@ -371,8 +423,7 @@ def _beam_search_sample(
def
_multinomial
(
probs
:
torch
.
Tensor
,
num_samples
:
int
,
seq_groups
:
Optional
[
List
[
Tuple
[
List
[
int
],
SamplingParams
]]]
=
None
,
generators
:
Optional
[
List
[
torch
.
Generator
]]
=
None
,
seq_groups
:
Optional
[
List
[
SequenceGroupToSample
]]
=
None
,
)
->
torch
.
Tensor
:
if
num_samples
>
1
:
# This is equivalent to torch.repeat_interleaved (which also
...
...
@@ -388,9 +439,11 @@ def _multinomial(
q
.
exponential_
()
else
:
sample_idx
=
0
for
(
seq_ids
,
_
),
generator
in
zip
(
seq_groups
,
generators
):
for
seq_group
in
seq_groups
:
seq_ids
=
seq_group
.
seq_ids
next_sample_idx
=
sample_idx
+
len
(
seq_ids
)
*
num_samples
q
[
sample_idx
:
next_sample_idx
].
exponential_
(
generator
=
generator
)
q
[
sample_idx
:
next_sample_idx
].
exponential_
(
generator
=
seq_group
.
generator
)
sample_idx
=
next_sample_idx
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
...
...
@@ -401,11 +454,13 @@ def _sample_with_torch(
sampling_metadata
:
SamplingMetadata
,
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
,
)
->
Tuple
[
List
[
Tuple
[
List
[
int
],
List
[
int
]]],
Optional
[
torch
.
Tensor
]]:
categorized_seq_group_ids
=
{
t
:
[]
for
t
in
SamplingType
}
)
->
Tuple
[
SampleResultType
,
Optional
[
torch
.
Tensor
]]:
categorized_seq_group_ids
:
Dict
[
SamplingType
,
List
[
int
]]
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
_
,
sampling_params
=
seq_group
sampling_params
=
seq_group
.
sampling_params
sampling_type
=
sampling_params
.
sampling_type
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
...
...
@@ -429,13 +484,11 @@ def _sample_with_torch(
num_tokens
=
len
(
sample_indices
)
if
num_tokens
==
0
:
continue
seq_group_ids
=
categorized_seq_group_ids
[
sampling_type
]
seq_groups
=
[
sampling_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_ids
]
is_prompts
=
[
i
<
sampling_metadata
.
num_prompts
for
i
in
seq_group_ids
]
sample_metadata
[
sampling_type
]
=
(
seq_group_ids
,
seq_groups
,
is_prompts
,
sample_indices
)
long_sample_indices
=
sample_indices
.
long
()
seq_group_id
=
categorized_seq_group_ids
[
sampling_type
]
seq_groups
=
[
sampling_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_id
]
sample_metadata
[
sampling_type
]
=
(
seq_group_id
,
seq_groups
)
long_sample_indices
=
sample_indices
.
long
()
if
sampling_type
==
SamplingType
.
GREEDY
:
greedy_samples
=
torch
.
argmax
(
logprobs
[
long_sample_indices
],
dim
=-
1
)
...
...
@@ -455,14 +508,13 @@ def _sample_with_torch(
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
max_best_of_in_batch
=
1
for
seq_group
,
is_prompt
in
zip
(
seq_groups
,
is_prompts
)
:
if
is_prompt
:
_
,
sampling_params
=
seq_group
for
seq_group
in
seq_groups
:
if
seq_group
.
is_prompt
:
sampling_params
=
seq_group
.
sampling_params
max_best_of_in_batch
=
max
(
max_best_of_in_batch
,
sampling_params
.
best_of
)
seeded_args
=
{}
if
sampling_type
==
SamplingType
.
RANDOM
else
{
"seq_groups"
:
seq_groups
,
"generators"
:
sampling_metadata
.
generators
,
}
multinomial_samples
[
sampling_type
]
=
_multinomial
(
...
...
@@ -481,25 +533,22 @@ def _sample_with_torch(
# GPU<->CPU sync happens in the loop below.
# This also converts the sample output to Python objects.
for
sampling_type
in
SamplingType
:
if
sampling_type
not
in
sample_metadata
:
continue
seq_group_ids
,
seq_groups
,
is_prompts
,
sample_indices
=
sample_metadata
[
sampling_type
]
(
seq_group_id
,
seq_groups
)
=
sample_metadata
[
sampling_type
]
if
sampling_type
==
SamplingType
.
GREEDY
:
sample_results
=
_greedy_sample
(
seq_groups
,
greedy_samples
)
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
sample_results
=
_random_sample
(
seq_groups
,
is_prompts
,
sample_results
=
_random_sample
(
seq_groups
,
multinomial_samples
[
sampling_type
])
elif
sampling_type
==
SamplingType
.
BEAM
:
sample_results
=
_beam_search_sample
(
seq_groups
,
is_prompts
,
sampling_metadata
.
seq_data
,
sample_results
=
_beam_search_sample
(
seq_groups
,
beam_search_logprobs
)
sample_results_dict
.
update
(
zip
(
seq_group_id
s
,
sample_results
))
sample_results_dict
.
update
(
zip
(
seq_group_id
,
sample_results
))
sample_results
=
[
sample_results_dict
[
i
]
sample_results_dict
.
get
(
i
,
([],
[]))
for
i
in
range
(
len
(
sampling_metadata
.
seq_groups
))
]
return
sample_results
,
sampled_token_ids_tensor
...
...
@@ -510,11 +559,13 @@ def _sample_with_triton_kernel(
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
categorized_seq_group_ids
=
{
t
:
[]
for
t
in
SamplingType
}
)
->
SampleResultType
:
categorized_seq_group_ids
:
Dict
[
SamplingType
,
List
[
int
]]
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
_
,
sampling_params
=
seq_group
sampling_params
=
seq_group
.
sampling_params
sampling_type
=
sampling_params
.
sampling_type
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
...
...
@@ -530,17 +581,16 @@ def _sample_with_triton_kernel(
num_tokens
=
len
(
sample_indices
)
if
num_tokens
==
0
:
continue
seq_group_ids
=
categorized_seq_group_ids
[
sampling_type
]
seq_groups
=
[
sampling_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_ids
]
is_prompts
=
[
i
<
sampling_metadata
.
num_prompts
for
i
in
seq_group_ids
]
sample_metadata
[
sampling_type
]
=
(
seq_group_ids
,
seq_groups
,
is_prompts
,
sample_indices
,
seq_group_id
=
categorized_seq_group_ids
[
sampling_type
]
seq_groups
=
[
sampling_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_id
]
sample_metadata
[
sampling_type
]
=
(
seq_group_id
,
seq_groups
,
sample_indices
,
sampled_token_indices
)
if
sampling_type
in
(
SamplingType
.
GREEDY
,
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
for
seq_group
,
is_prompt
in
zip
(
seq_groups
,
is_prompts
)
:
if
is_prompt
:
_
,
sampling_params
=
seq_group
for
seq_group
in
seq_groups
:
if
seq_group
.
is_prompt
:
sampling_params
=
seq_group
.
sampling_params
max_best_of_in_batch
=
max
(
max_best_of_in_batch
,
sampling_params
.
best_of
)
elif
sampling_type
==
SamplingType
.
BEAM
:
...
...
@@ -564,22 +614,21 @@ def _sample_with_triton_kernel(
for
sampling_type
in
SamplingType
:
if
sampling_type
not
in
sample_metadata
:
continue
(
seq_group_id
s
,
seq_groups
,
is_prompts
,
sample_indices
,
(
seq_group_id
,
seq_groups
,
sample_indices
,
sampled_token_indices
)
=
sample_metadata
[
sampling_type
]
if
sampling_type
==
SamplingType
.
GREEDY
:
sample_results
=
_greedy_sample
(
seq_groups
,
sampled_tokens
[
sampled_token_indices
][:,
0
])
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
sample_results
=
_random_sample
(
seq_groups
,
is_prompts
,
sampled_tokens
[
sampled_token_indices
])
seq_groups
,
sampled_tokens
[
sampled_token_indices
])
elif
sampling_type
==
SamplingType
.
BEAM
:
sample_results
=
_beam_search_sample
(
seq_groups
,
is_prompts
,
sampling_metadata
.
seq_data
,
sample_results
=
_beam_search_sample
(
seq_groups
,
beam_search_logprobs
)
sample_results_dict
.
update
(
zip
(
seq_group_id
s
,
sample_results
))
sample_results_dict
.
update
(
zip
(
seq_group_id
,
sample_results
))
sample_results
=
[
sample_results_dict
[
i
]
sample_results_dict
.
get
(
i
,
([],
[]))
for
i
in
range
(
len
(
sampling_metadata
.
seq_groups
))
]
return
sample_results
...
...
@@ -589,7 +638,19 @@ def _sample(
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
)
->
Tuple
[
List
[
Tuple
[
List
[
int
],
List
[
int
]]],
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
SampleResultType
,
Optional
[
torch
.
Tensor
]]:
"""
Args:
probs: (num_query_tokens_in_batch, num_vocab)
logprobs: (num_query_tokens_in_batch, num_vocab)
sampling_metadata: The metadata for a batch for sampling.
sampling_tensors: Tensors that include sampling related metadata.
Returns:
(next_token_ids, parent_seq_ids) for each seq group in a batch.
If sampling is skipped, it returns ([], [])
sampled_token_ids_tensor: A tensor of sampled token ids.
"""
return
_sample_with_torch
(
probs
,
logprobs
,
...
...
@@ -625,57 +686,98 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
def
_get_logprobs
(
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sample_results
:
List
[
Tuple
[
List
[
int
],
List
[
int
]]],
)
->
Tuple
[
List
[
Optional
[
List
[
Optional
[
Dict
[
int
,
float
]]]]],
List
[
List
[
Dict
[
int
,
float
]]]]:
# Prepare query indices
batched_logprobs_query_seq_indices
:
List
[
int
]
=
[]
batched_logprobs_query_token_indices
:
List
[
int
]
=
[]
# at least get one logprob for each token
sample_results
:
SampleResultType
,
)
->
Tuple
[
List
[
Optional
[
PromptLogprobs
]],
List
[
SampleLogprobs
]]:
"""Return sample lobprobs and prompt logprobs.
The logic consists of 3 parts.
- Select indices to compute logprob from, ranks of token ids, and
the top k token ids from logprobs.
- Compute prompt logprobs if required.
- Compute sample logprobs if required.
Args:
logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
logprob per vocab. Sequence groups' query tokens are batched in a
single flattened tensor. For example, assuming there are N
seq groups, it is sorted by prefill tokens for seq_group_1 (if
prompt logprob is enabled), decode tokens for seq_group_1 (if
sampling is required), prefill tokens for seq_group_2, ...
sampling_metadata: The sampling metadata.
sample_results: (num_seq_groups) The tuple of (next_token_ids,
parent_ids) for each sequence group. When beam search is enabled,
sample_results can contain different number of seq_ids from
sampling_metadata.seq_groups. It is because beam search creates
2 * BEAM_WIDTH number of samples (whereas there are only up to
BEAM_WIDTH number of seq_ids).
Returns:
A tuple of prompt and sample logprobs per sequence group in a batch.
"""
# The index of query token to calculate logprobs. It includes both
# prompt and sample logprob indices.
query_indices
:
List
[
int
]
=
[]
# The next token ids to get the logprob value from.
next_token_ids
:
List
[
int
]
=
[]
# The largest requested number of logprobs. We find logprobs as many as the
# largest num logprobs in this API.
largest_num_logprobs
=
1
sample_idx
=
0
for
i
,
(
seq_group
,
sample_result
)
in
enumerate
(
zip
(
sampling_metadata
.
seq_groups
,
sample_results
)):
seq_ids
,
sampling_params
=
seq_group
next_token_ids
,
parent_ids
=
sample_result
num_parent_seqs
=
len
(
seq_ids
)
if
(
i
<
sampling_metadata
.
num_prompts
# Select indices to compute logprob from, ranks of token ids, and the top
# k token ids from logprobs.
for
(
seq_group
,
sample_result
)
in
zip
(
sampling_metadata
.
seq_groups
,
sample_results
):
sampling_params
=
seq_group
.
sampling_params
# Update indices and tokens for prompt logprobs.
if
(
seq_group
.
is_prompt
and
sampling_params
.
prompt_logprobs
is
not
None
):
largest_num_logprobs
=
max
(
largest_num_logprobs
,
sampling_params
.
prompt_logprobs
)
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
prompt_tokens
=
sampling_metadata
.
seq_data
[
seq_ids
[
0
]].
prompt_token_ids
batched_logprobs_query_seq_indices
.
extend
(
sample_idx
+
j
for
j
in
range
(
prompt_len
-
1
))
batched_logprobs_query_token_indices
.
extend
(
token_id
for
token_id
in
prompt_tokens
[
1
:])
sample_idx
+=
prompt_len
-
1
batched_logprobs_query_seq_indices
.
extend
(
[
sample_idx
+
parent_id
for
parent_id
in
parent_ids
])
batched_logprobs_query_token_indices
.
extend
(
next_token_ids
)
if
sampling_params
.
logprobs
is
not
None
:
largest_num_logprobs
=
max
(
largest_num_logprobs
,
sampling_params
.
logprobs
)
sample_idx
+=
num_parent_seqs
assert
sample_idx
==
logprobs
.
size
(
0
)
batched_logprobs_query_seq_indices_gpu
=
torch
.
tensor
(
batched_logprobs_query_seq_indices
,
device
=
logprobs
.
device
)
batched_logprobs_query_token_indices_gpu
=
torch
.
tensor
(
batched_logprobs_query_token_indices
,
device
=
logprobs
.
device
)
# Batched query for logprobs of selected token
batched_logprobs_query_result
=
logprobs
[[
batched_logprobs_query_seq_indices_gpu
,
batched_logprobs_query_token_indices_gpu
next_prompt_tokens
=
_get_next_prompt_tokens
(
seq_group
)
query_indices
.
extend
(
seq_group
.
prompt_logprob_indices
)
next_token_ids
.
extend
(
next_prompt_tokens
)
# Update indices and next tokenes for sample logprob.
if
seq_group
.
do_sample
:
token_ids
,
parent_seq_ids
=
sample_result
# NOTE: We cannot directly use sample_indices because
# sample_indices only contain parent seq_ids of a previous step.
# The current step may have different number of seq_ids, and
# we can obtain it from `sample_result[1]`.
query_idx
=
seq_group
.
sample_indices
[
0
]
query_indices
.
extend
(
[
query_idx
+
parent_id
for
parent_id
in
parent_seq_ids
])
next_token_ids
.
extend
(
token_ids
)
if
sampling_params
.
logprobs
is
not
None
:
largest_num_logprobs
=
max
(
largest_num_logprobs
,
sampling_params
.
logprobs
)
assert
len
(
next_token_ids
)
==
len
(
query_indices
)
if
len
(
query_indices
)
==
0
:
empty_sampled_logprob
:
SampleLogprobs
=
[]
empty_prompt_logprob
:
Optional
[
PromptLogprobs
]
=
None
return
[
empty_prompt_logprob
],
[
empty_sampled_logprob
]
query_indices_gpu
=
torch
.
tensor
(
query_indices
,
device
=
logprobs
.
device
)
next_token_ids_gpu
=
torch
.
tensor
(
next_token_ids
,
device
=
logprobs
.
device
)
# (num_selected_query_tokens, num_logprobs). Note that query_indices can
# contain duplicates if beam search is enabled.
selected_logprobs
=
logprobs
[[
query_indices_gpu
,
next_token_ids_gpu
,
]]
ranks
=
_get_ranks
(
logprobs
[
query_indices_gpu
],
next_token_ids_gpu
,
)
assert
selected_logprobs
.
shape
[
0
]
==
ranks
.
shape
[
0
]
batched_ranks_query_result
=
_get_ranks
(
logprobs
[
batched_logprobs_query_seq_indices_gpu
],
batched_logprobs_query_token_indices_gpu
)
# Batched query for logprobs of topk tokens
# Logprobs of topk tokens for a batch of sequence groups.
# (num_query_tokens_across_batch).
if
largest_num_logprobs
>
0
:
top_logprobs
,
top_token_ids
=
torch
.
topk
(
logprobs
,
largest_num_logprobs
,
...
...
@@ -685,79 +787,136 @@ def _get_logprobs(
else
:
top_logprobs
,
top_token_ids
=
None
,
None
batched_logprobs_query_result
=
batched_logprobs_query_result
.
cpu
()
batched_ranks_query_result
=
batched_ranks_query_result
.
cpu
()
# Gather results
result_prompt_logprobs
:
List
[
Optional
[
PromptLogprobs
]]
=
[]
result_sample_logprobs
:
List
[
SampleLogprobs
]
=
[]
sample_idx
=
0
query_result_idx
=
0
for
i
,
(
seq_group
,
sample_result
)
in
enumerate
(
zip
(
sampling_metadata
.
seq_groups
,
sample_results
)):
seq_ids
,
sampling_params
=
seq_group
next_token_ids
,
parent_ids
=
sample_result
selected_logprobs
=
selected_logprobs
.
cpu
()
ranks
=
ranks
.
cpu
()
# Find prompt/sample logprobs.
prompt_logprobs_per_seq_group
:
List
[
Optional
[
PromptLogprobs
]]
=
[]
sample_logprobs_per_seq_group
:
List
[
SampleLogprobs
]
=
[]
top_logprob_idx
=
0
selected_logprobs_idx
=
0
for
seq_group
,
sample_result
in
zip
(
sampling_metadata
.
seq_groups
,
sample_results
):
(
prompt_logprobs
,
top_logprob_idx
,
selected_logprobs_idx
)
=
_get_prompt_logprob_if_needed
(
seq_group
,
selected_logprobs
,
ranks
,
top_token_ids
,
top_logprobs
,
selected_logprobs_idx
,
top_logprob_idx
)
prompt_logprobs_per_seq_group
.
append
(
prompt_logprobs
)
(
sampled_logprobs
,
top_logprob_idx
,
selected_logprobs_idx
)
=
_get_sampled_logprob_if_needed
(
seq_group
,
sample_result
,
selected_logprobs
,
ranks
,
top_token_ids
,
top_logprobs
,
selected_logprobs_idx
,
top_logprob_idx
)
sample_logprobs_per_seq_group
.
append
(
sampled_logprobs
)
return
prompt_logprobs_per_seq_group
,
sample_logprobs_per_seq_group
def
_get_prompt_logprob_if_needed
(
seq_group
:
SequenceGroupToSample
,
selected_logprobs
:
torch
.
Tensor
,
ranks
:
torch
.
Tensor
,
top_token_ids
:
torch
.
Tensor
,
top_logprobs
:
torch
.
Tensor
,
selected_logprobs_idx
:
int
,
top_logprob_idx
:
int
,
):
"""Compute the prompt logprob from a sequence group if needed."""
sampling_params
=
seq_group
.
sampling_params
is_prompt
=
seq_group
.
is_prompt
# Find prompt logprobs
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
if
(
is_prompt
and
sampling_params
.
prompt_logprobs
is
not
None
):
prompt_logprobs
=
[]
num_logprobs
=
sampling_params
.
prompt_logprobs
next_prompt_tokens
=
_get_next_prompt_tokens
(
seq_group
)
for
token_id
in
next_prompt_tokens
:
# Calculate the prompt logprob of the real prompt tokens.
# Use tuple here for performance (to use to_list()).
# {token_id: (logprob, rank_from_vocab)}
prompt_logprobs_dict
:
Dict
[
int
,
Tuple
[
float
,
int
]]
=
{
token_id
:
(
selected_logprobs
[
selected_logprobs_idx
].
item
(),
ranks
[
selected_logprobs_idx
].
item
())
}
# Prompt logprobs
if
(
i
<
sampling_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
num_logprobs
=
sampling_params
.
prompt_logprobs
prompt_tokens
=
sampling_metadata
.
seq_data
[
seq_ids
[
0
]].
prompt_token_ids
group_prompt_logprobs
:
PromptLogprobs
=
[
None
]
for
token_id
in
prompt_tokens
[
1
:]:
prompt_logprobs_dict
=
{
token_id
:
(
batched_logprobs_query_result
[
query_result_idx
].
item
(),
batched_ranks_query_result
[
query_result_idx
].
item
())
}
if
num_logprobs
>
0
:
prompt_logprobs_dict
.
update
(
# Add top K prompt logprobs along with its rank.
if
num_logprobs
>
0
:
prompt_logprobs_dict
.
update
(
zip
(
top_token_ids
[
top_logprob_idx
,
:
num_logprobs
].
tolist
(),
zip
(
top_token_ids
[
sample_idx
,
:
num_logprobs
].
tolist
(),
zip
(
top_logprobs
[
sample_idx
,
:
num_logprobs
].
tolist
(),
range
(
1
,
num_logprobs
+
1
))))
group_prompt_logprobs
.
append
({
token_id
:
Logprob
(
*
logprob_rank
)
for
token_id
,
logprob_rank
in
prompt_logprobs_dict
.
items
()
})
sample_idx
+=
1
query_result_idx
+=
1
result_prompt_logprobs
.
append
(
group_prompt_logprobs
)
else
:
result_prompt_logprobs
.
append
(
None
)
# Sample logprobs
num_logprobs
=
sampling_params
.
logprobs
if
num_logprobs
is
None
:
num_logprobs
=
0
group_sample_logprobs
:
SampleLogprobs
=
[]
for
next_token_id
,
parent_id
in
zip
(
next_token_ids
,
parent_ids
):
sample_logprobs_dict
=
{
top_logprobs
[
top_logprob_idx
,
:
num_logprobs
].
tolist
(),
# This is ranks. Since top_logprob is sorted,
# we can just use a range here.
range
(
1
,
num_logprobs
+
1
))))
prompt_logprobs
.
append
({
token_id
:
Logprob
(
*
logprob_and_rank
)
for
token_id
,
logprob_and_rank
in
prompt_logprobs_dict
.
items
()
})
# + 1 to go to the next prompt token.
top_logprob_idx
+=
1
selected_logprobs_idx
+=
1
return
prompt_logprobs
,
top_logprob_idx
,
selected_logprobs_idx
def
_get_sampled_logprob_if_needed
(
seq_group
:
SequenceGroupToSample
,
sample_result
:
Tuple
[
List
[
int
],
List
[
int
]],
selected_logprobs
:
torch
.
Tensor
,
ranks
:
torch
.
Tensor
,
top_token_ids
:
torch
.
Tensor
,
top_logprobs
:
torch
.
Tensor
,
selected_logprobs_idx
:
int
,
top_logprob_idx
:
int
,
):
"""Compute the sample logprob if needed."""
seq_ids
=
seq_group
.
seq_ids
num_logprobs
=
seq_group
.
sampling_params
.
logprobs
if
num_logprobs
is
None
:
num_logprobs
=
0
sampled_logprobs
:
SampleLogprobs
=
[]
next_token_ids
,
parent_seq_ids
=
sample_result
if
seq_group
.
do_sample
:
assert
len
(
next_token_ids
)
>
0
for
(
next_token_id
,
parent_id
)
in
zip
(
next_token_ids
,
parent_seq_ids
):
# Calculate the sample logprob of the real sampled tokens.
# Use tuple here for performance (to use to_list()).
# token_id: (logprob, rank_from_vocab)
sampled_logprobs_dict
:
Dict
[
int
,
Tuple
[
float
,
int
]]
=
{
next_token_id
:
(
batch
ed_logprobs
_query_result
[
query_result
_idx
].
item
(),
batched_ranks_query_result
[
query_result
_idx
].
item
())
(
select
ed_logprobs
[
selected_logprobs
_idx
].
item
(),
ranks
[
selected_logprobs
_idx
].
item
())
}
query_result_idx
+=
1
# +1 to go to the next sampled token. Note that
# selected_logprobs can contain duplicates unlike top_logprobs
# when beam search is enabled.
selected_logprobs_idx
+=
1
# Second, add top K logprobs along with its rank.
if
num_logprobs
>=
0
:
sample_logprobs_dict
.
update
(
sample
d
_logprobs_dict
.
update
(
zip
(
top_token_ids
[
sample
_idx
+
top_token_ids
[
top_logprob
_idx
+
parent_id
,
:
num_logprobs
].
tolist
(),
zip
(
top_logprobs
[
sample
_idx
+
top_logprobs
[
top_logprob
_idx
+
parent_id
,
:
num_logprobs
].
tolist
(),
# This is rank. Since top_logprob is sorted, we
# can just use a range here.
range
(
1
,
num_logprobs
+
1
))))
group_sample_logprobs
.
append
({
token_id
:
Logprob
(
*
logprob_rank
)
for
token_id
,
logprob_rank
in
sample_logprobs_dict
.
items
()
sampled_logprobs
.
append
({
token_id
:
Logprob
(
*
logprob_and_rank
)
for
token_id
,
logprob_and_rank
in
sampled_logprobs_dict
.
items
()
})
result_sample_logprobs
.
append
(
group_sample_logprobs
)
sample_idx
+=
len
(
seq_ids
)
return
result_prompt_logprobs
,
result_sample
_logprobs
# There are len(seq_ids) number of sampled tokens for the current
# sequence group in top_logprobs. Jump to the next seq_group.
top_logprob_idx
+=
len
(
seq_ids
)
return
sampled_logprobs
,
top_logprob_idx
,
selected
_logprobs
_idx
def
_modify_greedy_probs_inplace
(
logprobs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
...
...
@@ -805,18 +964,18 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
has implications on the overall design of the sampler, e.g. how to record
accurate logprobs for the user, so this improvement is deferred to later.
"""
logprobs
[
sample_indices
,
:]
=
-
float
(
'inf'
)
logprobs
[
sample_indices
,
greedy_samples
]
=
0.0
# NOTE: logprobs are not modified so they can be returned to the user.
probs
[
sample_indices
,
:]
=
0
probs
[
sample_indices
,
greedy_samples
]
=
1.0
def
_build_sampler_output
(
sample_results
:
List
[
Tuple
[
List
[
int
],
List
[
int
]]]
,
sample_results
:
SampleResultType
,
sampling_metadata
:
SamplingMetadata
,
prompt_logprobs
:
List
[
Optional
[
PromptLogprobs
]],
sample_logprobs
:
List
[
SampleLogprobs
],
on_device_tensors
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
on_device_tensors
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
SamplerOutput
:
"""Construct Python objects with the output of sampling.
...
...
@@ -832,7 +991,7 @@ def _build_sampler_output(
group_sample_logprobs
)
in
zip
(
sampling_metadata
.
seq_groups
,
sample_results
,
prompt_logprobs
,
sample_logprobs
):
seq_ids
,
_
=
seq_group
seq_ids
=
seq_group
.
seq_ids
next_token_ids
,
parent_ids
=
sample_result
seq_outputs
=
[]
for
parent_id
,
next_token_id
,
logprobs
in
zip
(
parent_ids
,
...
...
@@ -845,12 +1004,48 @@ def _build_sampler_output(
# If not specified, store None values in SamplerOutput.
if
on_device_tensors
is
not
None
:
sampled_token_probs
,
sampled_token_ids
=
on_device_tensors
(
sampled_token_probs
,
logprobs_tensor
,
sampled_token_ids
)
=
on_device_tensors
else
:
sampled_token_probs
,
sampled_token_ids
=
(
None
,
None
)
sampled_token_probs
,
logprobs_tensor
,
sampled_token_ids
=
(
None
,
None
,
None
)
return
SamplerOutput
(
outputs
=
sampler_output
,
sampled_token_probs
=
sampled_token_probs
,
sampled_token_ids
=
sampled_token_ids
,
logprobs
=
logprobs_tensor
,
)
def
_get_next_prompt_tokens
(
seq_group
:
SequenceGroupToSample
)
->
List
[
int
]:
"""Get a list of next prompt tokens to compute logprob from a
given sequence group.
It is used to compute prompt logprob. Imagine you have logprob for each
query token. Query token needs to know the next prompt token id to compute
prompt logprob. This is a helper to obtain next prompt token ids.
This API has to be used only when the caller knows seq_group is in prefill
stage.
Returns:
A list of next prompt tokens to compute logprob.
"""
assert
seq_group
.
is_prompt
,
(
"Caller should ensure the sequence group is in a prefill stage."
)
seq_ids
=
seq_group
.
seq_ids
query_len
=
seq_group
.
query_len
assert
query_len
is
not
None
# prompt has only 1 seq id.
assert
len
(
seq_ids
)
==
1
seq_data
=
seq_group
.
seq_data
[
seq_ids
[
0
]]
computed_len
=
seq_data
.
get_num_computed_tokens
()
prompt_tokens
=
seq_data
.
prompt_token_ids
# +1 because we are looking for a next prompt token.
next_token_index_start
=
computed_len
+
1
next_token_index_end
=
min
(
computed_len
+
query_len
+
1
,
len
(
prompt_tokens
))
next_prompt_tokens
=
prompt_tokens
[
next_token_index_start
:
next_token_index_end
]
return
next_prompt_tokens
vllm/model_executor/layers/vocab_parallel_embedding.py
View file @
1591c68f
...
...
@@ -105,6 +105,14 @@ class VocabParallelEmbedding(torch.nn.Module):
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
return
output
def
extra_repr
(
self
)
->
str
:
s
=
f
"num_embeddings=
{
self
.
num_embeddings_per_partition
}
"
s
+=
f
", embedding_dim=
{
self
.
embedding_dim
}
"
s
+=
f
", org_vocab_size=
{
self
.
org_vocab_size
}
"
s
+=
f
', num_embeddings_padded=
{
self
.
num_embeddings_padded
}
'
s
+=
f
', tp_size=
{
self
.
tp_size
}
'
return
s
class
ParallelLMHead
(
VocabParallelEmbedding
):
"""Parallelized LM head.
...
...
vllm/model_executor/model_loader/loader.py
View file @
1591c68f
...
...
@@ -3,16 +3,19 @@ import copy
import
glob
import
os
from
abc
import
ABC
,
abstractmethod
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
,
Type
)
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
,
Type
import
huggingface_hub
import
torch
from
torch
import
nn
from
vllm.config
import
(
VLLM_USE_MODELSCOPE
,
DeviceConfig
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.config
import
(
DeviceConfig
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.model_loader.tensorizer
import
(
TensorizerConfig
,
is_vllm_serialized_tensorizer
,
load_with_tensorizer
,
tensorizer_weights_iterator
)
...
...
@@ -24,9 +27,6 @@ from vllm.model_executor.model_loader.weight_utils import (
pt_weights_iterator
,
safetensors_weights_iterator
)
from
vllm.model_executor.models.llava
import
LlavaForConditionalGeneration
if
TYPE_CHECKING
:
from
vllm.model_executor.layers.linear
import
LinearMethodBase
_VISION_MODEL_CLASSES
=
[
LlavaForConditionalGeneration
,
]
...
...
@@ -34,11 +34,10 @@ _VISION_MODEL_CLASSES = [
logger
=
init_logger
(
__name__
)
def
_get_
linear_method
(
def
_get_
quantization_config
(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
)
->
Optional
[
"LinearMethodBase"
]:
"""Get the (maybe quantized) linear method."""
linear_method
=
None
load_config
:
LoadConfig
)
->
Optional
[
QuantizationConfig
]:
"""Get the quantization config."""
if
model_config
.
quantization
is
not
None
:
quant_config
=
get_quant_config
(
model_config
,
load_config
)
capability
=
torch
.
cuda
.
get_device_capability
()
...
...
@@ -55,6 +54,7 @@ def _get_linear_method(
f
"
{
model_config
.
dtype
}
is not supported for quantization "
f
"method
{
model_config
.
quantization
}
. Supported dtypes: "
f
"
{
supported_dtypes
}
"
)
<<<<<<<
HEAD
linear_method
=
quant_config
.
get_linear_method
()
...
...
@@ -62,6 +62,10 @@ def _get_linear_method(
os
.
environ
[
'LLAMA_NN'
]
=
'0'
return
linear_method
=======
return
quant_config
return
None
>>>>>>>
v0
.
4.2
def
_get_model_initialization_kwargs
(
...
...
@@ -89,10 +93,10 @@ def _initialize_model(
vision_language_config
:
Optional
[
VisionLanguageConfig
])
->
nn
.
Module
:
"""Initialize a model with the given configurations."""
model_class
=
get_model_architecture
(
model_config
)[
0
]
linear_method
=
_get_linear_method
(
model_config
,
load_config
)
quant_config
=
_get_quantization_config
(
model_config
,
load_config
)
return
model_class
(
config
=
model_config
.
hf_config
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
**
_get_model_initialization_kwargs
(
model_class
,
lora_config
,
vision_language_config
))
...
...
@@ -139,7 +143,9 @@ class DefaultModelLoader(BaseModelLoader):
model_path
=
snapshot_download
(
model_id
=
model
,
cache_dir
=
self
.
load_config
.
download_dir
,
revision
=
revision
)
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
revision
=
revision
,
)
else
:
model_path
=
model
return
model_path
...
...
@@ -233,9 +239,11 @@ class DefaultModelLoader(BaseModelLoader):
"fall_back_to_pt_during_load"
,
True
)),
)
for
_
,
module
in
model
.
named_modules
():
linear_method
=
getattr
(
module
,
"linear_method"
,
None
)
if
linear_method
is
not
None
:
linear_method
.
process_weights_after_loading
(
module
)
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
quant_method
.
process_weights_after_loading
(
module
)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if
hasattr
(
module
,
"process_weights_after_loading"
):
module
.
process_weights_after_loading
()
return
model
.
eval
()
...
...
@@ -318,11 +326,11 @@ class TensorizerLoader(BaseModelLoader):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model_class
=
get_model_architecture
(
model_config
)[
0
]
linear_method
=
_get_linear_method
(
model
_config
,
self
.
load_config
)
quant_config
=
_get_quantization
_config
(
model_config
,
self
.
load_config
)
extra_kwargs
=
_get_model_initialization_kwargs
(
model_class
,
lora_config
,
vision_language_config
)
extra_kwargs
[
"
linear_method"
]
=
linear_method
extra_kwargs
[
"
quant_config"
]
=
quant_config
tensorizer_config
=
copy
.
copy
(
self
.
tensorizer_config
)
tensorizer_config
.
model_class
=
model_class
...
...
vllm/model_executor/model_loader/tensorizer.py
View file @
1591c68f
...
...
@@ -11,9 +11,11 @@ import torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
import
vllm.envs
as
envs
from
vllm.config
import
ModelConfig
,
ParallelConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
...
...
@@ -43,7 +45,7 @@ class TensorizerConfig:
str
,
bytes
,
os
.
PathLike
,
int
]
vllm_tensorized
:
bool
verify_hash
:
Optional
[
bool
]
=
False
num_readers
:
Optional
[
int
]
=
1
num_readers
:
Optional
[
int
]
=
None
encryption_keyfile
:
Optional
[
str
]
=
None
s3_access_key_id
:
Optional
[
str
]
=
None
s3_secret_access_key
:
Optional
[
str
]
=
None
...
...
@@ -63,7 +65,7 @@ class TensorizerConfig:
"s3_secret_access_key"
:
self
.
s3_secret_access_key
,
"s3_endpoint"
:
self
.
s3_endpoint
,
}
return
TensorizerArgs
(
**
tensorizer_args
)
return
TensorizerArgs
(
**
tensorizer_args
)
# type: ignore
def
verify_with_parallel_config
(
self
,
...
...
@@ -103,7 +105,7 @@ class TensorizerArgs:
str
,
bytes
,
os
.
PathLike
,
int
]
vllm_tensorized
:
bool
verify_hash
:
Optional
[
bool
]
=
False
num_readers
:
Optional
[
int
]
=
1
num_readers
:
Optional
[
int
]
=
None
encryption_keyfile
:
Optional
[
str
]
=
None
s3_access_key_id
:
Optional
[
str
]
=
None
s3_secret_access_key
:
Optional
[
str
]
=
None
...
...
@@ -124,8 +126,9 @@ class TensorizerArgs:
the hashes stored in the metadata. A `HashMismatchError` will be
raised if any of the hashes do not match.
num_readers: Controls how many threads are allowed to read concurrently
from the source file. Default is 1. This greatly increases
performance.
from the source file. Default is `None`, which will dynamically set
the number of readers based on the number of available
resources and model size. This greatly increases performance.
encryption_keyfile: File path to a binary file containing a
binary key to use for decryption. `None` (the default) means
no decryption. See the example script in
...
...
@@ -140,13 +143,10 @@ class TensorizerArgs:
def
__post_init__
(
self
):
self
.
file_obj
=
self
.
tensorizer_uri
self
.
s3_access_key_id
=
(
self
.
s3_access_key_id
or
os
.
environ
.
get
(
"S3_ACCESS_KEY_ID"
))
or
None
self
.
s3_secret_access_key
=
(
self
.
s3_secret_access_key
or
os
.
environ
.
get
(
"S3_SECRET_ACCESS_KEY"
))
or
None
self
.
s3_endpoint
=
(
self
.
s3_endpoint
or
os
.
environ
.
get
(
"S3_ENDPOINT_URL"
))
or
None
self
.
s3_access_key_id
=
self
.
s3_access_key_id
or
envs
.
S3_ACCESS_KEY_ID
self
.
s3_secret_access_key
=
(
self
.
s3_secret_access_key
or
envs
.
S3_SECRET_ACCESS_KEY
)
self
.
s3_endpoint
=
self
.
s3_endpoint
or
envs
.
S3_ENDPOINT_URL
self
.
stream_params
=
{
"s3_access_key_id"
:
self
.
s3_access_key_id
,
"s3_secret_access_key"
:
self
.
s3_secret_access_key
,
...
...
@@ -198,10 +198,12 @@ class TensorizerArgs:
"use for decryption. Can be a file path or S3 network URI."
)
group
.
add_argument
(
"--num-readers"
,
default
=
1
,
default
=
None
,
type
=
int
,
help
=
"Controls how many threads are allowed to read concurrently "
"from the source file."
)
"from the source file. Default is `None`, which will dynamically "
"set the number of readers based on the available resources "
"and model size. This greatly increases performance."
)
group
.
add_argument
(
"--s3-access-key-id"
,
default
=
None
,
...
...
@@ -251,7 +253,7 @@ class TensorizerAgent:
"""
def
__init__
(
self
,
tensorizer_config
:
TensorizerConfig
,
linear_method
:
LinearMethodBase
,
**
extra_kwargs
):
quant_config
:
QuantizationConfig
,
**
extra_kwargs
):
if
tensorizer_load_fail
is
not
None
:
raise
ImportError
(
"Tensorizer is not installed. Please install tensorizer "
...
...
@@ -262,19 +264,21 @@ class TensorizerAgent:
self
.
tensorizer_args
=
(
self
.
tensorizer_config
.
_construct_tensorizer_args
())
self
.
extra_kwargs
=
extra_kwargs
if
extra_kwargs
.
get
(
"
linear_method
"
,
None
)
is
not
None
:
self
.
linear_method
=
extra_kwargs
[
"
linear_method
"
]
if
extra_kwargs
.
get
(
"
quant_config
"
,
None
)
is
not
None
:
self
.
quant_config
=
extra_kwargs
[
"
quant_config
"
]
else
:
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
model
=
self
.
_init_model
()
def
_init_model
(
self
):
assert
self
.
tensorizer_config
.
hf_config
is
not
None
model_args
=
self
.
tensorizer_config
.
hf_config
model_args
.
torch_dtype
=
self
.
tensorizer_config
.
dtype
assert
self
.
tensorizer_config
.
model_class
is
not
None
with
no_init_or_tensor
():
return
self
.
tensorizer_config
.
model_class
(
config
=
model_args
,
linear_method
=
self
.
linear_method
,
quant_config
=
self
.
quant_config
,
**
self
.
extra_kwargs
)
def
_resize_lora_embeddings
(
self
):
...
...
@@ -334,10 +338,10 @@ class TensorizerAgent:
per_second
=
convert_bytes
(
deserializer
.
total_tensor_bytes
/
duration
)
after_mem
=
get_mem_usage
()
deserializer
.
close
()
logger
.
info
(
f
"Deserialized
{
total_bytes_str
}
in "
f
"
{
end
-
start
:
0.2
f
}
s
,
{
per_second
}
/s"
)
logger
.
info
(
f
"Memory usage before:
{
before_mem
}
"
)
logger
.
info
(
f
"Memory usage after:
{
after_mem
}
"
)
logger
.
info
(
"Deserialized
%s in %0.2fs, %s/s"
,
total_bytes_str
,
end
-
start
,
per_second
)
logger
.
info
(
"Memory usage before:
%s"
,
before_mem
)
logger
.
info
(
"Memory usage after:
%s"
,
after_mem
)
self
.
_check_tensors_on_meta_device
()
self
.
_resize_lora_embeddings
()
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
1591c68f
...
...
@@ -127,11 +127,14 @@ def get_quant_config(model_config: ModelConfig,
if
not
is_local
:
# Download the config files.
with
get_lock
(
model_name_or_path
,
load_config
.
download_dir
):
hf_folder
=
snapshot_download
(
model_name_or_path
,
revision
=
model_config
.
revision
,
allow_patterns
=
"*.json"
,
cache_dir
=
load_config
.
download_dir
,
tqdm_class
=
DisabledTqdm
)
hf_folder
=
snapshot_download
(
model_name_or_path
,
revision
=
model_config
.
revision
,
allow_patterns
=
"*.json"
,
cache_dir
=
load_config
.
download_dir
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
tqdm_class
=
DisabledTqdm
,
)
else
:
hf_folder
=
model_name_or_path
...
...
@@ -161,12 +164,14 @@ def get_quant_config(model_config: ModelConfig,
return
quant_cls
.
from_config
(
config
)
def
download_weights_from_hf
(
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
],
allow_patterns
:
List
[
str
],
revision
:
Optional
[
str
]
=
None
)
->
str
:
def
download_weights_from_hf
(
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
],
allow_patterns
:
List
[
str
],
revision
:
Optional
[
str
]
=
None
,
)
->
str
:
"""Download model weights from Hugging Face Hub.
Args:
model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model
...
...
@@ -179,26 +184,30 @@ def download_weights_from_hf(model_name_or_path: str,
Returns:
str: The path to the downloaded model weights.
"""
# Before we download we look at that is available:
fs
=
HfFileSystem
()
file_list
=
fs
.
ls
(
model_name_or_path
,
detail
=
False
,
revision
=
revision
)
# depending on what is available we download different things
for
pattern
in
allow_patterns
:
matching
=
fnmatch
.
filter
(
file_list
,
pattern
)
if
len
(
matching
)
>
0
:
allow_patterns
=
[
pattern
]
break
logger
.
info
(
f
"Using model weights format
{
allow_patterns
}
"
)
if
not
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
:
# Before we download we look at that is available:
fs
=
HfFileSystem
()
file_list
=
fs
.
ls
(
model_name_or_path
,
detail
=
False
,
revision
=
revision
)
# depending on what is available we download different things
for
pattern
in
allow_patterns
:
matching
=
fnmatch
.
filter
(
file_list
,
pattern
)
if
len
(
matching
)
>
0
:
allow_patterns
=
[
pattern
]
break
logger
.
info
(
"Using model weights format %s"
,
allow_patterns
)
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with
get_lock
(
model_name_or_path
,
cache_dir
):
hf_folder
=
snapshot_download
(
model_name_or_path
,
allow_patterns
=
allow_patterns
,
cache_dir
=
cache_dir
,
tqdm_class
=
DisabledTqdm
,
revision
=
revision
)
hf_folder
=
snapshot_download
(
model_name_or_path
,
allow_patterns
=
allow_patterns
,
cache_dir
=
cache_dir
,
tqdm_class
=
DisabledTqdm
,
revision
=
revision
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
)
return
hf_folder
...
...
@@ -310,17 +319,17 @@ def kv_cache_scales_loader(
return
layer_scales_map
.
items
()
except
FileNotFoundError
:
logger
.
error
(
f
"File or directory '
{
filename
}
' not found."
)
logger
.
error
(
"File or directory '
%s
' not found."
,
filename
)
except
json
.
JSONDecodeError
:
logger
.
error
(
f
"Error decoding JSON in file '
{
filename
}
'."
)
logger
.
error
(
"Error decoding JSON in file '
%s'."
,
filename
)
except
Exception
as
e
:
logger
.
error
(
f
"An error occurred while reading '
{
filename
}
':
{
e
}
"
)
logger
.
error
(
"An error occurred while reading '
%s': %s"
,
filename
,
e
)
# This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales
logger
.
warning
(
"Defaulting to KV cache scaling factors = 1.0 "
f
"for all layers in TP rank
{
tp_rank
}
"
"
as an error occurred during loading."
)
logger
.
warning
(
"Defaulting to KV cache scaling factors = 1.0 for all
"
"layers in TP rank %d
as an error occurred during loading."
,
tp_rank
)
return
[]
...
...
vllm/model_executor/models/__init__.py
View file @
1591c68f
...
...
@@ -42,10 +42,11 @@ _MODELS = {
"MptForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MPTForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MiniCPMForCausalLM"
:
(
"minicpm"
,
"MiniCPMForCausalLM"
),
"O
LM
oForCausalLM"
:
(
"olmo"
,
"O
LM
oForCausalLM"
),
"O
lm
oForCausalLM"
:
(
"olmo"
,
"O
lm
oForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"OrionForCausalLM"
:
(
"orion"
,
"OrionForCausalLM"
),
"PhiForCausalLM"
:
(
"phi"
,
"PhiForCausalLM"
),
"Phi3ForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"Qwen2ForCausalLM"
:
(
"qwen2"
,
"Qwen2ForCausalLM"
),
"Qwen2MoeForCausalLM"
:
(
"qwen2_moe"
,
"Qwen2MoeForCausalLM"
),
...
...
@@ -90,8 +91,8 @@ class ModelRegistry:
"ROCm for now."
)
if
model_arch
in
_ROCM_PARTIALLY_SUPPORTED_MODELS
:
logger
.
warning
(
f
"Model architecture
{
model_arch
}
is partially supported
"
"by ROCm: "
+
_ROCM_PARTIALLY_SUPPORTED_MODELS
[
model_arch
])
"Model architecture
%s
is partially supported
by ROCm: %s"
,
model_arch
,
_ROCM_PARTIALLY_SUPPORTED_MODELS
[
model_arch
])
module_name
,
model_cls_name
=
_MODELS
[
model_arch
]
module
=
importlib
.
import_module
(
...
...
@@ -106,9 +107,9 @@ class ModelRegistry:
def
register_model
(
model_arch
:
str
,
model_cls
:
Type
[
nn
.
Module
]):
if
model_arch
in
_MODELS
:
logger
.
warning
(
f
"Model architecture
{
model_arch
}
is already registered, "
"
and will be
overwritten by the new model
"
f
"class
{
model_cls
.
__name__
}
."
)
"Model architecture
%s
is already registered,
and will be
"
"overwritten by the new model
class %s."
,
model_arch
,
model_cls
.
__name__
)
global
_OOT_MODELS
_OOT_MODELS
[
model_arch
]
=
model_cls
...
...
vllm/model_executor/models/baichuan.py
View file @
1591c68f
...
...
@@ -31,11 +31,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -77,17 +78,17 @@ class BaiChuanMLP(nn.Module):
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
...
...
@@ -110,7 +111,7 @@ class BaiChuanAttention(nn.Module):
position_embedding
:
str
,
rope_theta
:
float
=
10000
,
max_position_embeddings
:
int
=
8192
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
...
...
@@ -132,13 +133,13 @@ class BaiChuanAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_heads
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
# Create the alibi slopes and slice them.
if
self
.
postion_embedding
==
"ALIBI"
:
...
...
@@ -184,7 +185,7 @@ class BaiChuanDecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
position_embedding
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
...
...
@@ -196,13 +197,13 @@ class BaiChuanDecoderLayer(nn.Module):
position_embedding
=
position_embedding
,
rope_theta
=
rope_theta
,
max_position_embeddings
=
max_position_embeddings
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
mlp
=
BaiChuanMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -243,7 +244,7 @@ class BaiChuanModel(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
position_embedding
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
...
...
@@ -254,7 +255,7 @@ class BaiChuanModel(nn.Module):
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
BaiChuanDecoderLayer
(
config
,
position_embedding
,
linear_method
)
BaiChuanDecoderLayer
(
config
,
position_embedding
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -303,13 +304,13 @@ class BaiChuanBaseForCausalLM(nn.Module):
self
,
config
,
position_embedding
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
BaiChuanModel
(
config
,
position_embedding
,
linear_method
)
self
.
quant_config
=
quant_config
self
.
model
=
BaiChuanModel
(
config
,
position_embedding
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
@@ -388,13 +389,13 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
def
__init__
(
self
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
if
config
.
hidden_size
==
4096
:
# baichuan2 7b
super
().
__init__
(
config
,
"ROPE"
,
linear_method
,
lora_config
)
super
().
__init__
(
config
,
"ROPE"
,
quant_config
,
lora_config
)
else
:
# baichuan 13b, baichuan2 13b
super
().
__init__
(
config
,
"ALIBI"
,
linear_method
,
lora_config
)
super
().
__init__
(
config
,
"ALIBI"
,
quant_config
,
lora_config
)
class
BaiChuanForCausalLM
(
BaiChuanBaseForCausalLM
):
...
...
@@ -403,7 +404,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
def
__init__
(
self
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
super
().
__init__
(
config
,
"ROPE"
,
linear_method
,
lora_config
)
super
().
__init__
(
config
,
"ROPE"
,
quant_config
,
lora_config
)
vllm/model_executor/models/bloom.py
View file @
1591c68f
...
...
@@ -28,10 +28,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
...
...
@@ -70,7 +71,7 @@ class BloomAttention(nn.Module):
def
__init__
(
self
,
config
:
BloomConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -87,13 +88,13 @@ class BloomAttention(nn.Module):
self
.
head_dim
,
self
.
total_num_heads
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
dense
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
# Create the alibi slopes and slice them.
...
...
@@ -129,21 +130,20 @@ class BloomMLP(nn.Module):
def
__init__
(
self
,
config
:
BloomConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
self
.
dense_h_to_4h
=
ColumnParallelLinear
(
hidden_size
,
4
*
hidden_size
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
gelu_impl
=
get_act_fn
(
"gelu"
,
quant_config
,
4
*
hidden_size
)
self
.
dense_4h_to_h
=
RowParallelLinear
(
4
*
hidden_size
,
hidden_size
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -158,17 +158,17 @@ class BloomBlock(nn.Module):
def
__init__
(
self
,
config
:
BloomConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
self
.
input_layernorm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
self_attention
=
BloomAttention
(
config
,
linear_method
)
self
.
self_attention
=
BloomAttention
(
config
,
quant_config
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
BloomMLP
(
config
,
linear_method
)
self
.
mlp
=
BloomMLP
(
config
,
quant_config
)
self
.
apply_residual_connection_post_layernorm
=
(
config
.
apply_residual_connection_post_layernorm
)
...
...
@@ -214,7 +214,7 @@ class BloomModel(nn.Module):
def
__init__
(
self
,
config
:
BloomConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
embed_dim
=
config
.
hidden_size
...
...
@@ -229,7 +229,7 @@ class BloomModel(nn.Module):
# Transformer blocks
self
.
h
=
nn
.
ModuleList
([
BloomBlock
(
config
,
linear_method
)
BloomBlock
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
...
...
@@ -262,12 +262,12 @@ class BloomForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
BloomConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
transformer
=
BloomModel
(
config
,
linear_method
)
self
.
quant_config
=
quant_config
self
.
transformer
=
BloomModel
(
config
,
quant_config
)
self
.
lm_head_weight
=
self
.
transformer
.
word_embeddings
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/chatglm.py
View file @
1591c68f
...
...
@@ -13,11 +13,12 @@ from vllm.config import LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -33,7 +34,7 @@ class GLMAttention(nn.Module):
def
__init__
(
self
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -65,13 +66,13 @@ class GLMAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
config
.
add_bias_linear
or
config
.
add_qkv_bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
dense
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
config
.
hidden_size
,
bias
=
config
.
add_bias_linear
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
...
...
@@ -123,7 +124,7 @@ class GLMMLP(nn.Module):
def
__init__
(
self
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
...
...
@@ -134,7 +135,7 @@ class GLMMLP(nn.Module):
config
.
hidden_size
,
[
config
.
ffn_hidden_size
]
*
2
,
bias
=
config
.
add_bias_linear
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
activation_func
=
SiluAndMul
()
...
...
@@ -144,7 +145,7 @@ class GLMMLP(nn.Module):
config
.
ffn_hidden_size
,
config
.
hidden_size
,
bias
=
config
.
add_bias_linear
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
def
forward
(
self
,
hidden_states
):
...
...
@@ -166,7 +167,7 @@ class GLMBlock(nn.Module):
def
__init__
(
self
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
apply_residual_connection_post_layernorm
=
(
...
...
@@ -180,7 +181,7 @@ class GLMBlock(nn.Module):
eps
=
config
.
layernorm_epsilon
)
# Self attention.
self
.
self_attention
=
GLMAttention
(
config
,
linear_method
)
self
.
self_attention
=
GLMAttention
(
config
,
quant_config
)
self
.
hidden_dropout
=
config
.
hidden_dropout
# Layernorm on the attention output
...
...
@@ -188,7 +189,7 @@ class GLMBlock(nn.Module):
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
)
# MLP
self
.
mlp
=
GLMMLP
(
config
,
linear_method
)
self
.
mlp
=
GLMMLP
(
config
,
quant_config
)
def
forward
(
self
,
...
...
@@ -236,7 +237,7 @@ class GLMTransformer(nn.Module):
def
__init__
(
self
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
post_layer_norm
=
config
.
post_layer_norm
...
...
@@ -246,7 +247,7 @@ class GLMTransformer(nn.Module):
# Transformer layers.
self
.
layers
=
nn
.
ModuleList
(
[
GLMBlock
(
config
,
linear_method
)
for
i
in
range
(
self
.
num_layers
)])
[
GLMBlock
(
config
,
quant_config
)
for
i
in
range
(
self
.
num_layers
)])
if
self
.
post_layer_norm
:
layer_norm_func
=
RMSNorm
if
config
.
rmsnorm
else
LayerNorm
...
...
@@ -281,7 +282,7 @@ class ChatGLMModel(nn.Module):
def
__init__
(
self
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
...
...
@@ -291,7 +292,7 @@ class ChatGLMModel(nn.Module):
self
.
num_layers
=
config
.
num_layers
self
.
multi_query_group_num
=
config
.
multi_query_group_num
self
.
kv_channels
=
config
.
kv_channels
self
.
encoder
=
GLMTransformer
(
config
,
linear_method
)
self
.
encoder
=
GLMTransformer
(
config
,
quant_config
)
self
.
output_layer
=
ParallelLMHead
(
config
.
padded_vocab_size
,
config
.
hidden_size
)
...
...
@@ -333,13 +334,13 @@ class ChatGLMForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
ChatGLMConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
:
ChatGLMConfig
=
config
self
.
linear_method
=
linear_method
self
.
transformer
=
ChatGLMModel
(
config
,
linear_method
)
self
.
quant_config
=
quant_config
self
.
transformer
=
ChatGLMModel
(
config
,
quant_config
)
self
.
lm_head_weight
=
self
.
transformer
.
output_layer
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
padded_vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/commandr.py
View file @
1591c68f
...
...
@@ -32,11 +32,12 @@ from vllm.attention import Attention, AttentionMetadata
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -91,7 +92,7 @@ class CohereMLP(nn.Module):
def
__init__
(
self
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -101,13 +102,13 @@ class CohereMLP(nn.Module):
self
.
hidden_size
,
[
self
.
intermediate_size
]
*
2
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
down_proj
=
RowParallelLinear
(
self
.
intermediate_size
,
self
.
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
act_fn
=
SiluAndMul
()
...
...
@@ -123,7 +124,7 @@ class CohereAttention(nn.Module):
def
__init__
(
self
,
config
:
CohereConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
tp_size
=
get_tensor_model_parallel_world_size
()
...
...
@@ -158,13 +159,13 @@ class CohereAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
...
...
@@ -218,13 +219,13 @@ class CohereDecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
CohereConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
CohereAttention
(
config
,
linear_method
=
linear_method
)
self
.
self_attn
=
CohereAttention
(
config
,
quant_config
=
quant_config
)
self
.
mlp
=
CohereMLP
(
config
,
linear_method
=
linear_method
)
self
.
mlp
=
CohereMLP
(
config
,
quant_config
=
quant_config
)
self
.
input_layernorm
=
LayerNorm
(
param_shape
=
(
config
.
hidden_size
),
eps
=
config
.
layer_norm_eps
)
...
...
@@ -257,7 +258,7 @@ class CohereModel(nn.Module):
def
__init__
(
self
,
config
:
CohereConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -265,7 +266,7 @@ class CohereModel(nn.Module):
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
([
CohereDecoderLayer
(
config
,
linear_method
=
linear_method
)
CohereDecoderLayer
(
config
,
quant_config
=
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
LayerNorm
(
param_shape
=
(
config
.
hidden_size
),
...
...
@@ -298,14 +299,14 @@ class CohereForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
CohereConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
,
scale
=
config
.
logit_scale
)
self
.
model
=
CohereModel
(
config
,
linear_method
)
self
.
model
=
CohereModel
(
config
,
quant_config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
...
...
vllm/model_executor/models/dbrx.py
View file @
1591c68f
...
...
@@ -9,11 +9,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
QKVParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -44,7 +45,7 @@ class DbrxRouter(nn.Module):
self
.
num_total_experts
,
bias
=
False
,
params_dtype
=
params_dtype
,
linear_method
=
None
,
quant_config
=
None
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -63,7 +64,7 @@ class DbrxExperts(nn.Module):
def
__init__
(
self
,
config
:
DbrxConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
):
super
().
__init__
()
...
...
@@ -165,7 +166,7 @@ class DbrxAttention(nn.Module):
def
__init__
(
self
,
config
:
DbrxConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
d_model
=
config
.
d_model
...
...
@@ -183,13 +184,13 @@ class DbrxAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
out_proj
=
RowParallelLinear
(
self
.
d_model
,
self
.
d_model
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
...
...
@@ -244,11 +245,11 @@ class DbrxFusedNormAttention(nn.Module):
def
__init__
(
self
,
config
:
DbrxConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
d_model
=
config
.
d_model
self
.
attn
=
DbrxAttention
(
config
,
linear_method
)
self
.
attn
=
DbrxAttention
(
config
,
quant_config
)
self
.
norm_1
=
nn
.
LayerNorm
(
self
.
d_model
)
self
.
norm_2
=
nn
.
LayerNorm
(
self
.
d_model
)
...
...
@@ -278,11 +279,11 @@ class DbrxBlock(nn.Module):
def
__init__
(
self
,
config
:
DbrxConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
norm_attn_norm
=
DbrxFusedNormAttention
(
config
,
linear_method
)
self
.
ffn
=
DbrxExperts
(
config
,
linear_method
)
self
.
norm_attn_norm
=
DbrxFusedNormAttention
(
config
,
quant_config
)
self
.
ffn
=
DbrxExperts
(
config
,
quant_config
)
def
forward
(
self
,
...
...
@@ -307,7 +308,7 @@ class DbrxModel(nn.Module):
def
__init__
(
self
,
config
:
DbrxConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
wte
=
VocabParallelEmbedding
(
...
...
@@ -315,7 +316,7 @@ class DbrxModel(nn.Module):
config
.
d_model
,
)
self
.
blocks
=
nn
.
ModuleList
(
[
DbrxBlock
(
config
,
linear_method
)
for
_
in
range
(
config
.
n_layers
)])
[
DbrxBlock
(
config
,
quant_config
)
for
_
in
range
(
config
.
n_layers
)])
self
.
norm_f
=
nn
.
LayerNorm
(
config
.
d_model
,
eps
=
1e-5
)
for
module
in
self
.
modules
():
if
hasattr
(
module
,
"bias"
)
and
isinstance
(
module
.
bias
,
...
...
@@ -348,13 +349,13 @@ class DbrxForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
DbrxConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
transformer
=
DbrxModel
(
config
,
linear_method
)
self
.
transformer
=
DbrxModel
(
config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
d_model
,
...
...
vllm/model_executor/models/decilm.py
View file @
1591c68f
...
...
@@ -29,7 +29,8 @@ import torch
from
transformers
import
PretrainedConfig
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.layers.linear
import
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
...
...
@@ -55,13 +56,13 @@ class DeciLMForCausalLM(LlamaForCausalLM):
def
__init__
(
self
,
config
:
Optional
[
PretrainedConfig
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
config
.
num_key_value_heads
=
max
(
config
.
num_key_value_heads_per_layer
)
delattr
(
config
,
"num_key_value_heads_per_layer"
)
super
().
__init__
(
config
=
config
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
lora_config
=
lora_config
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
vllm/model_executor/models/deepseek.py
View file @
1591c68f
...
...
@@ -34,12 +34,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -56,18 +57,18 @@ class DeepseekMLP(nn.Module):
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
reduce_results
:
bool
=
True
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
...
...
@@ -86,7 +87,7 @@ class DeepseekMoE(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -103,7 +104,7 @@ class DeepseekMoE(nn.Module):
DeepseekMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
reduce_results
=
False
)
for
idx
in
range
(
self
.
n_routed_experts
)
])
...
...
@@ -112,7 +113,7 @@ class DeepseekMoE(nn.Module):
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
n_routed_experts
,
bias
=
False
,
linear_method
=
None
)
quant_config
=
None
)
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
(
config
.
moe_intermediate_size
*
...
...
@@ -121,7 +122,7 @@ class DeepseekMoE(nn.Module):
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
reduce_results
=
False
,
)
...
...
@@ -177,7 +178,7 @@ class DeepseekAttention(nn.Module):
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
...
...
@@ -208,14 +209,14 @@ class DeepseekAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
rotary_emb
=
get_rope
(
...
...
@@ -251,7 +252,7 @@ class DeepseekDecoderLayer(nn.Module):
self
,
config
:
PretrainedConfig
,
layer_idx
:
int
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -266,18 +267,18 @@ class DeepseekDecoderLayer(nn.Module):
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
if
(
config
.
n_routed_experts
is
not
None
and
layer_idx
>=
config
.
first_k_dense_replace
and
layer_idx
%
config
.
moe_layer_freq
==
0
):
self
.
mlp
=
DeepseekMoE
(
config
=
config
,
linear_method
=
linear_method
)
self
.
mlp
=
DeepseekMoE
(
config
=
config
,
quant_config
=
quant_config
)
else
:
self
.
mlp
=
DeepseekMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -320,7 +321,7 @@ class DeepseekModel(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
...
...
@@ -331,9 +332,7 @@ class DeepseekModel(nn.Module):
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
DeepseekDecoderLayer
(
config
,
layer_idx
,
linear_method
=
linear_method
)
DeepseekDecoderLayer
(
config
,
layer_idx
,
quant_config
=
quant_config
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -361,12 +360,12 @@ class DeepseekForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
DeepseekModel
(
config
,
linear_method
)
self
.
quant_config
=
quant_config
self
.
model
=
DeepseekModel
(
config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/falcon.py
View file @
1591c68f
...
...
@@ -32,10 +32,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -76,7 +77,7 @@ class FalconAttention(nn.Module):
def
__init__
(
self
,
config
:
FalconConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
...
...
@@ -115,7 +116,7 @@ class FalconAttention(nn.Module):
self
.
total_num_kv_heads
,
bias
=
config
.
bias
,
skip_bias_add
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
...
...
@@ -129,7 +130,7 @@ class FalconAttention(nn.Module):
self
.
hidden_size
,
bias
=
config
.
bias
,
skip_bias_add
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
reduce_results
=
self
.
reduce_row_parallel_results
)
self
.
use_rotary
=
config
.
rotary
...
...
@@ -192,7 +193,7 @@ class FalconMLP(nn.Module):
def
__init__
(
self
,
config
:
FalconConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
...
...
@@ -201,8 +202,7 @@ class FalconMLP(nn.Module):
4
*
hidden_size
,
bias
=
config
.
bias
,
skip_bias_add
=
True
,
linear_method
=
linear_method
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
quant_config
=
quant_config
)
self
.
act
=
get_act_fn
(
"gelu"
,
quant_config
,
4
*
hidden_size
)
self
.
reduce_row_parallel_results
=
not
(
config
.
new_decoder_architecture
or
config
.
parallel_attn
)
...
...
@@ -212,7 +212,7 @@ class FalconMLP(nn.Module):
bias
=
config
.
bias
,
skip_bias_add
=
True
,
reduce_results
=
self
.
reduce_row_parallel_results
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
...
...
@@ -229,13 +229,13 @@ class FalconDecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
FalconConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
self_attention
=
FalconAttention
(
config
,
linear_method
)
self
.
mlp
=
FalconMLP
(
config
,
linear_method
)
self
.
self_attention
=
FalconAttention
(
config
,
quant_config
)
self
.
mlp
=
FalconMLP
(
config
,
quant_config
)
self
.
config
=
config
if
config
.
new_decoder_architecture
:
...
...
@@ -311,7 +311,7 @@ class FalconModel(nn.Module):
def
__init__
(
self
,
config
:
FalconConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -327,7 +327,7 @@ class FalconModel(nn.Module):
# Transformer blocks
self
.
h
=
nn
.
ModuleList
([
FalconDecoderLayer
(
config
,
linear_method
)
FalconDecoderLayer
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
...
...
@@ -359,12 +359,12 @@ class FalconForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
FalconConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
transformer
=
FalconModel
(
config
,
linear_method
)
self
.
quant_config
=
quant_config
self
.
transformer
=
FalconModel
(
config
,
quant_config
)
self
.
lm_head_weight
=
self
.
transformer
.
word_embeddings
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
Prev
1
…
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment