Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
583034f1
Commit
583034f1
authored
Oct 20, 2025
by
zhuwenwen
Browse files
[models] support step3v
parent
0adf9cda
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5936 additions
and
9 deletions
+5936
-9
vllm/config.py
vllm/config.py
+11
-1
vllm/entrypoints/openai/tool_parsers/__init__.py
vllm/entrypoints/openai/tool_parsers/__init__.py
+3
-0
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+9
-0
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+76
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+269
-3
vllm/model_executor/layers/logits_processor.py
vllm/model_executor/layers/logits_processor.py
+10
-4
vllm/model_executor/layers/quantization/groupwise_quant.py
vllm/model_executor/layers/quantization/groupwise_quant.py
+726
-0
vllm/model_executor/layers/quantization/quant_utils.py
vllm/model_executor/layers/quantization/quant_utils.py
+135
-0
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+7
-0
vllm/model_executor/models/mm_step1o.py
vllm/model_executor/models/mm_step1o.py
+1080
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+27
-0
vllm/model_executor/models/step1.py
vllm/model_executor/models/step1.py
+1442
-0
vllm/model_executor/models/step2_mini.py
vllm/model_executor/models/step2_mini.py
+807
-0
vllm/model_executor/models/step_encoder.py
vllm/model_executor/models/step_encoder.py
+447
-0
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+25
-1
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+29
-0
vllm/transformers_utils/configs/mmgpt.py
vllm/transformers_utils/configs/mmgpt.py
+558
-0
vllm/transformers_utils/configs/step.py
vllm/transformers_utils/configs/step.py
+102
-0
vllm/transformers_utils/configs/step1f.py
vllm/transformers_utils/configs/step1f.py
+164
-0
vllm/transformers_utils/detokenizer_utils.py
vllm/transformers_utils/detokenizer_utils.py
+9
-0
No files found.
vllm/config.py
View file @
583034f1
...
...
@@ -3418,6 +3418,8 @@ def _get_and_verify_max_len(
possible_keys
=
[
# OPT
"max_position_embeddings"
,
# step3
"max_position_embedding"
,
# GPT-2
"n_positions"
,
# MPT
...
...
@@ -3490,8 +3492,14 @@ def _get_and_verify_max_len(
# No need to consider "type" key because of patch_rope_scaling when
# loading HF config
rope_type
=
rope_scaling
[
"rope_type"
]
if
rope_type
==
"ntk_bypart"
:
derived_max_model_len
=
min
(
derived_max_model_len
,
rope_scaling
[
"real_length"
]
*
rope_scaling
[
"scaling_factor"
]
)
if
"real_length"
in
rope_scaling
and
"scaling_factor"
in
rope_scaling
else
derived_max_model_len
if
rope_type
not
in
(
"su"
,
"longrope"
,
"llama3"
):
el
if
rope_type
not
in
(
"su"
,
"longrope"
,
"llama3"
):
if
disable_sliding_window
:
# TODO(robertgshaw): Find a model that supports rope_scaling
# with sliding window to see if this case should be allowed.
...
...
@@ -3548,6 +3556,8 @@ def _get_and_verify_max_len(
logger
.
warning
(
"%s Make sure the value is correct and within the "
"model context size."
,
msg
)
if
getattr
(
hf_config
,
"max_position_embedding"
,
None
)
is
not
None
:
# step3/3v
hf_config
.
max_position_embedding
=
max_model_len
else
:
raise
ValueError
(
f
"
{
msg
}
To allow overriding this maximum, set "
...
...
vllm/entrypoints/openai/tool_parsers/__init__.py
View file @
583034f1
...
...
@@ -36,4 +36,7 @@ __all__ = [
"xLAMToolParser"
,
"MinimaxToolParser"
,
"Glm4MoeModelToolParser"
,
"Step1p5vMini2ToolParser"
,
"Step1p5vMini2MsToolParser"
,
"Step3ToolParser"
,
]
vllm/model_executor/layers/activation.py
View file @
583034f1
...
...
@@ -3,6 +3,7 @@
"""Custom activation functions."""
import
math
from
typing
import
Optional
import
optimus
import
torch
import
torch.nn
as
nn
...
...
@@ -53,6 +54,14 @@ class FatreluAndMul(CustomOp):
return
out
class
OptimusSiluAndMul
(
nn
.
Module
):
def
forward
(
self
,
x
:
torch
.
Tensor
,
output
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
ops
.
Optimus
.
SiluDot_forward
(
x
,
out
=
output
)
@
CustomOp
.
register
(
"silu_and_mul"
)
class
SiluAndMul
(
CustomOp
):
"""An activation function for SwiGLU.
...
...
vllm/model_executor/layers/layernorm.py
View file @
583034f1
...
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Custom normalization layers."""
from
typing
import
Optional
,
Union
,
Tuple
import
optimus
# noqa F401
import
torch
import
torch.nn
as
nn
...
...
@@ -298,6 +299,49 @@ class RMSNorm(CustomOp):
return
s
class
OptimusRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
eps
:
float
=
1e-6
,
)
->
None
:
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
fp16_out
:
bool
=
False
)
->
torch
.
Tensor
:
if
residual
is
not
None
:
assert
output
is
None
from
vllm
import
_custom_ops
as
ops
assert
not
fp16_out
ops
.
fused_add_rms_norm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
,
)
return
x
,
residual
else
:
if
fp16_out
:
if
output
is
None
:
output
=
torch
.
empty_like
(
x
).
half
()
else
:
output
=
output
.
half
()
# return torch.ops.Optimus.rms_norm(x,
# self.weight,
# self.variance_epsilon,
# out=output)
return
torch
.
nn
.
functional
.
rms_norm
(
x
,
self
.
weight
,
self
.
variance_epsilon
,
out
=
output
)
@
CustomOp
.
register
(
"gemma_rms_norm"
)
class
GemmaRMSNorm
(
CustomOp
):
"""RMS normalization for Gemma.
...
...
@@ -363,3 +407,35 @@ class GemmaRMSNorm(CustomOp):
self
.
forward_static
)
self
.
_is_compiled
=
True
return
self
.
forward_native
(
x
,
residual
)
class
OptimusLayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
eps
:
float
=
1e-6
)
->
None
:
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
assert
residual
is
None
# return torch.ops.Optimus.layer_norm(x,
# self.weight,
# self.bias,
# eps=self.variance_epsilon,
# out=output)
# return torch.nn.functional.layer_norm(x,
# self.weight,
# self.bias,
# eps=self.variance_epsilon,
# out=output)
return
torch
.
nn
.
functional
.
layer_norm
(
x
,
self
.
weight
.
shape
,
# normalized_shape 应为 weight 的形状
self
.
weight
,
self
.
bias
,
eps
=
self
.
variance_epsilon
)
vllm/model_executor/layers/linear.py
View file @
583034f1
...
...
@@ -3,7 +3,7 @@
import
itertools
from
abc
import
abstractmethod
from
typing
import
Any
,
Literal
,
Optional
,
Union
from
typing
import
Any
,
Literal
,
Optional
,
Union
,
List
import
vllm.envs
as
envs
import
torch
import
torch.nn
as
nn
...
...
@@ -269,6 +269,40 @@ class UnquantizedLinearMethod(LinearMethodBase):
return
dispatch_unquantized_gemm
()(
x
,
layer
.
weight
,
bias
)
class
UnquantizedMoELinearMethod
(
LinearMethodBase
):
"""MoE Linear method without quantization.
"""
def
__init__
(
self
):
self
.
quant_config
=
None
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
num_experts
:
Optional
[
int
]
=
None
,
**
extra_weight_attrs
):
weight
=
Parameter
(
torch
.
empty
(
num_experts
,
sum
(
output_partition_sizes
),
input_size_per_partition
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
params_dtype
),
requires_grad
=
False
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
2
,
"output_dim"
:
1
})
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""Apply the weights to the input tensor."""
raise
NotImplementedError
class
LinearBase
(
torch
.
nn
.
Module
):
"""Base linear layer.
...
...
@@ -783,6 +817,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param
.
shard_id
.
append
(
loaded_shard_id
)
param
.
shard_id_map
[
loaded_shard_id
]
=
len
(
param
.
data_container
)
param
.
data_container
.
append
(
loaded_weight
)
if
len
(
param
.
data_container
)
==
2
:
self
.
qweight
=
param
.
materialize_nested
()
return
param_data
=
param
.
data
...
...
@@ -986,6 +1022,175 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset
=
shard_offset
,
shard_size
=
shard_size
)
class
MergedColumnParallelMoELinear
(
MergedColumnParallelLinear
):
def
__init__
(
self
,
num_experts
:
int
,
input_size
:
int
,
output_sizes
:
List
[
int
],
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
torch
.
nn
.
Module
.
__init__
(
self
)
output_size
=
sum
(
output_sizes
)
self
.
num_experts
=
num_experts
self
.
output_sizes
=
output_sizes
self
.
input_size
=
input_size
self
.
output_size
=
sum
(
output_sizes
)
tp_size
=
get_tensor_model_parallel_world_size
()
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
self
.
output_size_per_partition
=
divide
(
self
.
output_size
,
tp_size
)
self
.
output_partition_sizes
=
[
divide
(
output_size
,
tp_size
)
for
output_size
in
self
.
output_sizes
]
self
.
gather_output
=
False
if
output_sizes
is
None
:
output_sizes
=
[
output_size
]
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
if
quant_config
is
None
:
self
.
quant_method
=
UnquantizedMoELinearMethod
()
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
# FIXME(ys): hack for moe
if
isinstance
(
self
.
quant_method
,
UnquantizedLinearMethod
):
self
.
quant_method
=
UnquantizedMoELinearMethod
()
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
self
,
self
.
input_size
,
self
.
output_partition_sizes
,
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
,
self
.
num_experts
,
weight_loader
=
self
.
weight_loader
)
self
.
register_parameter
(
"bias"
,
None
)
def
forward
(
self
,
input_
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_idx
:
int
=
-
1
):
if
isinstance
(
self
.
quant_method
,
UnquantizedMoELinearMethod
):
# use optimus moe_ffn outside
return
bias
=
None
assert
self
.
quant_method
is
not
None
output
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
,
expert_idx
=
expert_idx
,
output
=
output
)
return
output
class
QKVReplicatedLinear
(
ReplicatedLinear
):
def
__init__
(
self
,
hidden_size
:
int
,
head_size
:
int
,
total_num_heads
:
int
,
total_num_kv_heads
:
Optional
[
int
]
=
None
,
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
return_bias
:
bool
=
True
):
nn
.
Module
.
__init__
(
self
)
self
.
hidden_size
=
hidden_size
self
.
head_size
=
head_size
self
.
num_heads
=
total_num_heads
self
.
num_kv_heads
=
total_num_kv_heads
if
total_num_kv_heads
else
total_num_heads
self
.
input_size
=
self
.
hidden_size
self
.
output_size
=
(
self
.
num_heads
+
2
*
self
.
num_kv_heads
)
*
self
.
head_size
self
.
skip_bias_add
=
skip_bias_add
self
.
return_bias
=
return_bias
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedLinearMethod
()
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
self
,
self
.
input_size
,
[
self
.
output_size
],
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
,
weight_loader
=
self
.
weight_loader
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
dtype
=
self
.
params_dtype
))
set_weight_attrs
(
self
.
bias
,
{
"output_dim"
:
0
,
"weight_loader"
:
self
.
weight_loader
})
else
:
self
.
register_parameter
(
"bias"
,
None
)
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
str
]
=
None
):
param_data
=
param
.
data
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
is_quantization
=
not
isinstance
(
self
.
quant_method
,
UnquantizedLinearMethod
)
if
loaded_shard_id
is
None
:
# Loaded weight is already packed.
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
return
assert
loaded_shard_id
in
[
"q"
,
"k"
,
"v"
]
if
output_dim
is
not
None
:
if
loaded_shard_id
==
"q"
:
shard_offset
=
0
shard_size
=
self
.
num_heads
*
self
.
head_size
elif
loaded_shard_id
==
"k"
:
shard_offset
=
self
.
num_heads
*
self
.
head_size
shard_size
=
self
.
num_kv_heads
*
self
.
head_size
elif
loaded_shard_id
==
"v"
:
shard_offset
=
(
self
.
num_heads
+
self
.
num_kv_heads
)
*
self
.
head_size
shard_size
=
self
.
num_kv_heads
*
self
.
head_size
# If quantized, we need to adjust the offset and size to account
# for the packing.
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
if
not
envs
.
VLLM_USE_NN
or
is_quantization
:
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
else
:
param_data
=
param_data
.
narrow
(
int
(
not
(
output_dim
)),
shard_offset
,
shard_size
)
else
:
ignore_warning
=
getattr
(
param
,
"ignore_warning"
,
False
)
if
not
ignore_warning
:
logger
.
warning
(
"Loading a weight without `output_dim` attribute in "
"QKVReplicatedLinear, assume the weight is the same "
"for all partitions."
)
if
envs
.
VLLM_USE_NN
and
not
is_quantization
:
loaded_weight
=
loaded_weight
.
t
()
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
class
QKVParallelLinear
(
ColumnParallelLinear
):
"""Linear layers for the attention's QKV transformation.
...
...
@@ -1185,6 +1390,8 @@ class QKVParallelLinear(ColumnParallelLinear):
param
.
shard_id
.
append
(
loaded_shard_id
)
param
.
shard_id_map
[
loaded_shard_id
]
=
len
(
param
.
data_container
)
param
.
data_container
.
append
(
loaded_weight
)
if
len
(
param
.
data_container
)
==
3
:
self
.
qweight
=
param
.
materialize_nested
()
return
param_data
=
param
.
data
...
...
@@ -1495,7 +1702,7 @@ class RowParallelLinear(LinearBase):
def
forward
(
self
,
input_
,
use_fused_silu_mul_quant
:
Optional
[
bool
]
=
False
use_fused_silu_mul_quant
:
Optional
[
bool
]
=
False
,
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
if
self
.
input_is_parallel
:
input_parallel
=
input_
...
...
@@ -1757,4 +1964,63 @@ class QKVCrossParallelLinear(LinearBase):
s
+=
f
", bias=
{
self
.
bias
is
not
None
}
"
s
+=
f
", tp_size=
{
get_tensor_model_parallel_world_size
()
}
"
s
+=
", gather_output=False"
return
s
\ No newline at end of file
return
s
class
RowParallelMoELinear
(
RowParallelLinear
):
def
__init__
(
self
,
num_experts
:
int
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
torch
.
nn
.
Module
.
__init__
(
self
)
self
.
num_experts
=
num_experts
self
.
input_size
=
input_size
self
.
output_size
=
output_size
self
.
reduce_results
=
False
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedMoELinearMethod
()
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
# FIXME(ys): hack for moe
if
isinstance
(
self
.
quant_method
,
UnquantizedLinearMethod
):
self
.
quant_method
=
UnquantizedMoELinearMethod
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
self
,
self
.
input_size_per_partition
,
[
self
.
output_size
],
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
,
self
.
num_experts
,
weight_loader
=
self
.
weight_loader
)
self
.
register_parameter
(
"bias"
,
None
)
def
forward
(
# type: ignore[override]
self
,
input_
,
residual
=
None
,
expert_idx
:
int
=
-
1
,
output
:
Optional
[
torch
.
Tensor
]
=
None
):
if
isinstance
(
self
.
quant_method
,
UnquantizedMoELinearMethod
):
# use optimus moe_ffn outside
return
bias
=
None
assert
self
.
quant_method
is
not
None
output
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
,
expert_idx
=
expert_idx
,
output
=
output
)
return
output
\ No newline at end of file
vllm/model_executor/layers/logits_processor.py
View file @
583034f1
...
...
@@ -36,7 +36,8 @@ class LogitsProcessor(nn.Module):
org_vocab_size
:
Optional
[
int
]
=
None
,
scale
:
float
=
1.0
,
logits_as_input
:
bool
=
False
,
soft_cap
:
Optional
[
float
]
=
None
)
->
None
:
soft_cap
:
Optional
[
float
]
=
None
,
need_fp32_logits
:
bool
=
False
)
->
None
:
"""
Args:
scale: A scaling factor to apply to the logits.
...
...
@@ -52,6 +53,7 @@ class LogitsProcessor(nn.Module):
self
.
soft_cap
=
soft_cap
# Whether to use gather or all-gather to gather the logits.
self
.
use_all_gather
=
current_platform
.
use_all_gather
()
self
.
need_fp32_logits
=
need_fp32_logits
def
forward
(
self
,
...
...
@@ -106,9 +108,13 @@ class LogitsProcessor(nn.Module):
embedding_bias
:
Optional
[
torch
.
Tensor
],
)
->
Optional
[
torch
.
Tensor
]:
# Get the logits for the next tokens.
logits
=
lm_head
.
quant_method
.
apply
(
lm_head
,
hidden_states
,
bias
=
embedding_bias
)
if
self
.
need_fp32_logits
:
logits
=
torch
.
ops
.
OptimusMoe
.
matmul_fp32
(
hidden_states
,
lm_head
.
weight
.
t
())
else
:
logits
=
lm_head
.
quant_method
.
apply
(
lm_head
,
hidden_states
,
bias
=
embedding_bias
)
# Gather logits for TP
logits
=
self
.
_gather_logits
(
logits
)
...
...
vllm/model_executor/layers/quantization/groupwise_quant.py
0 → 100644
View file @
583034f1
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.layers.fused_moe.optimus_moe
import
(
# noqa: F401
optimus_moe_int8
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
direct_register_custom_op
class
GroupwiseQuantConfig
(
QuantizationConfig
):
"""Config class for Groupwise Quantization.
"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
symmetric
:
bool
=
False
,
bf16_blocks
:
Optional
[
list
]
=
None
,
int8_blocks
:
Optional
[
list
]
=
None
,
weight_dtype
:
Optional
[
str
]
=
None
,
# Literal["int8", "fp8_e4m3", "fp6", "int4"] = "int8",
# FIXME: hack for mixed precision quantization
extra_quant_configs
:
Optional
[
dict
]
=
None
,
)
->
None
:
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
symmetric
=
symmetric
self
.
bf16_blocks
=
bf16_blocks
if
bf16_blocks
else
[]
self
.
int8_blocks
=
int8_blocks
if
int8_blocks
else
[]
self
.
extra_quant_configs
=
extra_quant_configs
if
self
.
weight_bits
==
4
:
self
.
pack_factor
=
32
//
self
.
weight_bits
else
:
self
.
pack_factor
=
1
if
not
weight_dtype
:
if
weight_bits
==
8
:
self
.
weight_dtype
=
"int8"
# fp8e4m3 must explicitly set
elif
weight_bits
==
6
:
self
.
weight_dtype
=
"fp6"
elif
weight_bits
==
4
:
self
.
weight_dtype
=
"int4"
else
:
raise
ValueError
(
f
"Unsupported weight bits:
{
weight_bits
}
"
)
else
:
self
.
weight_dtype
=
weight_dtype
def
__repr__
(
self
)
->
str
:
return
(
f
"GroupwiseQuantConfig(weight_bits=
{
self
.
weight_bits
}
, "
f
"group_size=
{
self
.
group_size
}
)"
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"groupwise_quant"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# The Groupwise Quant kernel only supports Ampere or newer GPUs.
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"quant_config.json"
,
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"GroupwiseQuantConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"w_bit"
,
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"q_group_size"
,
"group_size"
])
bf16_blocks
=
config
.
get
(
"bf16_blocks"
,
[])
int8_blocks
=
config
.
get
(
"int8_blocks"
,
[])
weight_dtype
=
config
.
get
(
"weight_type"
)
extra_quant_configs
=
config
.
get
(
"extra_quant_configs"
,
{})
return
cls
(
weight_bits
,
group_size
,
bf16_blocks
=
bf16_blocks
,
int8_blocks
=
int8_blocks
,
weight_dtype
=
weight_dtype
,
extra_quant_configs
=
extra_quant_configs
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
=
""
)
->
Optional
[
"QuantizeMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
block_index
=
int
(
prefix
.
split
(
"layers."
)[
1
].
split
(
"."
)[
0
])
layer_name
=
prefix
.
split
(
"."
)[
-
1
]
if
block_index
in
self
.
bf16_blocks
:
return
UnquantizedLinearMethod
()
if
self
.
extra_quant_configs
:
for
config
in
self
.
extra_quant_configs
:
if
block_index
in
config
[
"block_indices"
]
and
layer_name
in
config
[
"target_modules"
]:
return
GroupwiseQuantLinearMethod
(
GroupwiseQuantConfig
(
weight_bits
=
config
[
"weight_bit"
],
group_size
=
config
[
"group_size"
]
or
self
.
group_size
,
weight_dtype
=
config
[
"weight_type"
],
extra_quant_configs
=
self
.
extra_quant_configs
))
# no specific config for this layer means no quantization
return
UnquantizedLinearMethod
()
else
:
# Compatible with old config
return
GroupwiseQuantLinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
# For MoE layers, only support 8-bit quantization
if
self
.
weight_bits
==
8
:
return
GroupwiseInt8MoeMethod
(
self
)
else
:
raise
ValueError
(
f
"Unsupported weight bits for MoE:
{
self
.
weight_bits
}
. Only 8-bit is supported."
)
return
None
class
GroupwiseQuantLinearMethod
(
LinearMethodBase
):
"""Linear method for GroupwiseQuant.
Args:
quant_config: The groupwise_quant quantization config.
"""
def
__init__
(
self
,
quant_config
:
GroupwiseQuantConfig
)
->
None
:
self
.
quant_config
=
quant_config
self
.
sm
=
torch
.
cuda
.
get_device_capability
()
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
num_experts
:
Optional
[
int
]
=
None
,
**
extra_weight_attrs
):
assert
input_size_per_partition
%
self
.
quant_config
.
group_size
==
0
output_size_per_partition
=
sum
(
output_partition_sizes
)
assert
output_size_per_partition
%
self
.
quant_config
.
pack_factor
==
0
layer_keys
=
dir
(
layer
)
has_num_experts
=
any
(
"num_experts"
in
name
for
name
in
layer_keys
)
if
not
has_num_experts
:
layer
.
register_parameter
(
"num_experts"
,
None
)
if
self
.
quant_config
.
weight_bits
==
4
:
assert
input_size_per_partition
%
self
.
quant_config
.
group_size
==
0
assert
output_size_per_partition
%
self
.
quant_config
.
pack_factor
==
0
if
num_experts
:
weight_shape
=
[
num_experts
,
input_size_per_partition
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
]
scale_shape
=
[
num_experts
,
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
]
else
:
weight_shape
=
[
input_size_per_partition
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
]
scale_shape
=
[
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
]
qweight
=
Parameter
(
torch
.
empty
(
*
weight_shape
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
scales
=
Parameter
(
torch
.
empty
(
*
scale_shape
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
zeros
=
Parameter
(
torch
.
empty
(
*
scale_shape
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
if
num_experts
:
set_weight_attrs
(
qweight
,
{
"input_dim"
:
1
,
"output_dim"
:
2
,
"packed_dim"
:
2
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
set_weight_attrs
(
scales
,
{
"input_dim"
:
1
,
"output_dim"
:
2
,
})
set_weight_attrs
(
zeros
,
{
"input_dim"
:
1
,
"output_dim"
:
2
,
})
else
:
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
set_weight_attrs
(
scales
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
})
set_weight_attrs
(
zeros
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
})
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
layer
.
register_parameter
(
"zeros"
,
zeros
)
set_weight_attrs
(
zeros
,
extra_weight_attrs
)
elif
self
.
quant_config
.
weight_bits
==
8
:
assert
input_size_per_partition
%
self
.
quant_config
.
group_size
==
0
if
num_experts
:
weight_shape
=
[
num_experts
,
input_size_per_partition
,
output_size_per_partition
]
scale_shape
=
[
num_experts
,
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
]
else
:
weight_shape
=
[
input_size_per_partition
,
output_size_per_partition
]
scale_shape
=
[
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
]
qweight
=
Parameter
(
torch
.
empty
(
*
weight_shape
,
device
=
"cuda"
,
dtype
=
torch
.
int8
,
),
requires_grad
=
False
,
)
scales
=
Parameter
(
torch
.
empty
(
*
scale_shape
,
device
=
"cuda"
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
if
num_experts
:
set_weight_attrs
(
qweight
,
{
"input_dim"
:
1
,
"output_dim"
:
2
,
})
set_weight_attrs
(
scales
,
{
"input_dim"
:
1
,
"output_dim"
:
2
,
})
else
:
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
})
set_weight_attrs
(
scales
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
})
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
elif
self
.
quant_config
.
weight_bits
==
6
:
assert
input_size_per_partition
%
self
.
quant_config
.
group_size
==
0
if
num_experts
:
weight_shape
=
[
num_experts
,
output_size_per_partition
,
input_size_per_partition
,
]
scale_shape
=
[
num_experts
,
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
]
else
:
weight_shape
=
[
output_size_per_partition
,
input_size_per_partition
]
scale_shape
=
[
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
]
qweight
=
Parameter
(
torch
.
zeros
(
*
weight_shape
,
device
=
"cpu"
,
# hack for fp6 weight is stored in float16, to avoid cuda oom
dtype
=
torch
.
float16
,
),
requires_grad
=
False
,
)
scales
=
Parameter
(
torch
.
empty
(
*
scale_shape
,
device
=
"cuda"
,
dtype
=
torch
.
float16
,
),
requires_grad
=
False
,
)
if
num_experts
:
set_weight_attrs
(
qweight
,
{
"input_dim"
:
2
,
"output_dim"
:
1
,
})
set_weight_attrs
(
scales
,
{
"input_dim"
:
1
,
"output_dim"
:
2
,
})
else
:
set_weight_attrs
(
qweight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
,
})
set_weight_attrs
(
scales
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
})
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
else
:
raise
NotImplementedError
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
if
not
hasattr
(
layer
,
"qweight"
):
return
if
self
.
quant_config
.
weight_bits
==
4
:
num_experts
=
layer
.
num_experts
qweight
=
layer
.
qweight
zeros
=
layer
.
zeros
scales
=
layer
.
scales
if
num_experts
:
qscales_list
=
[]
for
i
in
range
(
num_experts
):
qweight_processed
,
qscales
=
torch
.
ops
.
Optimus
.
GemmInt4GroupQuantWeight
(
qweight
[
i
],
zeros
[
i
],
scales
[
i
]
+
zeros
[
i
],
self
.
quant_config
.
group_size
)
qweight
[
i
].
copy_
(
qweight_processed
)
qscales_list
.
append
(
qscales
)
qscales
=
Parameter
(
torch
.
stack
(
qscales_list
),
requires_grad
=
False
)
layer
.
register_parameter
(
"qscales"
,
qscales
)
else
:
qweight_processed
,
qscales
=
torch
.
ops
.
Optimus
.
GemmInt4GroupQuantWeight
(
qweight
,
zeros
,
scales
+
zeros
,
self
.
quant_config
.
group_size
)
qweight
.
copy_
(
qweight_processed
)
qscales
=
Parameter
(
qscales
,
requires_grad
=
False
)
layer
.
register_parameter
(
"qscales"
,
qscales
)
layer
.
_parameters
.
pop
(
"zeros"
)
layer
.
_parameters
.
pop
(
"scales"
)
elif
self
.
quant_config
.
weight_bits
==
8
:
num_experts
=
layer
.
num_experts
qweight
=
layer
.
qweight
if
num_experts
:
for
i
in
range
(
num_experts
):
qweight
[
i
].
copy_
(
torch
.
ops
.
Optimus
.
FpAIntBPreprocessWeightGPU
(
qweight
[
i
].
t
().
contiguous
(),
torch
.
int8
))
else
:
qweight
.
copy_
(
torch
.
ops
.
Optimus
.
FpAIntBPreprocessWeightGPU
(
qweight
.
t
().
contiguous
(),
torch
.
int8
))
elif
self
.
quant_config
.
weight_bits
==
6
:
num_experts
=
layer
.
num_experts
qweight
=
layer
.
qweight
layer
.
_parameters
.
pop
(
"qweight"
)
assert
qweight
.
shape
[
-
1
]
%
8
==
0
if
num_experts
:
qweight_processed
=
torch
.
empty
(
qweight
.
shape
[
0
],
qweight
.
shape
[
1
],
qweight
.
shape
[
2
]
*
6
//
8
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
else
:
qweight_processed
=
torch
.
empty
(
qweight
.
shape
[
0
],
qweight
.
shape
[
1
]
*
6
//
8
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
if
num_experts
:
for
i
in
range
(
num_experts
):
qweight_processed
[
i
]
=
torch
.
ops
.
Optimus
.
fp6_preprocess_weight
(
qweight
[
i
].
cpu
()).
cuda
()
qweight_processed
=
Parameter
(
qweight_processed
,
requires_grad
=
False
)
layer
.
register_parameter
(
"qweight"
,
qweight_processed
)
else
:
qweight_processed
=
Parameter
(
torch
.
ops
.
Optimus
.
fp6_preprocess_weight
(
qweight
.
cpu
()).
cuda
(),
requires_grad
=
False
)
layer
.
register_parameter
(
"qweight"
,
qweight_processed
)
else
:
raise
NotImplementedError
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_idx
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
if
self
.
quant_config
.
weight_bits
==
4
:
qweight
=
layer
.
qweight
qscales
=
layer
.
qscales
num_experts
=
layer
.
num_experts
if
num_experts
:
assert
expert_idx
is
not
None
,
"expert_idx is None"
qweight
=
qweight
[
expert_idx
]
qscales
=
qscales
[
expert_idx
]
out
=
torch
.
ops
.
vllm
.
optimus_gemm_int4_group
(
x
,
qweight
,
qscales
,
bias
,
None
,
# Placeholder for a fifth argument that is None
out
=
output
)
if
residual
is
not
None
:
out
+=
residual
return
out
elif
self
.
quant_config
.
weight_bits
==
8
:
qweight
=
layer
.
qweight
scales
=
layer
.
scales
num_experts
=
layer
.
num_experts
if
num_experts
:
assert
expert_idx
is
not
None
,
"expert_idx is None"
qweight
=
qweight
[
expert_idx
]
scales
=
scales
[
expert_idx
]
if
residual
is
not
None
:
assert
output
is
None
or
output
is
residual
out
=
torch
.
ops
.
vllm
.
optimus_fp_aintb_gemm
(
x
,
qweight
,
torch
.
int8
,
# Placeholder for dtype argument
scales
,
residual
,
out
=
residual
)
if
bias
is
not
None
:
out
+=
bias
else
:
out
=
torch
.
ops
.
vllm
.
optimus_fp_aintb_gemm
(
x
,
qweight
,
torch
.
int8
,
# Placeholder for dtype argument
scales
,
bias
,
out
=
output
)
return
out
elif
self
.
quant_config
.
weight_bits
==
6
:
qweight
=
layer
.
qweight
scales
=
layer
.
scales
num_experts
=
layer
.
num_experts
if
num_experts
:
assert
expert_idx
is
not
None
,
"expert_idx is None"
qweight
=
qweight
[
expert_idx
]
scales
=
scales
[
expert_idx
]
if
x
.
dtype
!=
torch
.
bfloat16
:
if
output
is
None
:
output
=
torch
.
empty
(
x
.
shape
[
0
],
qweight
.
shape
[
0
],
device
=
x
.
device
,
dtype
=
torch
.
bfloat16
)
else
:
output
=
output
.
to
(
torch
.
bfloat16
)
out
=
torch
.
ops
.
vllm
.
optimus_fp6_linear
(
x
,
qweight
,
scales
,
4
,
# Placeholder for fp6_format_code
out
=
output
)
if
bias
is
not
None
:
out
+=
bias
if
residual
is
not
None
:
out
+=
residual
return
out
else
:
raise
NotImplementedError
class
GroupwiseInt8MoeMethod
(
FusedMoEMethodBase
):
"""MoE method for Groupwise INT8 quantization.
Args:
quant_config: The groupwise quantization config.
"""
def
__init__
(
self
,
quant_config
:
GroupwiseQuantConfig
):
self
.
quant_config
=
quant_config
assert
self
.
quant_config
.
weight_bits
==
8
,
"Only 8-bit quantization is supported for GroupwiseInt8MoeMethod"
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
extra_weight_attrs
.
update
({
"is_transposed"
:
True
,
"quant_method"
:
FusedMoeWeightScaleSupported
.
GROUP
.
value
,
})
# Create INT8 weights
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
2
*
intermediate_size_per_partition
,
dtype
=
torch
.
int8
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
{
"input_dim"
:
1
,
"output_dim"
:
2
,
})
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size_per_partition
,
hidden_size
,
dtype
=
torch
.
int8
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
{
"input_dim"
:
1
,
"output_dim"
:
2
,
})
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# Create scales for groupwise quantization
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
//
self
.
quant_config
.
group_size
,
2
*
intermediate_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
set_weight_attrs
(
w13_weight_scale
,
{
"input_dim"
:
1
,
"output_dim"
:
2
,
})
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size_per_partition
//
self
.
quant_config
.
group_size
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
set_weight_attrs
(
w2_weight_scale
,
{
"input_dim"
:
1
,
"output_dim"
:
2
,
})
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# Process weights similar to GroupwiseQuantLinearMethod for 8-bit case
num_experts
=
layer
.
w13_weight
.
shape
[
0
]
for
expert
in
range
(
num_experts
):
# Preprocess w13 weight (gate and up combined)
layer
.
w13_weight
[
expert
].
copy_
(
torch
.
ops
.
Optimus
.
FpAIntBPreprocessWeightGPU
(
layer
.
w13_weight
[
expert
].
t
().
contiguous
(),
torch
.
int8
))
# Preprocess w2 weight (down)
layer
.
w2_weight
[
expert
].
copy_
(
torch
.
ops
.
Optimus
.
FpAIntBPreprocessWeightGPU
(
layer
.
w2_weight
[
expert
].
t
().
contiguous
(),
torch
.
int8
))
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
return
torch
.
ops
.
vllm
.
optimus_moe_int8
(
hidden_states
=
x
,
router_logits
=
router_logits
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
top_k
=
top_k
,
global_num_experts
=
global_num_experts
,
norm_expert_weight
=
renormalize
,
activation
=
activation
,
)
# Wrapper and Fake Functions for Optimus::GemmInt4Group
def
optimus_gemm_int4_group
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
qscales
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
torch
.
ops
.
Optimus
.
GemmInt4Group
(
x
,
qweight
,
qscales
,
bias
,
None
,
out
=
out
)
def
optimus_gemm_int4_group_fake
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
qscales
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output_shape
=
list
(
x
.
shape
[:
-
1
])
+
[
qscales
.
shape
[
-
1
]]
if
out
is
not
None
:
return
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
return
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
# Wrapper and Fake Functions for Optimus::FpAIntBGemm
def
optimus_fp_aintb_gemm
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
dtype_arg
:
torch
.
dtype
,
scales
:
torch
.
Tensor
,
bias_or_residual
:
Optional
[
torch
.
Tensor
],
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
torch
.
ops
.
Optimus
.
FpAIntBGemm
(
x
,
qweight
,
dtype_arg
,
scales
,
bias_or_residual
,
"identity"
,
out
=
out
)
def
optimus_fp_aintb_gemm_fake
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
dtype_arg
:
torch
.
dtype
,
scales
:
torch
.
Tensor
,
bias_or_residual
:
Optional
[
torch
.
Tensor
],
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output_shape
=
list
(
x
.
shape
[:
-
1
])
+
[
qweight
.
shape
[
-
1
]]
if
out
is
not
None
:
return
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
return
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
# Wrapper and Fake Functions for Optimus::fp6_linear
def
optimus_fp6_linear
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
fp6_format_code
:
int
,
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
torch
.
ops
.
Optimus
.
fp6_linear
(
x
,
qweight
,
scales
,
fp6_format_code
,
out
=
out
)
def
optimus_fp6_linear_fake
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
fp6_format_code
:
int
,
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output_channels
=
scales
.
shape
[
-
1
]
output_shape
=
list
(
x
.
shape
[:
-
1
])
+
[
output_channels
]
output_dtype
=
x
.
dtype
if
x
.
dtype
!=
torch
.
bfloat16
:
output_dtype
=
torch
.
bfloat16
if
out
is
not
None
:
return
torch
.
empty
(
output_shape
,
dtype
=
output_dtype
,
device
=
x
.
device
)
return
torch
.
empty
(
output_shape
,
dtype
=
output_dtype
,
device
=
x
.
device
)
direct_register_custom_op
(
op_name
=
"optimus_gemm_int4_group"
,
op_func
=
optimus_gemm_int4_group
,
mutates_args
=
[
"out"
],
fake_impl
=
optimus_gemm_int4_group_fake
,
)
direct_register_custom_op
(
op_name
=
"optimus_fp_aintb_gemm"
,
op_func
=
optimus_fp_aintb_gemm
,
mutates_args
=
[
"out"
,
"bias_or_residual"
],
fake_impl
=
optimus_fp_aintb_gemm_fake
,
)
direct_register_custom_op
(
op_name
=
"optimus_fp6_linear"
,
op_func
=
optimus_fp6_linear
,
mutates_args
=
[
"out"
],
fake_impl
=
optimus_fp6_linear_fake
,
)
\ No newline at end of file
vllm/model_executor/layers/quantization/quant_utils.py
0 → 100644
View file @
583034f1
import
torch
@
torch
.
jit
.
script
def
cal_scale
(
amax
,
fp_max
,
scale
):
margin
=
0
exp
=
torch
.
floor
(
torch
.
log2
(
fp_max
/
amax
))
-
margin
sf
=
torch
.
round
(
torch
.
pow
(
2
,
torch
.
abs
(
exp
)))
sf
=
torch
.
where
(
amax
>
0.0
,
sf
,
scale
)
sf
=
torch
.
where
(
torch
.
isfinite
(
amax
),
sf
,
scale
)
scale
=
torch
.
where
(
exp
<
0
,
1
/
sf
,
sf
)
scale_inv
=
torch
.
reciprocal
(
scale
)
return
scale
,
scale_inv
instances
=
{}
def
singleton
(
cls
):
global
instances
def
get_instance
(
*
args
,
**
kwargs
):
if
cls
not
in
instances
:
instances
[
cls
]
=
cls
(
*
args
,
**
kwargs
)
return
instances
[
cls
]
return
get_instance
def
reset_singleton
():
global
instances
instances
=
{}
@
singleton
class
QuantFp8
:
def
__init__
(
self
,
device
):
self
.
fp_max
=
torch
.
tensor
([
448.0
],
device
=
device
)
self
.
device
=
device
self
.
scale
=
torch
.
tensor
([
1.0
],
device
=
self
.
device
)
pass
@
staticmethod
def
quantize_v1
(
weight
,
bits
):
if
bits
==
8
:
amax
=
weight
.
abs
().
max
()
fp_max
=
torch
.
tensor
([
448.0
]).
to
(
weight
.
device
)
margin
=
0
scale
=
torch
.
tensor
([
1.0
]).
to
(
weight
.
device
)
exp
=
torch
.
floor
(
torch
.
log2
(
fp_max
/
amax
))
-
margin
sf
=
torch
.
round
(
torch
.
pow
(
2
,
torch
.
abs
(
exp
)))
sf
=
torch
.
where
(
amax
>
0.0
,
sf
,
scale
)
sf
=
torch
.
where
(
torch
.
isfinite
(
amax
),
sf
,
scale
)
scale
=
torch
.
where
(
exp
<
0
,
1
/
sf
,
sf
)
qweight
=
(
weight
.
to
(
torch
.
float32
)
*
scale
).
to
(
torch
.
float8_e4m3fn
)
scale
=
torch
.
reciprocal
(
scale
)
# print(f"amax={amax},scalse={scale}")
else
:
raise
ValueError
(
f
"Unsupported bit width:
{
bits
}
"
)
return
qweight
,
scale
def
quantize
(
self
,
weight
,
bits
,
weight_scale
,
use_offline_input_scales
):
if
bits
==
8
:
amax
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
scale
=
torch
.
tensor
([
1.0
],
device
=
self
.
device
)
torch
.
ops
.
OptimusFp8
.
abs_max_nan_to_inf
(
weight
,
amax
)
if
weight_scale
is
None
or
not
use_offline_input_scales
:
scale
,
scale_inv
=
cal_scale
(
amax
,
self
.
fp_max
,
scale
)
else
:
scale
,
scale_inv
=
weight_scale
,
torch
.
reciprocal
(
weight_scale
)
qweight
=
torch
.
ops
.
OptimusFp8
.
quantize
(
weight
,
scale
,
None
,
torch
.
float8_e4m3fn
)
# print(f"scale={scale},self.amax={self.amax}")
return
qweight
,
scale_inv
else
:
raise
ValueError
(
f
"Unsupported bit width:
{
bits
}
"
)
def
get_quant_scale
(
self
,
tensor
):
amax
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
tensor
.
device
)
torch
.
ops
.
OptimusFp8
.
abs_max_nan_to_inf
(
tensor
,
amax
)
scale
,
_
=
cal_scale
(
amax
,
self
.
fp_max
,
self
.
scale
)
return
scale
def
quantize
(
weight
,
bits
,
weight_scale
=
None
,
use_offline_input_scales
=
True
):
quant
=
QuantFp8
(
weight
.
device
)
return
quant
.
quantize
(
weight
,
bits
,
weight_scale
,
use_offline_input_scales
)
def
dequant
(
weight
,
weight_scales
):
return
torch
.
ops
.
OptimusFp8
.
dequantize
(
weight
,
weight_scales
,
torch
.
bfloat16
)
def
experts_dequant
(
weights
,
weight_scales
):
ret
=
torch
.
empty
(
*
weights
.
shape
,
device
=
weights
.
device
,
dtype
=
torch
.
bfloat16
)
for
i
in
range
(
weights
.
shape
[
0
]):
ret
[
i
]
=
dequant
(
weights
[
i
],
weight_scales
[
i
])
return
ret
def
experts_quantize
(
weight
,
bits
):
if
bits
==
8
:
qweight_experts
=
torch
.
empty
(
*
weight
.
shape
,
dtype
=
torch
.
float8_e4m3fn
,
device
=
weight
.
device
)
scales
=
torch
.
empty
(
weight
.
shape
[
0
],
dtype
=
torch
.
float32
,
device
=
weight
.
device
)
for
idx
in
range
(
weight
.
shape
[
0
]):
expert_weight
=
weight
[
idx
]
qweight
,
scale
=
quantize
(
expert_weight
,
bits
)
qweight_experts
[
idx
]
=
qweight
scales
[
idx
]
=
scale
return
qweight_experts
,
scales
else
:
raise
ValueError
(
f
"Unsupported bit width:
{
bits
}
"
)
def
dynamic_fp8_pertensor_quantize
(
tensor
):
# amax = torch.empty(1, dtype=torch.float32, device=tensor.device)
# scale = torch.tensor([1.0], device=tensor.device)
# fp_max = torch.tensor([448.0], device=tensor.device)
# torch.ops.OptimusFp8.abs_max_nan_to_inf(tensor, amax)
# scale, _ = cal_scale(amax, fp_max, scale)
# return scale
quant
=
QuantFp8
(
tensor
.
device
)
return
quant
.
get_quant_scale
(
tensor
)
\ No newline at end of file
vllm/model_executor/model_loader/weight_utils.py
View file @
583034f1
...
...
@@ -797,3 +797,10 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
# If there were no matches, return the untouched param name
return
name
def
fp8_input_scales_loader
(
path
:
str
):
with
safe_open
(
path
,
framework
=
"pt"
)
as
f
:
for
name
in
f
.
keys
():
# noqa: SIM118
param
=
f
.
get_slice
(
name
)
yield
name
,
param
vllm/model_executor/models/mm_step1o.py
0 → 100644
View file @
583034f1
# SPDX-License-Identifier: Apache-2.0
import
os
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
functools
import
cached_property
from
itertools
import
product
from
math
import
ceil
,
sqrt
from
typing
import
Any
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
,
Union
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
PIL
import
Image
from
transformers
import
BatchFeature
,
PretrainedConfig
,
TensorType
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
)
from
vllm.model_executor.layers.linear
import
RowParallelLinear
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolerOutput
,
PoolingType
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.step_encoder
import
StepCLIPVisionModel
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargs
,
NestedTensors
)
from
vllm.multimodal.parse
import
ImageSize
,
MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.step_image_preprocessor
import
StepPreprocessor
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
SentencePieceTokenizer
)
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.interfaces_base
import
VllmModelForPooling
from
.utils
import
(
flatten_bn
,
init_vllm_registered_model
,
is_pp_missing_parameter
,
maybe_prefix
,
merge_multimodal_embeddings
)
DEFAULT_HIGH_RESOLUTION
=
os
.
getenv
(
"VLLM_DEFAULT_HIGH_RESOLUTION"
,
"false"
).
lower
()
in
[
"true"
,
"1"
]
VISION_MODEL_USE_DP
=
os
.
getenv
(
"VLLM_VISION_MODEL_USE_DP"
,
"false"
).
lower
()
in
[
"true"
,
"1"
]
print
(
f
"DEFAULT_HIGH_RESOLUTION:
{
DEFAULT_HIGH_RESOLUTION
}
"
)
print
(
f
"VISION_MODEL_USE_DP:
{
VISION_MODEL_USE_DP
}
"
)
class
MMStep1oImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
pixel_values
:
torch
.
Tensor
# (batch_size * num_images, num_channels, height, width)
patch_pixel_values
:
Optional
[
torch
.
Tensor
]
# (batch_size * num_patches, num_channels, patch_size, patch_size)
num_patches
:
List
[
int
]
# (batch_size * num_patches)
class
MMStep1oImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
image_embeds
:
torch
.
Tensor
# (batch_size * num_images * image_feature_size, hidden_size)
MMStep1oImageInputs
=
Union
[
MMStep1oImagePixelInputs
,
MMStep1oImageEmbeddingInputs
]
ImageWithPatches
=
Tuple
[
Image
.
Image
,
list
[
Image
.
Image
],
list
[
int
]
|
None
]
class
ImagePatcher
:
def
determine_window_size
(
self
,
long
:
int
,
short
:
int
)
->
int
:
if
long
<=
728
:
return
short
if
long
/
short
>
1.5
else
0
return
min
(
short
,
504
)
if
long
/
short
>
4
else
504
def
slide_window
(
self
,
width
:
int
,
height
:
int
,
sizes
:
list
[
tuple
[
int
,
int
]],
steps
:
list
[
tuple
[
int
,
int
]],
img_rate_thr
:
float
=
0.6
,
)
->
Tuple
[
List
[
Tuple
[
int
,
int
,
int
,
int
]],
Tuple
[
int
,
int
]]:
assert
1
>=
img_rate_thr
>=
0
,
"The `in_rate_thr` should lie in 0~1"
windows
=
[]
# Sliding windows.
for
size
,
step
in
zip
(
sizes
,
steps
):
size_w
,
size_h
=
size
step_w
,
step_h
=
step
x_num
=
1
if
width
<=
size_w
else
ceil
((
width
-
size_w
)
/
step_w
+
1
)
x_start
=
[
step_w
*
i
for
i
in
range
(
x_num
)]
if
len
(
x_start
)
>
1
and
x_start
[
-
1
]
+
size_w
>
width
:
x_start
[
-
1
]
=
width
-
size_w
y_num
=
1
if
height
<=
size_h
else
ceil
((
height
-
size_h
)
/
step_h
+
1
)
y_start
=
[
step_h
*
i
for
i
in
range
(
y_num
)]
if
len
(
y_start
)
>
1
and
y_start
[
-
1
]
+
size_h
>
height
:
y_start
[
-
1
]
=
height
-
size_h
start
=
np
.
array
(
list
(
product
(
y_start
,
x_start
)),
dtype
=
int
)
start
[:,
[
0
,
1
]]
=
start
[:,
[
1
,
0
]]
windows
.
append
(
np
.
concatenate
([
start
,
start
+
size
],
axis
=
1
))
windows
=
np
.
concatenate
(
windows
,
axis
=
0
)
return
[(
int
(
box
[
0
]),
int
(
box
[
1
]),
int
(
box
[
2
]
-
box
[
0
]),
int
(
box
[
3
]
-
box
[
1
]))
for
box
in
windows
],
(
x_num
,
y_num
)
def
square_pad
(
self
,
img
:
Image
.
Image
)
->
Image
.
Image
:
w
,
h
=
img
.
size
if
w
==
h
:
return
img
size
=
max
(
w
,
h
)
padded
=
Image
.
new
(
img
.
mode
,
(
size
,
size
),
0
)
padded
.
paste
(
img
,
(
0
,
0
))
return
padded
def
get_image_size_for_padding
(
self
,
img_width
:
int
,
img_height
:
int
)
->
Tuple
[
int
,
int
]:
ratio
=
img_width
/
img_height
if
min
(
img_height
,
img_width
)
<
32
and
(
ratio
>
4
or
ratio
<
1
/
4
):
new_size
=
max
(
img_height
,
img_width
)
return
new_size
,
new_size
return
img_width
,
img_height
def
get_image_size_for_preprocess
(
self
,
img_width
:
int
,
img_height
:
int
)
->
Tuple
[
int
,
int
]:
if
max
(
img_height
,
img_width
)
>
3024
:
scale_factor
=
3024
/
max
(
img_height
,
img_width
)
img_width
=
int
(
img_width
*
scale_factor
)
img_height
=
int
(
img_height
*
scale_factor
)
return
img_width
,
img_height
else
:
return
img_width
,
img_height
def
get_image_size_for_crop
(
self
,
img_width
:
int
,
img_height
:
int
,
window_size
:
int
):
w_ratio
=
img_width
/
window_size
h_ratio
=
img_height
/
window_size
if
w_ratio
<
1
:
width_new
=
img_width
else
:
xiaoshu_w
=
w_ratio
-
img_width
//
window_size
w_ratio
=
int
(
w_ratio
)
+
1
if
xiaoshu_w
>
0.2
else
int
(
w_ratio
)
width_new
=
window_size
*
w_ratio
if
h_ratio
<
1
:
height_new
=
img_height
else
:
xiaoshu_h
=
h_ratio
-
img_height
//
window_size
h_ratio
=
int
(
h_ratio
)
+
1
if
xiaoshu_h
>
0.2
else
int
(
h_ratio
)
height_new
=
window_size
*
h_ratio
return
int
(
width_new
),
int
(
height_new
)
def
patch_crop
(
self
,
img
:
Image
.
Image
,
i
:
int
,
j
:
int
,
th
:
int
,
tw
:
int
):
target
=
img
.
crop
((
j
,
i
,
j
+
tw
,
i
+
th
))
return
target
def
get_num_patches
(
self
,
img_width
:
int
,
img_height
:
int
)
->
Tuple
[
int
,
int
]:
img_width
,
img_height
=
self
.
get_image_size_for_padding
(
img_width
,
img_height
)
img_width
,
img_height
=
self
.
get_image_size_for_preprocess
(
img_width
,
img_height
)
window_size
=
self
.
determine_window_size
(
max
(
img_height
,
img_width
),
min
(
img_height
,
img_width
))
if
window_size
==
0
:
return
0
,
0
else
:
img_width
,
img_height
=
self
.
get_image_size_for_crop
(
img_width
,
img_height
,
window_size
)
center_list
,
(
x_num
,
y_num
)
=
self
.
slide_window
(
img_width
,
img_height
,
[(
window_size
,
window_size
)],
[(
window_size
,
window_size
)])
full_rows
=
(
len
(
center_list
)
-
1
)
//
x_num
+
1
if
len
(
center_list
)
>
0
and
len
(
center_list
)
%
x_num
==
0
:
full_rows
-=
1
return
len
(
center_list
),
full_rows
def
__call__
(
self
,
img
:
Image
.
Image
)
->
Tuple
[
Image
.
Image
,
List
[
Image
.
Image
],
List
[
bool
]
|
None
]:
img_width
,
img_height
=
img
.
size
new_img_width
,
new_img_height
=
self
.
get_image_size_for_padding
(
img_width
,
img_height
)
if
new_img_width
!=
img_width
or
new_img_height
!=
img_height
:
img
=
self
.
square_pad
(
img
)
img_width
,
img_height
=
img
.
size
new_img_width
,
new_img_height
=
self
.
get_image_size_for_preprocess
(
img_width
,
img_height
)
img
=
img
.
resize
((
new_img_width
,
new_img_height
),
Image
.
Resampling
.
BILINEAR
)
window_size
=
self
.
determine_window_size
(
max
(
new_img_height
,
new_img_width
),
min
(
new_img_height
,
new_img_width
))
if
window_size
==
0
:
return
img
,
[],
None
else
:
new_img_width
,
new_img_height
=
self
.
get_image_size_for_crop
(
new_img_width
,
new_img_height
,
window_size
)
if
(
new_img_width
,
new_img_height
)
!=
(
img_width
,
img_height
):
img_for_crop
=
img
.
resize
((
new_img_width
,
new_img_height
),
Image
.
Resampling
.
BILINEAR
)
else
:
img_for_crop
=
img
patches
=
[]
newlines
=
[]
center_list
,
(
x_num
,
y_num
)
=
self
.
slide_window
(
new_img_width
,
new_img_height
,
[(
window_size
,
window_size
)],
[(
window_size
,
window_size
)])
for
patch_id
,
center_lf_point
in
enumerate
(
center_list
):
x
,
y
,
patch_w
,
patch_h
=
center_lf_point
big_patch
=
self
.
patch_crop
(
img_for_crop
,
y
,
x
,
patch_h
,
patch_w
)
patches
.
append
(
big_patch
)
if
(
patch_id
+
1
)
%
x_num
==
0
:
newlines
.
append
(
patch_id
)
if
newlines
and
newlines
[
-
1
]
==
len
(
patches
)
-
1
:
newlines
.
pop
()
return
img
,
patches
,
[
i
in
newlines
for
i
in
range
(
len
(
patches
))
]
if
len
(
patches
)
>
0
else
None
class
Step1oProcessor
:
def
__init__
(
self
,
config
:
PretrainedConfig
,
tokenizer
:
AnyTokenizer
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
tokenizer
=
tokenizer
self
.
image_size
=
728
self
.
patch_size
=
504
self
.
image_preprocessor
=
StepPreprocessor
(
self
.
image_size
,
"bilinear"
,
self
.
patch_size
)
self
.
num_image_feature_size
=
169
self
.
num_patch_feature_size
=
81
self
.
image_token
=
"<im_patch>"
self
.
image_feature_placeholder
=
self
.
image_token
*
self
.
num_image_feature_size
self
.
patch_feature_placeholder
=
self
.
image_token
*
self
.
num_patch_feature_size
self
.
patcher
=
ImagePatcher
()
@
property
def
image_token_id
(
self
)
->
int
:
return
self
.
tokenizer
.
get_vocab
()[
self
.
image_token
]
def
get_num_image_tokens
(
self
,
img_width
:
int
,
img_height
:
int
,
detail
:
str
=
"auto"
)
->
int
:
if
detail
==
"high"
:
use_high_resolution
=
True
elif
detail
==
"low"
:
use_high_resolution
=
False
else
:
use_high_resolution
=
DEFAULT_HIGH_RESOLUTION
if
use_high_resolution
:
num_patches
,
num_newlines
=
self
.
patcher
.
get_num_patches
(
img_width
,
img_height
)
else
:
num_patches
=
0
num_newlines
=
0
return
num_patches
*
(
self
.
num_patch_feature_size
+
2
)
+
self
.
num_image_feature_size
+
2
+
num_newlines
def
_split_images
(
self
,
images
:
list
[
Image
.
Image
])
->
list
[
ImageWithPatches
]:
result
=
[]
for
img
in
images
:
detail
=
img
.
info
.
get
(
"detail"
,
None
)
if
detail
==
"high"
:
use_high_resolution
=
True
elif
detail
==
"low"
:
use_high_resolution
=
False
else
:
use_high_resolution
=
DEFAULT_HIGH_RESOLUTION
if
use_high_resolution
:
result
.
append
(
self
.
patcher
(
img
))
else
:
result
.
append
((
img
,
[],
None
))
return
result
def
_convert_images_to_pixel_values
(
self
,
images
:
list
[
Image
.
Image
],
is_patch
:
bool
=
False
,
)
->
list
[
torch
.
Tensor
]:
return
[
self
.
image_preprocessor
.
preprocess
(
img
,
is_patch
=
is_patch
)[
"pixel_values"
]
for
img
in
images
]
def
_get_patch_repl
(
self
,
num_patches
:
int
,
patch_newline_mask
:
list
[
bool
]
|
None
,
)
->
Tuple
[
str
,
list
[
int
]]:
text
=
""
token_ids
=
[]
for
i
in
range
(
num_patches
):
assert
len
(
patch_newline_mask
)
==
num_patches
text
+=
f
"<patch_start>
{
self
.
patch_feature_placeholder
}
<patch_end>"
token_ids
.
extend
(
[
self
.
tokenizer
.
convert_tokens_to_ids
(
"<patch_start>"
)]
+
[
self
.
image_token_id
]
*
self
.
num_patch_feature_size
+
[
self
.
tokenizer
.
convert_tokens_to_ids
(
"<patch_end>"
)])
if
patch_newline_mask
and
patch_newline_mask
[
i
]:
text
+=
"<patch_newline>"
token_ids
.
append
(
self
.
tokenizer
.
convert_tokens_to_ids
(
"<patch_newline>"
))
return
text
,
token_ids
def
_get_image_repl
(
self
,
num_images
:
int
,
)
->
Tuple
[
str
,
list
[
int
]]:
text
=
f
"<im_start>
{
self
.
image_feature_placeholder
}
<im_end>"
token_ids
=
[
self
.
tokenizer
.
convert_tokens_to_ids
(
"<im_start>"
)
]
+
[
self
.
image_token_id
]
*
self
.
num_image_feature_size
+
[
self
.
tokenizer
.
convert_tokens_to_ids
(
"<im_end>"
)
]
return
text
*
num_images
,
token_ids
*
num_images
def
_get_image_repl_features
(
self
,
num_images
:
int
,
num_patches
:
int
,
patch_new_line_idx
:
Optional
[
list
[
bool
]],
)
->
Tuple
[
str
,
list
[
int
]]:
if
num_patches
>
0
:
patch_repl
,
patch_repl_ids
=
self
.
_get_patch_repl
(
num_patches
,
patch_new_line_idx
)
else
:
patch_repl
=
""
patch_repl_ids
=
[]
image_repl
,
image_repl_ids
=
self
.
_get_image_repl
(
num_images
)
return
patch_repl
+
image_repl
,
patch_repl_ids
+
image_repl_ids
def
replace_placeholder
(
self
,
text
:
str
,
placeholder
:
str
,
repls
:
list
[
str
])
->
str
:
parts
=
text
.
split
(
placeholder
)
if
len
(
parts
)
-
1
!=
len
(
repls
):
raise
ValueError
(
"The number of placeholders does not match the number of replacements."
)
result
=
[
parts
[
0
]]
for
i
,
repl
in
enumerate
(
repls
):
result
.
append
(
repl
)
result
.
append
(
parts
[
i
+
1
])
return
""
.
join
(
result
)
def
__call__
(
self
,
text
:
Optional
[
Union
[
str
,
list
[
str
]]]
=
None
,
images
:
Optional
[
Union
[
Image
.
Image
,
list
[
Image
.
Image
]]]
=
None
,
return_tensors
:
Optional
[
Union
[
str
,
TensorType
]]
=
None
,
)
->
BatchFeature
:
if
text
is
None
:
text
=
[]
if
not
isinstance
(
text
,
list
):
text
=
[
text
]
if
images
is
None
:
images
=
[]
if
not
isinstance
(
images
,
list
):
images
=
[
images
]
if
len
(
images
)
==
0
:
image_inputs
=
{}
if
isinstance
(
self
.
tokenizer
,
SentencePieceTokenizer
):
assert
len
(
text
)
==
1
text_inputs
=
{
"input_ids"
:
torch
.
tensor
([
self
.
tokenizer
.
encode
(
text
[
0
],
add_special_tokens
=
True
)
],
dtype
=
torch
.
long
)
}
# step-tokenizer does not support text input for special tokens
else
:
text_inputs
=
self
.
tokenizer
(
text
)
else
:
splitted_images_data
=
self
.
_split_images
(
images
)
pixel_values_lst
=
[]
patch_pixel_values_lst
=
[]
patch_newline_mask_lst
=
[]
image_repl_str_lst
=
[]
image_repl_ids_lst
=
[]
num_patches
=
[]
for
raw_img
,
img_patches
,
patch_newline_mask
in
splitted_images_data
:
pixel_values_lst
.
extend
(
self
.
_convert_images_to_pixel_values
([
raw_img
]))
if
len
(
img_patches
)
>
0
:
patch_pixel_values_lst
.
extend
(
self
.
_convert_images_to_pixel_values
(
img_patches
,
is_patch
=
True
))
num_patches
.
append
(
len
(
img_patches
))
image_repl_str
,
image_repl_ids
=
self
.
_get_image_repl_features
(
1
,
len
(
img_patches
),
patch_newline_mask
)
image_repl_str_lst
.
append
(
image_repl_str
)
image_repl_ids_lst
.
extend
(
image_repl_ids
)
if
patch_newline_mask
is
not
None
:
patch_newline_mask_lst
.
extend
(
patch_newline_mask
)
image_inputs
=
{
"pixel_values"
:
torch
.
cat
(
pixel_values_lst
),
"num_patches"
:
num_patches
,
}
if
patch_pixel_values_lst
:
image_inputs
[
"patch_pixel_values"
]
=
torch
.
cat
(
patch_pixel_values_lst
)
if
patch_newline_mask_lst
:
image_inputs
[
"patch_newline_mask"
]
=
torch
.
tensor
(
patch_newline_mask_lst
,
dtype
=
torch
.
bool
)
if
isinstance
(
self
.
tokenizer
,
SentencePieceTokenizer
):
text_inputs
=
{
"input_ids"
:
torch
.
tensor
(
image_repl_ids_lst
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
}
# step-tokenizer does not support text input for special tokens
else
:
text
=
[
self
.
replace_placeholder
(
t
,
self
.
image_token
,
image_repl_str_lst
)
for
t
in
text
]
text_inputs
=
self
.
tokenizer
(
text
)
return
BatchFeature
(
{
**
text_inputs
,
**
image_inputs
,
},
tensor_type
=
return_tensors
,
)
class
Step1oProcessingInfo
(
BaseProcessingInfo
):
def
get_hf_processor
(
self
)
->
Step1oProcessor
:
return
Step1oProcessor
(
self
.
get_hf_config
(),
self
.
get_tokenizer
(),
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
}
def
get_max_image_tokens
(
self
)
->
int
:
hf_processor
=
self
.
get_hf_processor
()
return
hf_processor
.
get_num_image_tokens
(
self
.
get_image_size_with_most_features
().
width
,
self
.
get_image_size_with_most_features
().
height
)
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
self
.
get_max_image_tokens
()}
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
return
ImageSize
(
728
,
728
)
def
get_num_mm_tokens
(
self
,
mm_data
:
MultiModalDataDict
)
->
int
:
if
len
(
mm_data
)
!=
1
or
"image"
not
in
mm_data
:
raise
ValueError
(
"mm_data could only contain one key 'image' for steo1o"
)
image_data
=
mm_data
[
"image"
]
if
not
isinstance
(
image_data
,
(
list
,
tuple
)):
image_data
=
[
image_data
]
return
sum
(
self
.
get_hf_processor
().
get_num_image_tokens
(
img
.
width
,
img
.
height
,
detail
=
img
.
info
.
get
(
"detail"
,
None
))
for
img
in
image_data
)
class
Step1oDummyInputsBuilder
(
BaseDummyInputsBuilder
[
Step1oProcessingInfo
]):
# def get_dummy_processor_inputs(
# self,
# seq_len: int,
# mm_counts: Mapping[str, int],
# ) -> ProcessorInputs:
# target_width, target_height = \
# self.info.get_image_size_with_most_features()
# num_images = mm_counts.get("image", 0)
# mm_data = {
# "image":
# self._get_dummy_images(width=target_width,
# height=target_height,
# num_images=num_images)
# }
# return ProcessorInputs(
# prompt_text="<im_patch>" * num_images,
# mm_data=mm_data,
# )
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
return
"<im_patch>"
*
num_images
def
get_dummy_mm_data
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
MultiModalDataDict
:
target_width
,
target_height
=
\
self
.
info
.
get_image_size_with_most_features
()
num_images
=
mm_counts
.
get
(
"image"
,
0
)
return
{
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
num_images
=
num_images
)
}
class
Step1oMultiModalProcessor
(
BaseMultiModalProcessor
[
Step1oProcessingInfo
]):
def
_get_prompt_updates
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
Any
],
out_mm_kwargs
:
MultiModalKwargs
,
)
->
Sequence
[
PromptUpdate
]:
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
image_placeholder_token_id
=
hf_processor
.
image_token_id
batch_num_patches
=
out_mm_kwargs
[
"num_patches"
].
tolist
()
def
get_replacement_step1o
(
item_idx
:
int
):
img_out
=
out_mm_kwargs
.
get_item
(
"image"
,
item_idx
)
num_patches
=
batch_num_patches
[
item_idx
]
if
num_patches
>
0
:
patch_newline_mask
=
img_out
[
"patch_newline_mask"
].
data
.
tolist
(
)
image_repl_ids
=
hf_processor
.
_get_image_repl_features
(
1
,
num_patches
,
patch_newline_mask
)[
1
]
else
:
image_repl_ids
=
hf_processor
.
_get_image_repl_features
(
1
,
0
,
None
)[
1
]
return
PromptUpdateDetails
.
select_token_id
(
seq
=
image_repl_ids
,
embed_token_id
=
image_placeholder_token_id
,
)
return
[
PromptReplacement
(
modality
=
"image"
,
target
=
[
image_placeholder_token_id
],
replacement
=
get_replacement_step1o
,
)
]
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
num_patches
=
hf_inputs
.
get
(
"num_patches"
,
torch
.
empty
(
0
))
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
patch_pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
num_patches
),
num_patches
=
MultiModalFieldConfig
.
batched
(
"image"
),
patch_newline_mask
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
num_patches
),
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
Step1oMultiModalProcessor
,
info
=
Step1oProcessingInfo
,
dummy_inputs
=
Step1oDummyInputsBuilder
)
class
MMGPTStep1oForCausalLM
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
vision_model
=
StepCLIPVisionModel
(
config
.
vision_tower_config
,
None
,
prefix
=
maybe_prefix
(
prefix
,
"vision_model"
),
need_dp
=
VISION_MODEL_USE_DP
)
self
.
vit_downsampler
=
nn
.
Conv2d
(
config
.
vision_tower_config
.
hidden_size
,
config
.
vision_tower_config
.
output_hidden_size
,
kernel_size
=
2
,
stride
=
config
.
understand_projector_stride
)
self
.
vit_downsampler2
=
nn
.
Conv2d
(
config
.
vision_tower_config
.
output_hidden_size
,
config
.
vision_tower_config
.
output_hidden_size
*
2
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
)
self
.
vit_large_projector
=
nn
.
Linear
(
config
.
vision_tower_config
.
output_hidden_size
*
2
,
config
.
hidden_size
,
bias
=
config
.
projector_bias
,
)
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
))
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
@
cached_property
def
sampler
(
self
):
if
hasattr
(
self
.
language_model
,
"sampler"
):
return
self
.
language_model
.
sampler
return
get_sampler
()
@
property
def
device
(
self
):
return
next
(
self
.
parameters
()).
device
@
property
def
dtype
(
self
):
return
next
(
self
.
parameters
()).
dtype
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
MMStep1oImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
patch_pixel_values
=
kwargs
.
pop
(
"patch_pixel_values"
,
None
)
num_patches
=
kwargs
.
pop
(
"num_patches"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
if
pixel_values
is
None
and
image_embeds
is
None
:
return
None
if
pixel_values
is
not
None
:
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
)
if
pixel_values
.
dim
()
>=
3
:
pixel_values
=
pixel_values
.
view
(
-
1
,
*
pixel_values
.
shape
[
-
3
:])
if
patch_pixel_values
is
not
None
:
patch_pixel_values
=
flatten_bn
(
patch_pixel_values
,
concat
=
True
)
patch_pixel_values
=
patch_pixel_values
.
view
(
-
1
,
*
patch_pixel_values
.
shape
[
-
3
:])
# Handle empty patch_pixel_values by setting to None
if
patch_pixel_values
.
shape
[
0
]
==
0
:
patch_pixel_values
=
None
num_patches
=
flatten_bn
(
num_patches
,
concat
=
True
).
tolist
()
return
MMStep1oImagePixelInputs
(
type
=
"pixel_values"
,
pixel_values
=
pixel_values
.
to
(
self
.
dtype
).
to
(
self
.
device
),
patch_pixel_values
=
patch_pixel_values
.
to
(
self
.
dtype
).
to
(
self
.
device
)
if
patch_pixel_values
is
not
None
else
None
,
num_patches
=
num_patches
,
)
if
image_embeds
is
not
None
:
if
image_embeds
.
dim
()
==
2
or
image_embeds
.
dim
()
>=
3
:
image_embeds
=
image_embeds
.
view
(
-
1
,
image_embeds
.
shape
[
-
1
])
else
:
raise
ValueError
(
f
"Unexpected shape for image_embeds:
{
image_embeds
.
shape
}
"
)
return
MMStep1oImageEmbeddingInputs
(
type
=
"image_embeds"
,
image_embeds
=
image_embeds
.
to
(
self
.
dtype
).
to
(
self
.
device
),
)
return
None
def
_process_image_features
(
self
,
image_features
:
torch
.
Tensor
)
->
torch
.
Tensor
:
B
,
P
=
image_features
.
shape
[:
2
]
HW
=
int
(
sqrt
(
P
))
image_features
=
image_features
.
permute
(
0
,
2
,
1
).
view
(
B
,
-
1
,
HW
,
HW
)
image_features
=
self
.
vit_downsampler
(
image_features
)
image_features
=
self
.
vit_downsampler2
(
image_features
)
n_dim
=
image_features
.
size
(
1
)
image_features
=
image_features
.
view
(
B
,
n_dim
,
-
1
).
permute
(
0
,
2
,
1
)
image_features
=
self
.
vit_large_projector
(
image_features
)
return
image_features
def
_get_vision_model_output
(
self
,
input_tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
VISION_MODEL_USE_DP
and
get_tensor_model_parallel_world_size
()
>
1
:
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
batch_size
=
input_tensor
.
shape
[
0
]
chunk_size
=
(
batch_size
+
tp_size
-
1
)
//
tp_size
start_idx
=
tp_rank
*
chunk_size
end_idx
=
min
(
start_idx
+
chunk_size
,
batch_size
)
local_input_tensor
=
torch
.
empty
(
chunk_size
,
*
input_tensor
.
shape
[
1
:],
dtype
=
input_tensor
.
dtype
,
device
=
input_tensor
.
device
)
if
end_idx
>
start_idx
:
local_input_tensor
[:
end_idx
-
start_idx
].
copy_
(
input_tensor
[
start_idx
:
end_idx
])
local_features
=
self
.
vision_model
(
local_input_tensor
)[
0
][:,
4
:]
total_features
=
tensor_model_parallel_all_gather
(
local_features
.
contiguous
(),
dim
=
0
)
return
total_features
[:
batch_size
]
else
:
return
self
.
vision_model
(
input_tensor
)[
0
][:,
4
:]
def
_process_image_input
(
self
,
image_input
:
MMStep1oImageInputs
)
->
tuple
[
torch
.
Tensor
,
...]:
if
image_input
[
"type"
]
==
"image_embeds"
:
image_features
=
image_input
[
"image_embeds"
]
else
:
image_features
=
self
.
_get_vision_model_output
(
image_input
[
"pixel_values"
])
patch_image_features
=
self
.
_get_vision_model_output
(
image_input
[
"patch_pixel_values"
])
if
image_input
[
"patch_pixel_values"
]
is
not
None
else
None
num_patches
=
image_input
[
"num_patches"
]
image_features
=
self
.
_process_image_features
(
image_features
)
patch_image_features
=
self
.
_process_image_features
(
patch_image_features
)
if
patch_image_features
is
not
None
else
None
merged_image_features
=
[]
cur_patch_idx
=
0
for
i
,
num_patch
in
enumerate
(
num_patches
):
cur_feature
=
[]
if
num_patch
>
0
:
patch_slice
=
patch_image_features
[
cur_patch_idx
:
cur_patch_idx
+
num_patch
]
cur_feature
.
append
(
patch_slice
.
view
(
-
1
,
patch_slice
.
shape
[
-
1
]))
cur_feature
.
append
(
image_features
[
i
].
view
(
-
1
,
image_features
.
shape
[
-
1
]))
cur_patch_idx
+=
num_patch
merged_image_features
.
append
(
torch
.
cat
(
cur_feature
)
if
len
(
cur_feature
)
>
1
else
cur_feature
[
0
])
return
merged_image_features
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
return
vision_embeddings
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
,
vision_embeddings
:
Optional
[
NestedTensors
]
=
None
,
)
->
torch
.
Tensor
:
if
vision_embeddings
is
None
:
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
else
:
is_text
=
input_ids
!=
self
.
config
.
image_token_id
text_ids
=
input_ids
[
is_text
]
text_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
text_ids
)
inputs_embeds
=
torch
.
empty
(
input_ids
.
shape
[
0
],
text_embeds
.
shape
[
-
1
],
dtype
=
text_embeds
.
dtype
,
device
=
text_embeds
.
device
)
inputs_embeds
[
is_text
]
=
text_embeds
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
config
.
image_token_id
)
return
inputs_embeds
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
intermediate_tensors
is
not
None
:
inputs_embeds
=
None
elif
inputs_embeds
is
None
:
vision_embeddings
=
self
.
get_multimodal_embeddings
(
**
kwargs
)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
,
vision_embeddings
)
input_ids
=
None
hidden_states
=
self
.
language_model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
maybe_remap_params
(
self
,
name
):
if
name
.
startswith
(
"model."
):
name
=
name
.
replace
(
"model."
,
"language_model.model."
)
if
name
.
startswith
(
"lm_head"
):
name
=
name
.
replace
(
"lm_head"
,
"language_model.lm_head"
)
if
name
.
startswith
(
"vision_model."
):
name
=
name
.
replace
(
"vision_model."
,
"vision_model.vision_model."
)
return
name
def
load_weights_1o
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
()
for
name
,
loaded_weight
in
weights
:
name
=
self
.
maybe_remap_params
(
name
)
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
loaded_params
.
add
(
name
)
break
else
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
params_need_to_load
=
[]
for
name
in
params_dict
:
params_need_to_load
.
append
(
name
)
params_need_to_load
=
set
(
params_need_to_load
)
if
params_need_to_load
!=
loaded_params
:
param_name_example
=
list
(
params_need_to_load
-
loaded_params
)[
0
]
raise
RuntimeError
(
f
"Some parameters like
{
param_name_example
}
are not in the checkpoint and will falsely use random initialization"
)
def
load_weights_3v
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
qkv_params_mapping
=
[
# (param_name, shard_name, relative_start_idx, relative_end_idx)
(
".qkv_proj"
,
".q_proj"
,
0
,
self
.
config
.
share_q_dim
/
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
)),
(
".qkv_proj"
,
".k_proj"
,
self
.
config
.
share_q_dim
/
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
),
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
)
/
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
)),
(
".qkv_proj"
,
".v_proj"
,
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
)
/
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
),
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
)
/
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
)),
]
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
()
params_need_to_load
=
set
()
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
language_model
.
model
.
config
.
moe_num_experts
)
if
self
.
language_model
.
model
.
use_fused_moe
:
quant_config
=
self
.
language_model
.
model
.
vllm_config
.
quant_config
if
quant_config
is
not
None
and
quant_config
.
get_name
()
==
"groupwise_quant"
:
expert_params_mapping
=
[
(
".moe.experts.w13_weight"
,
".moe.gate_proj.qweight"
,
"w1"
),
(
".moe.experts.w13_weight"
,
".moe.up_proj.qweight"
,
"w3"
),
(
".moe.experts.w2_weight"
,
".moe.down_proj.qweight"
,
"w2"
),
(
".moe.experts.w13_weight_scale"
,
".moe.gate_proj.scales"
,
"w1"
),
(
".moe.experts.w13_weight_scale"
,
".moe.up_proj.scales"
,
"w3"
),
(
".moe.experts.w2_weight_scale"
,
".moe.down_proj.scales"
,
"w2"
),
]
else
:
expert_params_mapping
=
[
(
".moe.experts.w13_weight"
,
".moe.gate_proj.weight"
,
"w1"
),
(
".moe.experts.w13_weight"
,
".moe.up_proj.weight"
,
"w3"
),
(
".moe.experts.w2_weight"
,
".moe.down_proj.weight"
,
"w2"
)
]
else
:
expert_params_mapping
=
[]
disable_moe_stacked_params
=
[
data
[
1
]
for
data
in
expert_params_mapping
]
for
name
,
loaded_weight
in
weights
:
name
=
self
.
maybe_remap_params
(
name
)
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
if
any
(
disable_moe_stacked_param
in
name
for
disable_moe_stacked_param
in
disable_moe_stacked_params
):
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
loaded_params
.
add
(
name
)
break
else
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
# Skip loading extra bias for GPTQ models.
if
((
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
))
and
name
not
in
params_dict
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
for
expert_id
in
range
(
loaded_weight
.
shape
[
0
]):
loaded_weight_expert
=
loaded_weight
[
expert_id
]
weight_loader
(
param
,
loaded_weight_expert
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
loaded_params
.
add
(
name
)
break
else
:
for
(
param_name
,
weight_name
,
start_idx
,
end_idx
)
in
qkv_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
dim
=
param
.
shape
[
param
.
output_dim
]
begin_idx
=
int
(
start_idx
*
dim
)
end_idx
=
int
(
end_idx
*
dim
)
param_slice
=
param
.
narrow
(
param
.
output_dim
,
begin_idx
,
end_idx
-
begin_idx
)
param_slice
.
copy_
(
loaded_weight
)
loaded_params
.
add
(
name
)
break
else
:
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
params_need_to_load
=
[]
for
name
in
params_dict
:
params_need_to_load
.
append
(
name
)
params_need_to_load
=
set
(
params_need_to_load
)
if
params_need_to_load
!=
loaded_params
:
param_name_example
=
list
(
params_need_to_load
-
loaded_params
)[
0
]
raise
RuntimeError
(
f
"Some parameters like
{
param_name_example
}
are not in the checkpoint and will falsely use random initialization"
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
if
self
.
config
.
model_type
in
[
"step1o"
,
"mmgpt_qwen2_v2"
]:
self
.
load_weights_1o
(
weights
)
elif
self
.
config
.
model_type
==
"step3v"
:
self
.
load_weights_3v
(
weights
)
else
:
raise
ValueError
(
f
"Unsupported model type:
{
self
.
multimodal_config
.
model_type
}
"
)
class
MMGPTStep1oRewardModel
(
MMGPTStep1oForCausalLM
,
VllmModelForPooling
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
,
"Pooler config must be provided for classification models"
# Remove attributes specific to CausalLM if they exist directly on self
# (They are typically part of language_model)
for
attr
in
(
"sampler"
,
"lm_head"
):
if
hasattr
(
self
.
language_model
,
attr
):
delattr
(
self
.
language_model
,
attr
)
# Initialize the classification score head
self
.
score
=
RowParallelLinear
(
config
.
text_config
.
hidden_size
,
config
.
num_labels
,
# Assumes num_labels is in the main config
quant_config
=
quant_config
,
input_is_parallel
=
False
,
bias
=
False
,
prefix
=
maybe_prefix
(
prefix
,
"score"
))
# Initialize the pooler
# Use LAST pooling, no normalization, apply softmax (typical for classification)
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
ALL
,
normalize
=
False
,
softmax
=
False
,
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
)
->
torch
.
Tensor
:
# Get hidden states from the base model (without the LM head)
hidden_states
=
super
().
forward
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
**
kwargs
)
# Apply the classification head
logits
,
_
=
self
.
score
(
hidden_states
)
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# Filter out lm_head weights before passing to the base loader
weights_iterator
=
((
name
,
data
)
for
name
,
data
in
weights
if
"language_model.lm_head."
not
in
name
)
# Use the base class's load_weights logic, which now includes
# handling for the 'score' layer via maybe_remap_params
super
().
load_weights
(
weights_iterator
)
\ No newline at end of file
vllm/model_executor/models/registry.py
View file @
583034f1
...
...
@@ -134,6 +134,11 @@ _TEXT_GENERATION_MODELS = {
# [Encoder-decoder]
"BartModel"
:
(
"bart"
,
"BartForConditionalGeneration"
),
"BartForConditionalGeneration"
:
(
"bart"
,
"BartForConditionalGeneration"
),
# step model
"Step1ForCausalLM"
:
(
"step1"
,
"Step1ForCausalLM"
),
"Step2ForCausalLM"
:
(
"step1"
,
"Step1ForCausalLM"
),
"Step1MoEForCausalLM"
:
(
"step1"
,
"Step1ForCausalLM"
),
"Step2MiniForCausalLM"
:
(
"step2_mini"
,
"Step2MiniForCausalLM"
),
}
_EMBEDDING_MODELS
=
{
...
...
@@ -174,6 +179,19 @@ _EMBEDDING_MODELS = {
# input and output. I am adding it here because it piggy-backs on embedding
# models for the time being.
"PrithviGeoSpatialMAE"
:
(
"prithvi_geospatial_mae"
,
"PrithviGeoSpatialMAE"
),
# step model
"Step1ForSequenceClassification"
:
(
"step1"
,
"Step1ForSequenceClassification"
),
"Step2ForClassification"
:
(
"step1"
,
"Step1ForSequenceClassification"
),
"Step2ForSequenceClassification"
:
(
"step2"
,
"Step2ForSequenceClassification"
),
"Step2MiniForClassification"
:
(
"step2_mini"
,
"Step2MiniForSequenceClassification"
),
"MMGPTQwen2RewardModel"
:
(
"mm_step1o"
,
"MMGPTStep1oRewardModel"
),
# Technically PrithviGeoSpatialMAE is a model that works on images, both in
# input and output. I am adding it here because it piggy-backs on embedding
# models for the time being.
"PrithviGeoSpatialMAE"
:
(
"prithvi_geospatial_mae"
,
"PrithviGeoSpatialMAE"
),
}
_CROSS_ENCODER_MODELS
=
{
...
...
@@ -251,6 +269,15 @@ _SPECULATIVE_DECODING_MODELS = {
"Glm4MoeMTPModel"
:
(
"glm4_moe_mtp"
,
"Glm4MoeMTP"
),
"MedusaModel"
:
(
"medusa"
,
"Medusa"
),
"MLPSpeculatorPreTrainedModel"
:
(
"mlp_speculator"
,
"MLPSpeculator"
),
# step model
"MMGPTStep1ForCausalLMV2"
:
(
"mm_step1p5c_1u"
,
"MMGPTStep1ForCausalLMV2"
),
"MMGPTStep1ForCausalLMV3"
:
(
"mm_step1p5c_1u"
,
"MMGPTStep1ForCausalLMV3"
),
"MMGPTStep1ForCausalLMV4"
:
(
"mm_step1o"
,
"MMGPTStep1oForCausalLM"
),
"MMGPTQwen2ForCausalLM"
:
(
"mm_step1p5c_1u"
,
"MMGPTStep1ForCausalLMV3"
),
"MMGPTQwen2ForCausalLMV2"
:
(
"mm_step1o"
,
"MMGPTStep1oForCausalLM"
),
"MMGPTStep3vForCausalLM"
:
(
"mm_step1o"
,
"MMGPTStep1oForCausalLM"
),
"Step1AudioForCausalLM"
:
(
"mm_step_audio"
,
"MMGPTStep1fForCausalLM"
),
"StepAudioForCausalLMV2"
:
(
"mm_step_audio"
,
"MMGPTStep1fForCausalLM"
),
}
_TRANSFORMERS_MODELS
=
{
...
...
vllm/model_executor/models/step1.py
0 → 100644
View file @
583034f1
# SPDX-License-Identifier: Apache-2.0
import
math
import
os
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
# from optimus import moe_expert_histogram as optimus_moe_expert_histogram
# from optimus import moe_gather as optimus_moe_gather
# from optimus import moe_scatter as optimus_moe_scatter
from
torch
import
nn
from
vllm.attention
import
Attention
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.activation
import
OptimusSiluAndMul
,
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
OptimusRMSNorm
,
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
MergedColumnParallelMoELinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
RowParallelMoELinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.quant_utils
import
(
dynamic_fp8_pertensor_quantize
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
fp8_input_scales_loader
)
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
.interfaces
import
SupportsPP
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
DISABLE_SEQUENCE_PARALLEL
=
True
# FIXME: os.getenv("DISABLE_SEQUENCE_PARALLEL", "0") == "1"
SEQUENCE_PARALLEL_THRESHOLD
=
512
if
os
.
getenv
(
"SEQUENCE_PARALLEL_THRESHOLD"
,
"0"
)
==
"0"
else
int
(
os
.
getenv
(
"SEQUENCE_PARALLEL_THRESHOLD"
))
GEMM_COMM_OVERLAP_RATIO
=
0.5
MLP_BATCH_SIZE
=
8192
def
_get_alibi_slopes
(
n_heads
):
n
=
2
**
math
.
floor
(
math
.
log2
(
n_heads
))
# nearest 2**n to n_heads
m0
=
2.0
**
(
-
8.0
/
n
)
slopes
=
np
.
power
(
m0
,
np
.
arange
(
1
,
n
+
1
))
if
n
<
n_heads
:
m1
=
2.0
**
(
-
4.0
/
n
)
mm
=
np
.
power
(
m1
,
np
.
arange
(
1
,
1
+
2
*
(
n_heads
-
n
),
2
))
slopes
=
np
.
concatenate
([
slopes
,
mm
])
return
slopes
def
_get_ntk_alibi_slopes
(
max_pos_interp_ratio
,
slopes
):
if
max_pos_interp_ratio
==
1.0
:
return
slopes
smax
,
smin
=
slopes
.
max
(),
slopes
.
min
()
D0
=
np
.
log2
(
smax
)
-
np
.
log2
(
smin
)
W1
=
(
np
.
log2
(
smax
)
-
np
.
log2
(
slopes
))
/
D0
ratios
=
np
.
power
(
max_pos_interp_ratio
,
W1
)
return
slopes
/
(
ratios
**
0.5
)
class
Step1MoEMLP
(
nn
.
Module
):
def
__init__
(
self
,
num_experts
:
int
,
top_k
:
int
,
top_p
:
float
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
=
""
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
norm_expert_weight
=
True
,
prefix
:
str
=
""
,
enable_cudagraph
:
bool
=
False
):
super
().
__init__
()
self
.
gate
=
ReplicatedLinear
(
input_size
=
hidden_size
,
output_size
=
num_experts
,
bias
=
False
,
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.gate"
)
self
.
top_k
=
top_k
self
.
top_p
=
top_p
self
.
num_experts
=
num_experts
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
(
)
assert
intermediate_size
%
tensor_model_parallel_world_size
==
0
self
.
gate_up_proj
=
MergedColumnParallelMoELinear
(
num_experts
,
hidden_size
,
[
intermediate_size
]
*
2
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
if
(
intermediate_size
/
tensor_model_parallel_world_size
)
%
64
==
0
:
self
.
act_fn
=
OptimusSiluAndMul
()
else
:
self
.
act_fn
=
SiluAndMul
()
self
.
down_proj
=
RowParallelMoELinear
(
num_experts
,
intermediate_size
,
hidden_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
)
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
quant_config
=
quant_config
self
.
norm_expert_weight
=
norm_expert_weight
self
.
enable_cudagraph
=
enable_cudagraph
self
.
need_fp32_gate
=
False
def
get_expert_output
(
self
,
inputs
:
torch
.
Tensor
,
expert_token_cnt
:
torch
.
Tensor
,
token_nums
:
int
):
if
self
.
quant_config
and
getattr
(
self
.
gate_up_proj
.
quant_method
,
"quant_config"
,
None
)
and
getattr
(
self
.
down_proj
.
quant_method
,
"quant_config"
,
None
):
if
inputs
.
size
(
0
)
<=
1024
and
self
.
gate_up_proj
.
quant_method
.
quant_config
.
weight_bits
==
8
and
self
.
down_proj
.
quant_method
.
quant_config
.
weight_bits
==
8
:
if
self
.
enable_cudagraph
:
tmp
=
torch
.
ops
.
Optimus
.
MoeFpAIntBGemm
(
inputs
,
self
.
gate_up_proj
.
qweight
,
self
.
gate_up_proj
.
qweight
.
dtype
,
self
.
gate_up_proj
.
scales
,
expert_token_cnt
,
token_nums
,
None
)
tmp
=
self
.
act_fn
(
tmp
)
tmp
=
torch
.
ops
.
Optimus
.
MoeFpAIntBGemm
(
tmp
,
self
.
down_proj
.
qweight
,
self
.
down_proj
.
qweight
.
dtype
,
self
.
down_proj
.
scales
,
expert_token_cnt
,
token_nums
,
None
)
return
tmp
else
:
quant_output_
=
torch
.
ops
.
OptimusMoe
.
moe_ffn_quant
(
inputs
,
self
.
gate_up_proj
.
qweight
.
dtype
,
self
.
gate_up_proj
.
qweight
,
self
.
gate_up_proj
.
scales
,
self
.
down_proj
.
qweight
,
self
.
down_proj
.
scales
,
expert_token_cnt
,
token_nums
,
out
=
inputs
)
return
quant_output_
else
:
expert_token_cnt
=
expert_token_cnt
.
to
(
"cpu"
).
tolist
()
start
=
0
end
=
0
if
getattr
(
self
.
gate_up_proj
.
quant_method
,
"quant_config"
,
None
)
and
self
.
gate_up_proj
.
quant_method
.
quant_config
.
weight_bits
==
6
:
output
=
torch
.
empty_like
(
inputs
,
dtype
=
torch
.
bfloat16
,
device
=
inputs
.
device
)
else
:
output
=
inputs
for
i
in
range
(
len
(
expert_token_cnt
)):
cur_token_cnt
=
expert_token_cnt
[
i
]
if
(
cur_token_cnt
<=
0
):
continue
end
+=
cur_token_cnt
tmp
=
self
.
gate_up_proj
(
inputs
[
start
:
end
],
expert_idx
=
i
)
tmp
=
self
.
act_fn
(
tmp
)
tmp
=
self
.
down_proj
(
tmp
,
expert_idx
=
i
,
output
=
output
[
start
:
end
])
start
+=
cur_token_cnt
return
output
else
:
moe_output
=
torch
.
ops
.
OptimusMoe
.
moe_ffn
(
inputs
,
self
.
gate_up_proj
.
weight
,
self
.
down_proj
.
weight
,
expert_token_cnt
,
token_nums
)
return
moe_output
def
forward
(
self
,
x
,
residual
=
None
,
layernorm
=
None
,
disable_allreduce
=
False
,
user_output
=
None
,
):
if
layernorm
is
not
None
:
x
=
layernorm
(
x
,
fp16_out
=
getattr
(
self
.
gate_up_proj
.
quant_method
,
"quant_config"
,
None
)
and
self
.
gate_up_proj
.
quant_method
.
quant_config
.
weight_bits
==
6
if
self
.
gate_up_proj
.
quant_method
else
False
)
x_shape
=
x
.
shape
if
self
.
need_fp32_gate
:
if
getattr
(
self
.
gate_up_proj
.
quant_method
,
"quant_config"
,
None
)
and
self
.
gate_up_proj
.
quant_method
.
quant_config
.
weight_bits
==
6
:
logits
=
torch
.
ops
.
OptimusMoe
.
matmul_fp32
(
x
.
to
(
torch
.
bfloat16
),
self
.
gate
.
weight
.
t
())
else
:
logits
=
torch
.
ops
.
OptimusMoe
.
matmul_fp32
(
x
,
self
.
gate
.
weight
.
t
())
else
:
logits
=
self
.
gate
(
x
)[
0
]
# if self.top_p < 1.0:
# top_k_index, expert_weight, scatter_index = torch.ops.OptimusMoe.topk_topp_gating(
# logits, self.top_k, self.top_p, self.norm_expert_weight)
# expert_token_cnt = optimus_moe_expert_histogram(
# top_k_index, self.num_experts)
# scatter_index = torch.ops.OptimusMoe.index_compute(
# top_k_index, expert_token_cnt, out=scatter_index)
# mid_output = optimus_moe_scatter(x, scatter_index)
# expert_output = self.get_expert_output(mid_output,
# expert_token_cnt,
# x_shape[0])
# output = optimus_moe_gather(expert_output, scatter_index,
# expert_weight)
# else:
# expert_weight, expert_token_cnt, scatter_index = torch.ops.OptimusMoe.gating_histogram_index(
# logits, self.top_k, 1.0, self.norm_expert_weight)
# mid_output = optimus_moe_scatter(x, scatter_index)
# expert_output = self.get_expert_output(mid_output,
# expert_token_cnt,
# x_shape[0])
# output = optimus_moe_gather(expert_output, scatter_index,
# expert_weight)
if
self
.
top_p
<
1.0
:
top_k_index
,
expert_weight
,
scatter_index
=
torch
.
ops
.
OptimusMoe
.
topk_topp_gating
(
logits
,
self
.
top_k
,
self
.
top_p
,
self
.
norm_expert_weight
)
expert_token_cnt
=
torch
.
ops
.
OptimusMoe
.
expert_histogram
(
top_k_index
,
self
.
num_experts
)
scatter_index
=
torch
.
ops
.
OptimusMoe
.
index_compute
(
top_k_index
,
expert_token_cnt
,
out
=
scatter_index
)
mid_output
=
torch
.
ops
.
OptimusMoe
.
scatter
(
x
,
scatter_index
)
expert_output
=
self
.
get_expert_output
(
mid_output
,
expert_token_cnt
,
x_shape
[
0
])
output
=
torch
.
ops
.
OptimusMoe
.
gather
(
expert_output
,
scatter_index
,
expert_weight
)
else
:
expert_weight
,
expert_token_cnt
,
scatter_index
=
torch
.
ops
.
OptimusMoe
.
gating_histogram_index
(
logits
,
self
.
top_k
,
1.0
,
self
.
norm_expert_weight
)
mid_output
=
torch
.
ops
.
OptimusMoe
.
scatter
(
x
,
scatter_index
)
expert_output
=
self
.
get_expert_output
(
mid_output
,
expert_token_cnt
,
x_shape
[
0
])
output
=
torch
.
ops
.
OptimusMoe
.
gather
(
expert_output
,
scatter_index
,
expert_weight
)
if
self
.
tp_rank
==
0
and
residual
is
not
None
:
output
+=
residual
if
not
disable_allreduce
:
output
=
tensor_model_parallel_all_reduce
(
output
)
if
user_output
is
not
None
:
user_output
.
copy_
(
output
)
return
output
class
Step1MLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
use_optimus_silu
:
bool
=
True
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
if
use_optimus_silu
:
self
.
act_fn
=
OptimusSiluAndMul
()
else
:
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
,
residual
=
None
,
layernorm
=
None
,
disable_allreduce
=
False
,
user_output
=
None
):
if
layernorm
is
not
None
:
x
=
layernorm
(
x
,
fp16_out
=
self
.
gate_up_proj
.
quant_method
.
quant_config
.
weight_bits
==
6
if
getattr
(
self
.
gate_up_proj
.
quant_method
,
"quant_config"
,
None
)
else
False
)
x
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
x
)
residual
,
_
=
self
.
down_proj
(
x
,
residual
,
output
=
user_output
,
disable_allreduce
=
disable_allreduce
)
return
residual
class
Step1Attention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
slopes
:
Optional
[
List
[
float
]]
=
None
,
max_pos_interp_ratio
:
float
=
1.0
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
# Create the alibi slopes and slice them.
tp_rank
=
get_tensor_model_parallel_rank
()
head_start
=
tp_rank
*
self
.
num_heads
head_end
=
(
tp_rank
+
1
)
*
self
.
num_heads
if
slopes
is
None
:
alibi_slopes
=
_get_alibi_slopes
(
self
.
total_num_heads
)
alibi_slopes
=
_get_ntk_alibi_slopes
(
max_pos_interp_ratio
,
alibi_slopes
)
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
]
else
:
assert
len
(
slopes
)
==
self
.
total_num_heads
alibi_slopes
=
_get_ntk_alibi_slopes
(
max_pos_interp_ratio
,
slopes
).
tolist
()
alibi_slopes
=
slopes
[
head_start
:
head_end
]
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scaling
,
self
.
num_kv_heads
,
alibi_slopes
,
alibi_sqrt
=
True
,
cache_config
=
cache_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
layernorm
:
Optional
[
nn
.
Module
]
=
None
,
disable_allreduce
=
False
,
user_output
=
None
)
->
torch
.
Tensor
:
del
positions
# Unused.
hidden_states
=
layernorm
(
hidden_states
,
fp16_out
=
self
.
qkv_proj
.
quant_method
.
quant_config
.
weight_bits
==
6
if
getattr
(
self
.
qkv_proj
.
quant_method
,
"quant_config"
,
None
)
else
False
)
if
layernorm
else
hidden_states
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
residual
,
_
=
self
.
o_proj
(
attn_output
,
residual
,
disable_allreduce
=
disable_allreduce
,
output
=
user_output
)
return
residual
class
Step1DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
model_config
:
ModelConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
enable_cudagraph
=
not
model_config
.
enforce_eager
config
=
model_config
.
hf_config
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
Step1Attention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
config
.
num_attention_groups
,
slopes
=
config
.
alibi_slopes
,
max_pos_interp_ratio
=
config
.
max_pos_interp_ratio
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
layer_idx
=
int
(
prefix
.
split
(
"layers."
)[
1
].
split
(
"."
)[
0
])
self
.
use_moe
=
config
.
use_moe
and
(
layer_idx
+
config
.
moe_layer_offset
)
%
config
.
moe_every_n_layer
==
0
if
self
.
use_moe
:
self
.
moe
=
Step1MoEMLP
(
config
.
moe_num_experts
,
config
.
moe_top_k
,
config
.
moe_dynamic_exp_p
,
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
hidden_act
=
"silu"
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.moe"
,
enable_cudagraph
=
self
.
enable_cudagraph
,
)
else
:
self
.
mlp
=
Step1MLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
"silu"
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
ln_cls
=
OptimusRMSNorm
if
config
.
hidden_size
%
64
==
0
else
RMSNorm
self
.
input_layernorm
=
ln_cls
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
ln_cls
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# Self Attention
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
hidden_states
,
layernorm
=
self
.
input_layernorm
,
)
# Fully Connected
def
ffn_switch
():
return
self
.
moe
if
self
.
use_moe
else
self
.
mlp
hidden_states
=
ffn_switch
()(
hidden_states
,
hidden_states
,
self
.
post_attention_layernorm
)
return
hidden_states
# @support_torch_compile
class
Step1Model
(
nn
.
Module
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
assert
lora_config
is
None
self
.
config
=
config
self
.
allgather_dtype
=
None
# FIXME(ys): disable fp8 allgather
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
if
get_pp_group
().
is_first_rank
or
(
config
.
tie_word_embeddings
and
get_pp_group
().
is_last_rank
):
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
quant_config
=
quant_config
,
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
Step1DecoderLayer
(
model_config
=
vllm_config
.
model_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
,
)
ln_cls
=
OptimusRMSNorm
if
config
.
hidden_size
%
64
==
0
else
RMSNorm
if
get_pp_group
().
is_last_rank
:
self
.
norm
=
ln_cls
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
hidden_size
))
self
.
sequence_parallel_threshold
=
None
if
DISABLE_SEQUENCE_PARALLEL
else
SEQUENCE_PARALLEL_THRESHOLD
self
.
overlap_ratio
=
GEMM_COMM_OVERLAP_RATIO
self
.
mlp_batch_size
=
MLP_BATCH_SIZE
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
use_moe
=
config
.
use_moe
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
world_size
>
1
:
return
self
.
forward_pp
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
else
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
if
self
.
use_moe
:
return
self
.
forward_hidden_states_moe
(
hidden_states
,
positions
)
else
:
return
self
.
forward_hidden_states
(
hidden_states
,
positions
)
def
forward_pp
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
})
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
def
forward_hidden_states_moe
(
self
,
hidden_states
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
S
=
hidden_states
.
shape
[
0
]
if
(
self
.
tp_size
>
1
and
self
.
sequence_parallel_threshold
is
not
None
and
self
.
sequence_parallel_threshold
<
S
):
if
self
.
tp_size
>
8
:
return
self
.
forward_overlap_v2
(
hidden_states
,
positions
)
else
:
# TODO(xwx): overlap mlp layer of MoE model
return
self
.
forward_split_ffn
(
hidden_states
,
positions
)
else
:
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
,
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
def
forward_overlap_v2
(
self
,
hidden_states
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
del
positions
S
=
hidden_states
.
shape
[
0
]
tp_size
=
get_tensor_model_parallel_world_size
()
rank
=
get_tensor_model_parallel_rank
()
if
S
%
tp_size
!=
0
:
# pad to multiple of tp_size with 0
pad_len
=
tp_size
-
S
%
tp_size
hidden_states
=
torch
.
cat
([
hidden_states
,
torch
.
zeros
(
pad_len
,
hidden_states
.
shape
[
1
],
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
)
])
S
=
hidden_states
.
shape
[
0
]
else
:
pad_len
=
0
assert
S
%
tp_size
==
0
hidden_states
=
hidden_states
.
view
(
S
,
-
1
)
dim_0
=
int
((
S
*
self
.
overlap_ratio
+
tp_size
-
1
)
//
tp_size
*
tp_size
)
# round up to multiple of tp_size
buffer
=
torch
.
empty
(
S
*
int
(
self
.
config
.
intermediate_size
/
tp_size
),
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
)
mlp_buffer
=
buffer
.
view
(
S
,
-
1
)
if
tp_size
>
8
:
assert
tp_size
%
8
==
0
,
f
"tp_size should be an integer multiple of 8,but cur tp_size=
{
tp_size
}
"
kv_repeat
=
tp_size
//
8
else
:
kv_repeat
=
1
qkv_buffer
=
buffer
[:
S
*
int
(
(
self
.
config
.
num_attention_heads
+
self
.
config
.
num_attention_groups
*
kv_repeat
*
2
)
//
tp_size
*
(
self
.
config
.
hidden_size
//
self
.
config
.
num_attention_heads
))].
view
(
S
,
-
1
)
chunk_size
=
S
//
tp_size
residual
=
torch
.
empty
(
chunk_size
,
self
.
config
.
hidden_size
,
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
)
chunk_size_0
=
dim_0
//
tp_size
chunk_size_1
=
chunk_size
-
chunk_size_0
residual_intersect_0
=
residual
[:
chunk_size_0
]
residual_intersect_1
=
residual
[
chunk_size_0
:]
hidden_states_intersect_0
=
hidden_states
[
rank
*
chunk_size_0
:(
rank
+
1
)
*
chunk_size_0
]
hidden_states_intersect_1
=
hidden_states
[
dim_0
+
rank
*
chunk_size_1
:
dim_0
+
(
rank
+
1
)
*
chunk_size_1
]
s1
=
torch
.
cuda
.
Stream
(
device
=
residual
.
device
)
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
ffn
=
layer
.
moe
if
layer
.
use_moe
else
layer
.
mlp
# Attention Forward
residual_intersect_0
.
copy_
(
hidden_states_intersect_0
)
layer
.
input_layernorm
(
hidden_states
[:
dim_0
],
output
=
hidden_states
[:
dim_0
])
layer
.
self_attn
.
qkv_proj
(
hidden_states
[:
dim_0
],
output
=
qkv_buffer
[:
dim_0
])
with
torch
.
cuda
.
stream
(
s1
):
residual_intersect_1
.
copy_
(
hidden_states_intersect_1
)
layer
.
input_layernorm
(
hidden_states
[
dim_0
:],
output
=
hidden_states
[
dim_0
:])
layer
.
self_attn
.
qkv_proj
(
hidden_states
[
dim_0
:],
output
=
qkv_buffer
[
dim_0
:])
torch
.
cuda
.
current_stream
().
wait_stream
(
s1
)
q
,
k
,
v
=
qkv_buffer
.
view
(
S
,
-
1
).
split
([
layer
.
self_attn
.
q_size
,
layer
.
self_attn
.
kv_size
,
layer
.
self_attn
.
kv_size
],
dim
=-
1
)
if
pad_len
>
0
:
attn_output
=
layer
.
self_attn
.
attn
(
q
[:
-
pad_len
],
k
[:
-
pad_len
],
v
[:
-
pad_len
])
attn_output
=
torch
.
cat
([
attn_output
,
torch
.
zeros
(
pad_len
,
attn_output
.
shape
[
1
],
dtype
=
attn_output
.
dtype
,
device
=
attn_output
.
device
)],
dim
=
0
)
else
:
attn_output
=
layer
.
self_attn
.
attn
(
q
,
k
,
v
)
hidden_states
=
hidden_states
.
view
(
S
,
-
1
)
layer
.
self_attn
.
o_proj
(
attn_output
[:
dim_0
],
output
=
hidden_states
[:
dim_0
],
disable_allreduce
=
True
)
hidden_states_intersect_0
.
add_
(
residual_intersect_0
)
torch
.
distributed
.
all_reduce
(
hidden_states
[:
dim_0
],
group
=
get_tensor_model_parallel_group
().
device_group
)
with
torch
.
cuda
.
stream
(
s1
):
layer
.
self_attn
.
o_proj
(
attn_output
[
dim_0
:],
output
=
hidden_states
[
dim_0
:],
disable_allreduce
=
True
)
hidden_states_intersect_1
.
add_
(
residual_intersect_1
)
torch
.
distributed
.
all_reduce
(
hidden_states
[
dim_0
:],
group
=
get_tensor_model_parallel_group
().
device_group
)
del
attn_output
residual_intersect_0
.
copy_
(
hidden_states_intersect_0
)
layer
.
post_attention_layernorm
(
hidden_states
[:
dim_0
],
output
=
hidden_states
[:
dim_0
])
num_batch_size
=
(
dim_0
+
self
.
mlp_batch_size
-
1
)
//
self
.
mlp_batch_size
for
idx
in
range
(
num_batch_size
):
start
=
idx
*
self
.
mlp_batch_size
end
=
min
((
idx
+
1
)
*
self
.
mlp_batch_size
,
dim_0
)
ffn
(
hidden_states
[
start
:
end
],
disable_allreduce
=
True
,
user_output
=
hidden_states
[
start
:
end
])
hidden_states_intersect_0
.
add_
(
residual_intersect_0
)
torch
.
distributed
.
all_reduce
(
hidden_states
[:
dim_0
],
group
=
get_tensor_model_parallel_group
().
device_group
)
with
torch
.
cuda
.
stream
(
s1
):
residual_intersect_1
.
copy_
(
hidden_states_intersect_1
)
layer
.
post_attention_layernorm
(
hidden_states
[
dim_0
:],
output
=
hidden_states
[
dim_0
:])
num_batch_size
=
(
S
-
dim_0
+
self
.
mlp_batch_size
-
1
)
//
self
.
mlp_batch_size
for
idx
in
range
(
num_batch_size
):
start
=
dim_0
+
idx
*
self
.
mlp_batch_size
end
=
dim_0
+
min
(
(
idx
+
1
)
*
self
.
mlp_batch_size
,
S
-
dim_0
)
ffn
(
hidden_states
[
start
:
end
],
disable_allreduce
=
True
,
user_output
=
hidden_states
[
start
:
end
])
hidden_states_intersect_1
.
add_
(
residual_intersect_1
)
torch
.
distributed
.
all_reduce
(
hidden_states
[
dim_0
:],
group
=
get_tensor_model_parallel_group
().
device_group
)
torch
.
cuda
.
current_stream
().
wait_stream
(
s1
)
del
buffer
,
mlp_buffer
,
qkv_buffer
,
residual
self
.
norm
(
hidden_states
,
output
=
hidden_states
)
return
hidden_states
[:
S
-
pad_len
]
def
forward_split_ffn
(
self
,
hidden_states
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
seq_len
=
hidden_states
.
shape
[
0
]
tp_size
=
get_tensor_model_parallel_world_size
()
rank
=
get_tensor_model_parallel_rank
()
chunk_size
=
self
.
config
.
hidden_size
//
tp_size
residual
=
torch
.
empty
(
seq_len
,
chunk_size
,
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
)
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states_intersect_0
=
hidden_states
.
narrow
(
1
,
rank
*
chunk_size
,
chunk_size
)
residual
.
copy_
(
hidden_states_intersect_0
)
layer
.
input_layernorm
(
hidden_states
,
output
=
hidden_states
)
layer
.
self_attn
(
positions
,
hidden_states
,
residual
=
None
,
layernorm
=
None
,
disable_allreduce
=
True
,
user_output
=
hidden_states
)
hidden_states_intersect_0
.
add_
(
residual
)
torch
.
distributed
.
all_reduce
(
hidden_states
,
group
=
get_tensor_model_parallel_group
().
device_group
)
residual
.
copy_
(
hidden_states_intersect_0
)
layer
.
post_attention_layernorm
(
hidden_states
,
output
=
hidden_states
)
num_batch_size
=
(
seq_len
+
self
.
mlp_batch_size
-
1
)
//
self
.
mlp_batch_size
hidden_states
=
hidden_states
.
view
(
seq_len
,
-
1
)
for
idx
in
range
(
num_batch_size
):
start
=
idx
*
self
.
mlp_batch_size
end
=
min
((
idx
+
1
)
*
self
.
mlp_batch_size
,
seq_len
)
if
layer
.
use_moe
:
hidden_states
[
start
:
end
]
=
layer
.
moe
(
hidden_states
[
start
:
end
],
disable_allreduce
=
True
)
else
:
layer
.
mlp
(
hidden_states
[
start
:
end
],
disable_allreduce
=
True
,
user_output
=
hidden_states
[
start
:
end
])
hidden_states_intersect_0
.
add_
(
residual
)
torch
.
distributed
.
all_reduce
(
hidden_states
,
group
=
get_tensor_model_parallel_group
().
device_group
)
del
residual
self
.
norm
(
hidden_states
,
output
=
hidden_states
)
return
hidden_states
def
forward_hidden_states
(
self
,
hidden_states
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
S
=
hidden_states
.
shape
[
0
]
if
self
.
tp_size
>
1
and
self
.
sequence_parallel_threshold
is
not
None
and
self
.
sequence_parallel_threshold
<
S
:
tp_size
=
get_tensor_model_parallel_world_size
()
rank
=
get_tensor_model_parallel_rank
()
if
S
%
tp_size
!=
0
:
# pad to multiple of tp_size with 0
pad_len
=
tp_size
-
S
%
tp_size
hidden_states
=
torch
.
cat
([
hidden_states
,
torch
.
zeros
(
pad_len
,
hidden_states
.
shape
[
1
],
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
)
])
S
=
hidden_states
.
shape
[
0
]
else
:
pad_len
=
0
assert
S
%
tp_size
==
0
chunk_size
=
S
//
tp_size
residual
=
torch
.
empty
(
chunk_size
,
self
.
config
.
hidden_size
,
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
)
dim_0
=
int
((
S
*
self
.
overlap_ratio
+
tp_size
-
1
)
//
tp_size
*
tp_size
)
# round up to multiple of tp_size
dim_1
=
S
-
dim_0
mlp_dim
=
int
(
self
.
config
.
intermediate_size
/
tp_size
)
qkv_dim
=
int
(
(
self
.
config
.
num_attention_heads
+
self
.
config
.
num_attention_groups
*
2
)
//
tp_size
*
(
self
.
config
.
hidden_size
//
self
.
config
.
num_attention_heads
))
if
self
.
allgather_dtype
is
not
None
:
fp8_dim
=
int
(
self
.
config
.
hidden_size
/
2
)
max_buffer_dim
=
max
(
mlp_dim
,
qkv_dim
,
fp8_dim
)
else
:
max_buffer_dim
=
max
(
mlp_dim
,
qkv_dim
)
buffer
=
torch
.
empty
(
S
*
max_buffer_dim
,
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
)
buffer_0
=
buffer
[:
dim_0
*
max_buffer_dim
]
buffer_1
=
buffer
[
dim_0
*
max_buffer_dim
:]
mlp_buffer_0
=
buffer_0
[:
dim_0
*
mlp_dim
].
view
(
dim_0
,
-
1
)
mlp_buffer_1
=
buffer_1
[:
dim_1
*
mlp_dim
].
view
(
dim_1
,
-
1
)
qkv_buffer
=
buffer
[
dim_0
*
max_buffer_dim
-
dim_0
*
qkv_dim
:
dim_0
*
max_buffer_dim
+
dim_1
*
qkv_dim
].
view
(
S
,
-
1
)
chunk_size_0
=
dim_0
//
tp_size
chunk_size_1
=
chunk_size
-
chunk_size_0
residual_intersect_0
=
residual
[:
chunk_size_0
]
hidden_states_0
=
hidden_states
[:
dim_0
]
hidden_states_1
=
hidden_states
[
dim_0
:]
hidden_states_intersect_0
=
hidden_states
[
rank
*
chunk_size_0
:(
rank
+
1
)
*
chunk_size_0
]
residual_intersect_1
=
residual
[
chunk_size_0
:]
hidden_states_intersect_1
=
hidden_states
[
dim_0
+
rank
*
chunk_size_1
:
dim_0
+
(
rank
+
1
)
*
chunk_size_1
]
if
self
.
allgather_dtype
is
not
None
:
hidden_states_fp8_0
=
buffer_0
[:
dim_0
*
fp8_dim
]
hidden_states_fp8_0
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_0
,
torch
.
uint8
).
reshape
(
dim_0
,
self
.
config
.
hidden_size
)
hidden_states_fp8_1
=
buffer_1
[:
dim_1
*
fp8_dim
]
hidden_states_fp8_1
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_1
,
torch
.
uint8
).
reshape
(
dim_1
,
self
.
config
.
hidden_size
)
hidden_states_fp8_intersect_0
=
hidden_states_fp8_0
[
rank
*
chunk_size_0
:(
rank
+
1
)
*
chunk_size_0
]
hidden_states_fp8_intersect_1
=
hidden_states_fp8_1
[
rank
*
chunk_size_1
:(
rank
+
1
)
*
chunk_size_1
]
s1
=
torch
.
cuda
.
Stream
(
device
=
residual
.
device
)
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
# Attention Forward
if
i
==
0
:
residual_intersect_0
.
copy_
(
hidden_states_intersect_0
)
layer
.
input_layernorm
(
hidden_states
[:
dim_0
],
output
=
hidden_states
[:
dim_0
])
layer
.
self_attn
.
qkv_proj
(
hidden_states
[:
dim_0
],
output
=
qkv_buffer
[:
dim_0
])
with
torch
.
cuda
.
stream
(
s1
):
if
i
==
0
:
residual_intersect_1
.
copy_
(
hidden_states_intersect_1
)
layer
.
input_layernorm
(
hidden_states
[
dim_0
:],
output
=
hidden_states
[
dim_0
:])
else
:
if
self
.
allgather_dtype
is
not
None
:
if
self
.
allgather_dtype
==
"static_fp8e4m3"
:
qkv_input_scale_1
=
torch
.
full
(
[
1
],
layer
.
self_attn
.
qkv_proj
.
input_scales
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
torch
.
ops
.
OptimusFp8
.
rms_norm_quantize_infer
(
residual_intersect_1
,
layer
.
input_layernorm
.
weight
,
qkv_input_scale_1
,
out
=
hidden_states_fp8_intersect_1
)
elif
self
.
allgather_dtype
==
"dynamic_fp8e4m3"
:
layer
.
input_layernorm
(
residual_intersect_1
,
output
=
hidden_states_intersect_1
)
qkv_input_scale_1
=
dynamic_fp8_pertensor_quantize
(
hidden_states_intersect_1
)
torch
.
distributed
.
all_reduce
(
qkv_input_scale_1
,
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
get_tensor_model_parallel_group
(
).
device_group
)
hidden_states_fp8_intersect_1
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_intersect_1
,
torch
.
float8_e4m3fn
)
torch
.
ops
.
OptimusFp8
.
quantize
(
hidden_states_intersect_1
,
qkv_input_scale_1
,
out
=
hidden_states_fp8_intersect_1
)
hidden_states_fp8_intersect_1
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_intersect_1
,
torch
.
uint8
)
else
:
raise
ValueError
(
f
"Unsupported allgather_dtype:
{
self
.
allgather_dtype
}
"
)
torch
.
distributed
.
all_gather_into_tensor
(
hidden_states_fp8_1
,
hidden_states_fp8_intersect_1
,
group
=
get_tensor_model_parallel_group
(
).
device_group
)
hidden_states_fp8_1
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_1
,
torch
.
float8_e4m3fn
)
torch
.
ops
.
OptimusFp8
.
dequantize
(
hidden_states_fp8_1
,
qkv_input_scale_1
.
reciprocal
(),
torch
.
bfloat16
,
out
=
hidden_states
[
dim_0
:])
hidden_states_fp8_1
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_1
,
torch
.
uint8
)
else
:
layer
.
input_layernorm
(
residual_intersect_1
,
output
=
hidden_states_intersect_1
)
torch
.
distributed
.
all_gather_into_tensor
(
hidden_states
[
dim_0
:],
hidden_states_intersect_1
,
group
=
get_tensor_model_parallel_group
(
).
device_group
)
layer
.
self_attn
.
qkv_proj
(
hidden_states
[
dim_0
:],
output
=
qkv_buffer
[
dim_0
:])
torch
.
cuda
.
current_stream
().
wait_stream
(
s1
)
q
,
k
,
v
=
qkv_buffer
.
split
([
layer
.
self_attn
.
q_size
,
layer
.
self_attn
.
kv_size
,
layer
.
self_attn
.
kv_size
],
dim
=-
1
)
if
pad_len
>
0
:
attn_output
=
layer
.
self_attn
.
attn
(
q
[:
S
-
pad_len
],
k
[:
S
-
pad_len
],
v
[:
S
-
pad_len
])
attn_output
=
torch
.
cat
([
attn_output
,
torch
.
zeros
(
pad_len
,
attn_output
.
shape
[
1
],
dtype
=
attn_output
.
dtype
,
device
=
attn_output
.
device
)],
dim
=
0
)
else
:
attn_output
=
layer
.
self_attn
.
attn
(
q
,
k
,
v
)
layer
.
self_attn
.
o_proj
(
attn_output
[:
dim_0
],
output
=
hidden_states
[:
dim_0
],
disable_allreduce
=
True
)
hidden_states_intersect_0
.
add_
(
residual_intersect_0
)
torch
.
distributed
.
reduce_scatter_tensor
(
residual_intersect_0
,
hidden_states
[:
dim_0
],
group
=
get_tensor_model_parallel_group
().
device_group
)
with
torch
.
cuda
.
stream
(
s1
):
layer
.
self_attn
.
o_proj
(
attn_output
[
dim_0
:],
output
=
hidden_states
[
dim_0
:],
disable_allreduce
=
True
)
hidden_states_intersect_1
.
add_
(
residual_intersect_1
)
del
attn_output
if
self
.
allgather_dtype
is
not
None
:
if
self
.
allgather_dtype
==
"static_fp8e4m3"
:
gate_up_input_scale_0
=
torch
.
full
(
[
1
],
layer
.
mlp
.
gate_up_proj
.
input_scales
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
torch
.
ops
.
OptimusFp8
.
rms_norm_quantize_infer
(
residual_intersect_0
,
layer
.
post_attention_layernorm
.
weight
,
gate_up_input_scale_0
,
out
=
hidden_states_fp8_intersect_0
)
elif
self
.
allgather_dtype
==
"dynamic_fp8e4m3"
:
layer
.
post_attention_layernorm
(
residual_intersect_0
,
output
=
hidden_states_intersect_0
)
gate_up_input_scale_0
=
dynamic_fp8_pertensor_quantize
(
hidden_states_intersect_0
)
torch
.
distributed
.
all_reduce
(
gate_up_input_scale_0
,
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
get_tensor_model_parallel_group
(
).
device_group
)
hidden_states_fp8_intersect_0
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_intersect_0
,
torch
.
float8_e4m3fn
)
torch
.
ops
.
OptimusFp8
.
quantize
(
hidden_states_intersect_0
,
gate_up_input_scale_0
,
out
=
hidden_states_fp8_intersect_0
)
hidden_states_fp8_intersect_0
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_intersect_0
,
torch
.
uint8
)
else
:
raise
ValueError
(
f
"Unsupported allgather_dtype:
{
self
.
allgather_dtype
}
"
)
torch
.
distributed
.
all_gather_into_tensor
(
hidden_states_fp8_0
,
hidden_states_fp8_intersect_0
,
group
=
get_tensor_model_parallel_group
().
device_group
)
else
:
layer
.
post_attention_layernorm
(
residual_intersect_0
,
output
=
hidden_states_intersect_0
)
torch
.
distributed
.
all_gather_into_tensor
(
hidden_states
[:
dim_0
],
hidden_states_intersect_0
,
group
=
get_tensor_model_parallel_group
().
device_group
)
with
torch
.
cuda
.
stream
(
s1
):
torch
.
distributed
.
reduce_scatter_tensor
(
residual_intersect_1
,
hidden_states
[
dim_0
:],
group
=
get_tensor_model_parallel_group
().
device_group
)
if
self
.
allgather_dtype
is
not
None
:
hidden_states_fp8_0
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_0
,
torch
.
float8_e4m3fn
)
torch
.
ops
.
OptimusFp8
.
dequantize
(
hidden_states_fp8_0
,
gate_up_input_scale_0
.
reciprocal
(),
torch
.
bfloat16
,
out
=
hidden_states
[:
dim_0
])
hidden_states_fp8_0
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_0
,
torch
.
uint8
)
num_batch_size
=
(
dim_0
+
self
.
mlp_batch_size
-
1
)
//
self
.
mlp_batch_size
for
idx
in
range
(
num_batch_size
):
start
=
idx
*
self
.
mlp_batch_size
end
=
min
((
idx
+
1
)
*
self
.
mlp_batch_size
,
dim_0
)
w0_out_0
,
_
=
layer
.
mlp
.
gate_up_proj
(
hidden_states_0
[
start
:
end
])
layer
.
mlp
.
act_fn
(
w0_out_0
,
output
=
mlp_buffer_0
[
start
:
end
])
del
w0_out_0
with
torch
.
cuda
.
stream
(
s1
):
if
self
.
allgather_dtype
is
not
None
:
if
self
.
allgather_dtype
==
"static_fp8e4m3"
:
gate_up_input_scale_1
=
torch
.
full
(
[
1
],
layer
.
mlp
.
gate_up_proj
.
input_scales
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
torch
.
ops
.
OptimusFp8
.
rms_norm_quantize_infer
(
residual_intersect_1
,
layer
.
post_attention_layernorm
.
weight
,
gate_up_input_scale_1
,
out
=
hidden_states_fp8_intersect_1
)
elif
self
.
allgather_dtype
==
"dynamic_fp8e4m3"
:
layer
.
post_attention_layernorm
(
residual_intersect_1
,
output
=
hidden_states_intersect_1
)
gate_up_input_scale_1
=
dynamic_fp8_pertensor_quantize
(
hidden_states_intersect_1
)
torch
.
distributed
.
all_reduce
(
gate_up_input_scale_1
,
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
get_tensor_model_parallel_group
(
).
device_group
)
hidden_states_fp8_intersect_1
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_intersect_1
,
torch
.
float8_e4m3fn
)
torch
.
ops
.
OptimusFp8
.
quantize
(
hidden_states_intersect_1
,
gate_up_input_scale_1
,
out
=
hidden_states_fp8_intersect_1
)
hidden_states_fp8_intersect_1
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_intersect_1
,
torch
.
uint8
)
else
:
raise
ValueError
(
f
"Unsupported allgather_dtype:
{
self
.
allgather_dtype
}
"
)
torch
.
distributed
.
all_gather_into_tensor
(
hidden_states_fp8_1
,
hidden_states_fp8_intersect_1
,
group
=
get_tensor_model_parallel_group
(
).
device_group
)
else
:
layer
.
post_attention_layernorm
(
residual_intersect_1
,
output
=
hidden_states_intersect_1
)
torch
.
distributed
.
all_gather_into_tensor
(
hidden_states
[
dim_0
:],
hidden_states_intersect_1
,
group
=
get_tensor_model_parallel_group
(
).
device_group
)
layer
.
mlp
.
down_proj
(
mlp_buffer_0
,
output
=
hidden_states
[:
dim_0
],
disable_allreduce
=
True
)
hidden_states_intersect_0
.
add_
(
residual_intersect_0
)
if
i
<
len
(
self
.
layers
)
-
1
:
torch
.
distributed
.
reduce_scatter_tensor
(
residual_intersect_0
,
hidden_states
[:
dim_0
],
group
=
get_tensor_model_parallel_group
().
device_group
)
else
:
torch
.
distributed
.
all_reduce
(
hidden_states
[:
dim_0
],
group
=
get_tensor_model_parallel_group
().
device_group
)
with
torch
.
cuda
.
stream
(
s1
):
if
self
.
allgather_dtype
is
not
None
:
hidden_states_fp8_1
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_1
,
torch
.
float8_e4m3fn
)
torch
.
ops
.
OptimusFp8
.
dequantize
(
hidden_states_fp8_1
,
gate_up_input_scale_1
.
reciprocal
(),
torch
.
bfloat16
,
out
=
hidden_states
[
dim_0
:])
hidden_states_fp8_1
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_1
,
torch
.
uint8
)
num_batch_size
=
(
dim_1
+
self
.
mlp_batch_size
-
1
)
//
self
.
mlp_batch_size
for
idx
in
range
(
num_batch_size
):
start
=
idx
*
self
.
mlp_batch_size
end
=
min
((
idx
+
1
)
*
self
.
mlp_batch_size
,
dim_1
)
w0_out_1
,
_
=
layer
.
mlp
.
gate_up_proj
(
hidden_states_1
[
start
:
end
])
layer
.
mlp
.
act_fn
(
w0_out_1
,
output
=
mlp_buffer_1
[
start
:
end
])
del
w0_out_1
if
i
<
len
(
self
.
layers
)
-
1
:
next_layer
=
self
.
layers
[
i
+
1
]
if
self
.
allgather_dtype
is
not
None
:
if
self
.
allgather_dtype
==
"static_fp8e4m3"
:
qkv_input_scale_0
=
torch
.
full
(
[
1
],
next_layer
.
self_attn
.
qkv_proj
.
input_scales
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
torch
.
ops
.
OptimusFp8
.
rms_norm_quantize_infer
(
residual_intersect_0
,
next_layer
.
input_layernorm
.
weight
,
qkv_input_scale_0
,
out
=
hidden_states_fp8_intersect_0
)
elif
self
.
allgather_dtype
==
"dynamic_fp8e4m3"
:
next_layer
.
input_layernorm
(
residual_intersect_0
,
output
=
hidden_states_intersect_0
)
qkv_input_scale_0
=
dynamic_fp8_pertensor_quantize
(
hidden_states_intersect_0
)
torch
.
distributed
.
all_reduce
(
qkv_input_scale_0
,
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
get_tensor_model_parallel_group
(
).
device_group
)
hidden_states_fp8_intersect_0
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_intersect_0
,
torch
.
float8_e4m3fn
)
torch
.
ops
.
OptimusFp8
.
quantize
(
hidden_states_intersect_0
,
qkv_input_scale_0
,
out
=
hidden_states_fp8_intersect_0
)
hidden_states_fp8_intersect_0
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_intersect_0
,
torch
.
uint8
)
else
:
raise
ValueError
(
f
"Unsupported allgather_dtype:
{
self
.
allgather_dtype
}
"
)
torch
.
distributed
.
all_gather_into_tensor
(
hidden_states_fp8_0
,
hidden_states_fp8_intersect_0
,
group
=
get_tensor_model_parallel_group
(
).
device_group
)
else
:
next_layer
.
input_layernorm
(
residual_intersect_0
,
output
=
hidden_states_intersect_0
)
torch
.
distributed
.
all_gather_into_tensor
(
hidden_states
[:
dim_0
],
hidden_states_intersect_0
,
group
=
get_tensor_model_parallel_group
(
).
device_group
)
with
torch
.
cuda
.
stream
(
s1
):
layer
.
mlp
.
down_proj
(
mlp_buffer_1
,
output
=
hidden_states
[
dim_0
:],
disable_allreduce
=
True
)
hidden_states_intersect_1
.
add_
(
residual_intersect_1
)
if
i
<
len
(
self
.
layers
)
-
1
:
torch
.
distributed
.
reduce_scatter_tensor
(
residual_intersect_1
,
hidden_states
[
dim_0
:],
group
=
get_tensor_model_parallel_group
(
).
device_group
)
else
:
torch
.
distributed
.
all_reduce
(
hidden_states
[
dim_0
:],
group
=
get_tensor_model_parallel_group
(
).
device_group
)
if
i
<
len
(
self
.
layers
)
-
1
:
next_layer
=
self
.
layers
[
i
+
1
]
if
self
.
allgather_dtype
is
not
None
:
hidden_states_fp8_0
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_0
,
torch
.
float8_e4m3fn
)
torch
.
ops
.
OptimusFp8
.
dequantize
(
hidden_states_fp8_0
,
qkv_input_scale_0
.
reciprocal
(),
torch
.
bfloat16
,
out
=
hidden_states
[:
dim_0
])
hidden_states_fp8_0
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_0
,
torch
.
uint8
)
next_layer
.
self_attn
.
qkv_proj
(
hidden_states
[:
dim_0
],
output
=
qkv_buffer
[:
dim_0
])
torch
.
cuda
.
current_stream
().
wait_stream
(
s1
)
del
buffer
,
residual
self
.
norm
(
hidden_states
,
output
=
hidden_states
)
return
hidden_states
[:
S
-
pad_len
]
else
:
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
,
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
class
Step1PretrainedModel
(
nn
.
Module
,
SupportsPP
):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
()
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
loaded_params
.
add
(
name
)
break
else
:
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
params_need_to_load
=
[]
for
name
in
params_dict
:
if
not
(
"vision_model"
in
name
or
"latent_query_tokens"
in
name
or
"sam_model"
in
name
):
params_need_to_load
.
append
(
name
)
params_need_to_load
=
set
(
params_need_to_load
)
if
params_need_to_load
!=
loaded_params
:
param_name_example
=
list
(
params_need_to_load
-
loaded_params
)[
0
]
raise
RuntimeError
(
f
"Some parameters like
{
param_name_example
}
are not in the checkpoint and will falsely use random initialization"
)
def
load_fp8_input_scales
(
self
,
input_scales_path
):
for
name
,
loaded_weight
in
fp8_input_scales_loader
(
input_scales_path
):
if
name
.
startswith
(
"refrence_model."
):
name
=
name
.
replace
(
"refrence_model."
,
""
)
idx
=
int
(
name
.
split
(
"."
)[
2
])
layer
=
self
.
model
.
layers
[
idx
]
if
"qkv_proj"
in
name
:
layer
.
self_attn
.
qkv_proj
.
input_scales
=
loaded_weight
[:].
item
()
elif
"gate_up_proj"
in
name
:
layer
.
mlp
.
gate_up_proj
.
input_scales
=
loaded_weight
[:].
item
()
class
Step1ForCausalLM
(
Step1PretrainedModel
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
model
=
Step1Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
,
need_fp32_logits
=
False
)
self
.
sampler
=
get_sampler
()
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
class
Step1ForSequenceClassification
(
Step1PretrainedModel
):
"""
\
Step1 Transformer with a sequence classification head.
"""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
model
=
Step1Model
(
vllm_config
,
prefix
)
config
=
vllm_config
.
model_config
.
hf_config
assert
len
(
config
.
id2label
.
keys
())
==
config
.
num_labels
if
get_pp_group
().
is_last_rank
:
self
.
score
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
num_labels
,
bias
=
False
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
ALL
,
normalize
=
False
,
softmax
=
False
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
logits
,
_
=
self
.
score
(
hidden_states
)
ret
=
self
.
_pooler
(
logits
,
pooling_metadata
)
return
ret
\ No newline at end of file
vllm/model_executor/models/step2_mini.py
0 → 100644
View file @
583034f1
# SPDX-License-Identifier: Apache-2.0
"""Inference-only Jurassic model."""
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.distributed
import
(
get_dp_group
,
get_pp_group
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.step1
import
Step1MoEMLP
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
.interfaces
import
SupportsPP
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
logger
=
init_logger
(
__name__
)
# 全局共享的CUDA graph memory pool,类似model_runner.py中的实现
_graph_memory_pool
:
Optional
[
Tuple
[
int
,
int
]]
=
None
class
FusedMoEBlock
(
nn
.
Module
):
def
__init__
(
self
,
config
:
ModelConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
if
self
.
tp_size
>
config
.
moe_num_experts
:
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
config
.
moe_num_experts
}
."
)
assert
config
.
moe_dynamic_exp_p
==
1
,
"Only support dynamic exp p=1"
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
moe_num_experts
,
top_k
=
config
.
moe_top_k
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
config
.
norm_expert_weight
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
moe_num_experts
,
bias
=
False
,
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.gate"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape
=
hidden_states
.
shape
hidden_dim
=
hidden_states
.
shape
[
-
1
]
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
orig_shape
)
class
Step2MiniMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
self
.
prefix
=
prefix
self
.
hidden_size
=
hidden_size
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
gate_up
,
_
=
self
.
gate_up_proj
(
hidden_states
)
intermediate_act
=
self
.
act_fn
(
gate_up
)
output
,
_
=
self
.
down_proj
(
intermediate_act
)
return
output
class
Step2MiniAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
norm_eps
:
float
,
rope_theta
:
int
,
share_q_dim
:
Optional
[
int
]
=
None
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embedding
:
int
=
8192
,
head_dim
:
int
=
256
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
q_size
=
share_q_dim
if
share_q_dim
else
self
.
head_dim
self
.
qkv_proj
=
ReplicatedLinear
(
hidden_size
,
self
.
q_size
+
self
.
kv_size
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
self
.
inter_norm
=
RMSNorm
(
self
.
q_size
,
eps
=
norm_eps
)
self
.
wq
=
ColumnParallelLinear
(
self
.
q_size
,
self
.
head_dim
*
self
.
total_num_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.wq"
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embedding
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
)
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scaling
,
self
.
num_kv_heads
,
cache_config
=
cache_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
self
.
prefix
=
prefix
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
=
self
.
inter_norm
(
q
.
contiguous
())
q
=
self
.
wq
(
q
)[
0
]
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
residual
,
_
=
self
.
o_proj
(
attn_output
)
return
residual
class
Step2MiniDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
ModelConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
use_fused_moe
:
bool
=
False
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
config
=
config
.
hf_config
self
.
hidden_size
=
config
.
hidden_size
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
self
.
self_attn
=
Step2MiniAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
1
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
norm_eps
=
config
.
rms_norm_eps
,
max_position_embedding
=
config
.
max_position_embedding
,
head_dim
=
config
.
head_dim
,
share_q_dim
=
config
.
share_q_dim
,
rope_theta
=
config
.
rope_theta
,
rope_scaling
=
rope_scaling
,
prefix
=
f
"
{
prefix
}
.self_attn"
)
self
.
use_moe
=
False
layer_idx
=
int
(
prefix
.
split
(
"layers."
)[
1
].
split
(
"."
)[
0
])
moe_layers_enum
=
getattr
(
config
,
"moe_layers_enum"
,
None
)
if
moe_layers_enum
is
not
None
:
moe_layers_idx
=
[
int
(
i
)
for
i
in
moe_layers_enum
.
strip
().
split
(
','
)]
else
:
# Default to 1dense.
moe_layers_idx
=
[
i
for
i
in
range
(
1
,
config
.
num_hidden_layers
)]
if
layer_idx
in
moe_layers_idx
:
if
not
use_fused_moe
:
self
.
moe
=
Step1MoEMLP
(
config
.
moe_num_experts
,
config
.
moe_top_k
,
config
.
moe_dynamic_exp_p
,
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
hidden_act
=
"silu"
,
quant_config
=
quant_config
,
norm_expert_weight
=
config
.
norm_expert_weight
,
prefix
=
f
"
{
prefix
}
.moe"
,
enable_cudagraph
=
False
)
# FIXME: TODO: enable cudagraph
else
:
self
.
moe
=
FusedMoEBlock
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.moe"
)
self
.
share_expert
=
Step2MiniMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
share_expert_dim
,
hidden_act
=
"silu"
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.share_expert"
)
self
.
use_moe
=
True
else
:
self
.
mlp
=
Step2MiniMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
"silu"
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
self
.
use_fused_moe
=
use_fused_moe
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
prefix
=
prefix
# CUDA Graph parameters - 简化版本,使用共享memory pool
self
.
should_capture_graph
=
get_dp_group
().
world_size
>
1
and
current_platform
.
is_cuda_alike
()
self
.
cuda_graphs_captured
=
False
self
.
graph_runners_fwd1
:
dict
[
int
,
Tuple
[
torch
.
cuda
.
CUDAGraph
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]]
=
{}
self
.
graph_runners_fwd2
:
dict
[
int
,
Tuple
[
torch
.
cuda
.
CUDAGraph
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]]
=
{}
self
.
graph_runners_fwd3
:
dict
[
int
,
Tuple
[
torch
.
cuda
.
CUDAGraph
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]]
=
{}
self
.
max_graph_tokens
=
64
self
.
graph_token_step
=
32
self
.
decoder_captured_sizes
=
list
(
range
(
self
.
graph_token_step
,
self
.
max_graph_tokens
+
1
,
self
.
graph_token_step
))
if
self
.
should_capture_graph
else
[]
@
torch
.
inference_mode
()
def
_capture_cuda_graph
(
self
,
device
:
torch
.
device
,
hs_dtype
:
torch
.
dtype
,
pos_dtype
:
torch
.
dtype
):
global
_graph_memory_pool
if
self
.
cuda_graphs_captured
or
not
self
.
should_capture_graph
:
return
# 使用全局共享的memory pool
stream
=
torch
.
cuda
.
Stream
()
stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
stream
):
for
total_tokens
in
reversed
(
self
.
decoder_captured_sizes
):
# --- Capture forward_1 ---
graph_fwd1
=
torch
.
cuda
.
CUDAGraph
()
# 创建输入buffers
static_positions
=
torch
.
ones
((
total_tokens
,),
dtype
=
pos_dtype
,
device
=
device
)
static_hidden_states
=
torch
.
randn
((
total_tokens
,
self
.
hidden_size
),
dtype
=
hs_dtype
,
device
=
device
)
# Warmup forward_1
_
,
_
,
_
=
self
.
_forward_1_impl
(
static_positions
,
static_hidden_states
)
# Capture forward_1 - 使用torch.cuda.graph()和共享memory pool
with
torch
.
cuda
.
graph
(
graph_fwd1
,
pool
=
_graph_memory_pool
,
stream
=
stream
):
static_q_fwd1
,
static_k_fwd1
,
static_v_fwd1
=
self
.
_forward_1_impl
(
static_positions
,
static_hidden_states
)
# 更新全局memory pool
if
_graph_memory_pool
is
None
:
_graph_memory_pool
=
graph_fwd1
.
pool
()
self
.
graph_runners_fwd1
[
total_tokens
]
=
(
graph_fwd1
,
static_positions
,
static_hidden_states
,
static_q_fwd1
,
static_k_fwd1
,
static_v_fwd1
)
# --- Capture forward_2 ---
graph_fwd2
=
torch
.
cuda
.
CUDAGraph
()
# 创建输入buffers
attn_output_size
=
self
.
self_attn
.
num_heads
*
self
.
self_attn
.
head_dim
static_attn_output
=
torch
.
randn
((
total_tokens
,
attn_output_size
),
dtype
=
hs_dtype
,
device
=
device
)
static_residual
=
torch
.
randn
((
total_tokens
,
self
.
hidden_size
),
dtype
=
hs_dtype
,
device
=
device
)
# Warmup forward_2
_
,
_
=
self
.
_forward_2_impl
(
static_attn_output
,
static_residual
)
# Capture forward_2 - 使用torch.cuda.graph()和共享memory pool
with
torch
.
cuda
.
graph
(
graph_fwd2
,
pool
=
_graph_memory_pool
,
stream
=
stream
):
static_hs_out_fwd2
,
static_residual_out_fwd2
=
self
.
_forward_2_impl
(
static_attn_output
,
static_residual
)
self
.
graph_runners_fwd2
[
total_tokens
]
=
(
graph_fwd2
,
static_attn_output
,
static_residual
,
static_hs_out_fwd2
,
static_residual_out_fwd2
)
# --- Capture forward_3 ---
graph_fwd3
=
torch
.
cuda
.
CUDAGraph
()
# 创建输入buffers (重用之前的)
static_hidden_states_fwd3
=
torch
.
randn
((
total_tokens
,
self
.
hidden_size
),
dtype
=
hs_dtype
,
device
=
device
)
static_residual_fwd3
=
torch
.
randn
((
total_tokens
,
self
.
hidden_size
),
dtype
=
hs_dtype
,
device
=
device
)
# Warmup forward_3
_
,
_
=
self
.
_forward_3_impl
(
static_hidden_states_fwd3
,
static_residual_fwd3
)
# Capture forward_3 - 使用torch.cuda.graph()和共享memory pool
with
torch
.
cuda
.
graph
(
graph_fwd3
,
pool
=
_graph_memory_pool
,
stream
=
stream
):
static_ffn_output_fwd3
,
static_router_logits_fwd3
=
self
.
_forward_3_impl
(
static_hidden_states_fwd3
,
static_residual_fwd3
)
self
.
graph_runners_fwd3
[
total_tokens
]
=
(
graph_fwd3
,
static_hidden_states_fwd3
,
static_residual_fwd3
,
static_ffn_output_fwd3
,
static_router_logits_fwd3
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream
)
self
.
cuda_graphs_captured
=
True
def
_ensure_cuda_graphs_captured
(
self
,
device
:
torch
.
device
,
hs_dtype
:
torch
.
dtype
,
pos_dtype
:
torch
.
dtype
):
if
not
self
.
cuda_graphs_captured
and
self
.
should_capture_graph
:
self
.
_capture_cuda_graph
(
device
,
hs_dtype
,
pos_dtype
)
# Separate implementation logic from graph handling
def
_forward_1_impl
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
):
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
# q, _ = self.self_attn.q_proj(hidden_states)
# kv, _ = self.self_attn.kv_proj(hidden_states)
# k, v = kv.split([self.self_attn.kv_size, self.self_attn.kv_size], dim=-1)
qkv
,
_
=
self
.
self_attn
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
self_attn
.
q_size
,
self
.
self_attn
.
kv_size
,
self
.
self_attn
.
kv_size
],
dim
=-
1
)
q
=
self
.
self_attn
.
inter_norm
(
q
.
contiguous
())
q
=
self
.
self_attn
.
wq
(
q
)[
0
]
q
,
k
=
self
.
self_attn
.
rotary_emb
(
positions
,
q
,
k
)
return
q
,
k
,
v
def
forward_1
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
):
if
self
.
should_capture_graph
:
self
.
_ensure_cuda_graphs_captured
(
hidden_states
.
device
,
hidden_states
.
dtype
,
positions
.
dtype
)
graph_key
=
(
hidden_states
.
shape
[
0
]
+
self
.
graph_token_step
-
1
)
//
self
.
graph_token_step
*
self
.
graph_token_step
graph_data
=
self
.
graph_runners_fwd1
.
get
(
graph_key
)
if
self
.
cuda_graphs_captured
else
None
use_graph
=
graph_data
is
not
None
and
hidden_states
.
shape
[
0
]
<=
self
.
max_graph_tokens
if
use_graph
:
graph
,
static_pos_view
,
static_hs_view
,
static_q
,
static_k
,
static_v
=
graph_data
actual_tokens
=
hidden_states
.
shape
[
0
]
static_pos_view
[:
actual_tokens
].
copy_
(
positions
)
static_hs_view
[:
actual_tokens
].
copy_
(
hidden_states
)
graph
.
replay
()
return
static_q
[:
actual_tokens
],
static_k
[:
actual_tokens
],
static_v
[:
actual_tokens
]
# Fallback to eager execution
return
self
.
_forward_1_impl
(
positions
,
hidden_states
)
# Separate implementation logic from graph handling
def
_forward_2_impl
(
self
,
attn_output
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
):
hidden_states
,
_
=
self
.
self_attn
.
o_proj
(
attn_output
)
hidden_states
+=
residual
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
return
hidden_states
,
residual
def
forward_2
(
self
,
attn_output
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
):
if
self
.
should_capture_graph
:
graph_key
=
(
attn_output
.
shape
[
0
]
+
self
.
graph_token_step
-
1
)
//
self
.
graph_token_step
*
self
.
graph_token_step
graph_data
=
self
.
graph_runners_fwd2
.
get
(
graph_key
)
if
self
.
cuda_graphs_captured
else
None
use_graph
=
graph_data
is
not
None
and
attn_output
.
shape
[
0
]
<=
self
.
max_graph_tokens
if
use_graph
:
graph
,
static_attn_output_view
,
static_residual_view
,
static_hs_out
,
static_residual_out
=
graph_data
actual_tokens
=
attn_output
.
shape
[
0
]
static_attn_output_view
[:
actual_tokens
].
copy_
(
attn_output
)
static_residual_view
[:
actual_tokens
].
copy_
(
residual
)
graph
.
replay
()
return
static_hs_out
[:
actual_tokens
],
static_residual_out
[:
actual_tokens
]
# Fallback to eager execution
return
self
.
_forward_2_impl
(
attn_output
,
residual
)
# Separate implementation logic from graph handling
def
_forward_3_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
):
if
self
.
use_moe
:
ffn_output
=
self
.
share_expert
(
hidden_states
)
router_logits
,
_
=
self
.
moe
.
gate
(
hidden_states
)
else
:
ffn_output
=
self
.
mlp
(
hidden_states
)
router_logits
=
None
return
ffn_output
+
residual
,
router_logits
# Base output before potential MoE addition
def
forward_3
(
self
,
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
):
if
self
.
should_capture_graph
:
graph_key
=
(
hidden_states
.
shape
[
0
]
+
self
.
graph_token_step
-
1
)
//
self
.
graph_token_step
*
self
.
graph_token_step
graph_data
=
self
.
graph_runners_fwd3
.
get
(
graph_key
)
if
self
.
cuda_graphs_captured
else
None
use_graph
=
graph_data
is
not
None
and
hidden_states
.
shape
[
0
]
<=
self
.
max_graph_tokens
if
use_graph
:
graph
,
static_hs_view
,
static_residual_view
,
static_ffn_output
,
static_router_logits
=
graph_data
actual_tokens
=
hidden_states
.
shape
[
0
]
static_hs_view
[:
actual_tokens
].
copy_
(
hidden_states
)
static_residual_view
[:
actual_tokens
].
copy_
(
residual
)
graph
.
replay
()
return
static_ffn_output
[:
actual_tokens
],
static_router_logits
[:
actual_tokens
]
if
static_router_logits
is
not
None
else
None
# Fallback to eager execution
return
self
.
_forward_3_impl
(
hidden_states
,
residual
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
should_capture_graph
:
residual
=
hidden_states
q
,
k
,
v
=
self
.
forward_1
(
positions
,
hidden_states
)
attn_output
=
self
.
self_attn
.
attn
(
q
,
k
,
v
)
hidden_states
,
residual
=
self
.
forward_2
(
attn_output
,
residual
)
ffn_output_plus_residual
,
router_logits
=
self
.
forward_3
(
hidden_states
,
residual
)
if
self
.
use_moe
:
moe_output
=
self
.
moe
.
experts
(
hidden_states
,
router_logits
)
hidden_states
=
ffn_output_plus_residual
+
moe_output
else
:
hidden_states
=
ffn_output_plus_residual
return
hidden_states
else
:
return
self
.
forward_old
(
positions
,
hidden_states
)
def
forward_old
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
)
hidden_states
+=
residual
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
if
self
.
use_moe
:
share_output
=
self
.
share_expert
(
hidden_states
)
moe_output
=
self
.
moe
(
hidden_states
)
ffn_output
=
share_output
+
moe_output
else
:
ffn_output
=
self
.
mlp
(
hidden_states
)
hidden_states
=
ffn_output
+
residual
return
hidden_states
class
Step2MiniModel
(
nn
.
Module
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
use_fused_moe
:
bool
=
False
)
->
None
:
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
vocab_size
=
config
.
vocab_size
self
.
config
=
config
self
.
use_fused_moe
=
use_fused_moe
if
get_pp_group
().
is_first_rank
or
(
config
.
tie_word_embeddings
and
get_pp_group
().
is_last_rank
):
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
vocab_size
,
config
.
hidden_size
,
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
Step2MiniDecoderLayer
(
config
=
vllm_config
.
model_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
use_fused_moe
=
self
.
use_fused_moe
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
,
)
if
get_pp_group
().
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
})
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
@
support_torch_compile
class
Step3FlashModelFusedMoE
(
Step2MiniModel
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
(
vllm_config
,
prefix
,
use_fused_moe
=
True
)
class
Step2MiniPretrainedModel
(
nn
.
Module
,
SupportsPP
):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
qkv_params_mapping
=
[
# (param_name, shard_name, relative_start_idx, relative_end_idx)
(
".qkv_proj"
,
".q_proj"
,
0
,
self
.
config
.
share_q_dim
/
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
)),
(
".qkv_proj"
,
".k_proj"
,
self
.
config
.
share_q_dim
/
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
),
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
)
/
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
)),
(
".qkv_proj"
,
".v_proj"
,
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
)
/
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
),
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
)
/
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
)),
]
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
()
params_need_to_load
=
set
()
if
self
.
model
.
use_fused_moe
:
if
self
.
vllm_config
.
quant_config
is
not
None
and
self
.
vllm_config
.
quant_config
.
get_name
()
==
"groupwise_quant"
:
expert_params_mapping
=
[
(
".moe.experts.w13_weight"
,
".moe.gate_proj.qweight"
,
"w1"
),
(
".moe.experts.w13_weight"
,
".moe.up_proj.qweight"
,
"w3"
),
(
".moe.experts.w2_weight"
,
".moe.down_proj.qweight"
,
"w2"
),
(
".moe.experts.w13_weight_scale"
,
".moe.gate_proj.scales"
,
"w1"
),
(
".moe.experts.w13_weight_scale"
,
".moe.up_proj.scales"
,
"w3"
),
(
".moe.experts.w2_weight_scale"
,
".moe.down_proj.scales"
,
"w2"
),
]
else
:
expert_params_mapping
=
[
(
".moe.experts.w13_weight"
,
".moe.gate_proj.weight"
,
"w1"
),
(
".moe.experts.w13_weight"
,
".moe.up_proj.weight"
,
"w3"
),
(
".moe.experts.w2_weight"
,
".moe.down_proj.weight"
,
"w2"
)
]
else
:
expert_params_mapping
=
[]
disable_moe_stacked_params
=
[
data
[
1
]
for
data
in
expert_params_mapping
]
for
name
,
loaded_weight
in
weights
:
# continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
if
any
(
disable_moe_stacked_param
in
name
for
disable_moe_stacked_param
in
disable_moe_stacked_params
):
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
loaded_params
.
add
(
name
)
break
else
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
# Skip loading extra bias for GPTQ models.
if
((
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
))
and
name
not
in
params_dict
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
for
expert_id
in
range
(
loaded_weight
.
shape
[
0
]):
loaded_weight_expert
=
loaded_weight
[
expert_id
]
weight_loader
(
param
,
loaded_weight_expert
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
loaded_params
.
add
(
name
)
break
else
:
for
(
param_name
,
weight_name
,
start_idx
,
end_idx
)
in
qkv_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
dim
=
param
.
shape
[
param
.
output_dim
]
begin_idx
=
int
(
start_idx
*
dim
)
end_idx
=
int
(
end_idx
*
dim
)
param_slice
=
param
.
narrow
(
param
.
output_dim
,
begin_idx
,
end_idx
-
begin_idx
)
param_slice
.
copy_
(
loaded_weight
)
loaded_params
.
add
(
name
)
break
else
:
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
for
name
in
params_dict
:
params_need_to_load
.
add
(
name
)
if
params_need_to_load
!=
loaded_params
:
param_name_example
=
list
(
params_need_to_load
-
loaded_params
)[
0
]
raise
RuntimeError
(
f
"Some parameters like
{
param_name_example
}
are not in the checkpoint and will falsely use random initialization"
)
class
Step2MiniForCausalLM
(
Step2MiniPretrainedModel
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
vllm_config
=
vllm_config
# FIXME: hack for step3 flash model
if
self
.
config
.
num_hidden_layers
==
42
:
self
.
model
=
Step2MiniModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
else
:
self
.
model
=
Step3FlashModelFusedMoE
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
need_fp32_logits
=
False
)
self
.
sampler
=
get_sampler
()
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
):
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
class
Step2MiniForSequenceClassification
(
Step2MiniPretrainedModel
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
config
=
vllm_config
.
model_config
.
hf_config
self
.
model
=
Step2MiniModel
(
vllm_config
,
prefix
=
prefix
)
if
get_pp_group
().
is_last_rank
:
self
.
score
=
ReplicatedLinear
(
self
.
config
.
hidden_size
,
self
.
config
.
num_labels
,
bias
=
False
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
ALL
,
normalize
=
False
,
softmax
=
False
)
else
:
self
.
_pooler
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
SamplerOutput
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
logits
,
_
=
self
.
score
(
hidden_states
)
ret
=
self
.
_pooler
(
logits
,
pooling_metadata
)
return
ret
def
sequence_flops
(
self
,
input_length
,
context_length
):
output_flops
=
1
*
self
.
config
.
hidden_size
*
self
.
config
.
num_labels
*
2.0
/
1e12
return
super
().
sequence_flops
(
input_length
,
context_length
)
+
output_flops
\ No newline at end of file
vllm/model_executor/models/step_encoder.py
0 → 100644
View file @
583034f1
# SPDX-License-Identifier: Apache-2.0
import
math
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
import
torchvision
#from optimus import flash_attn_func
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
torchvision.transforms.functional
import
InterpolationMode
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.layernorm
import
OptimusLayerNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.transformers_utils.configs
import
CLIPVisionConfig
def
get_abs_pos
(
abs_pos
,
tgt_size
):
dim
=
abs_pos
.
size
(
-
1
)
abs_pos_new
=
abs_pos
.
squeeze
(
0
)
cls_token
,
old_pos_embed
=
abs_pos_new
[:
1
],
abs_pos_new
[
1
:]
src_size
=
int
(
math
.
sqrt
(
abs_pos_new
.
shape
[
0
]
-
1
))
tgt_size
=
int
(
math
.
sqrt
(
tgt_size
))
dtype
=
abs_pos
.
dtype
if
src_size
!=
tgt_size
:
old_pos_embed
=
old_pos_embed
.
view
(
1
,
src_size
,
src_size
,
dim
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
old_pos_embed
=
old_pos_embed
.
to
(
torch
.
float32
)
new_pos_embed
=
F
.
interpolate
(
old_pos_embed
,
size
=
(
tgt_size
,
tgt_size
),
mode
=
'bicubic'
,
antialias
=
True
,
align_corners
=
False
,
).
to
(
dtype
)
new_pos_embed
=
new_pos_embed
.
permute
(
0
,
2
,
3
,
1
)
new_pos_embed
=
new_pos_embed
.
view
(
tgt_size
*
tgt_size
,
dim
)
vision_pos_embed
=
torch
.
cat
([
cls_token
,
new_pos_embed
],
dim
=
0
)
vision_pos_embed
=
vision_pos_embed
.
view
(
1
,
tgt_size
*
tgt_size
+
1
,
dim
)
return
vision_pos_embed
else
:
return
abs_pos
class
StepCLIPVisionEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPVisionConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
image_size
=
config
.
image_size
self
.
patch_size
=
config
.
patch_size
self
.
class_embedding
=
nn
.
Parameter
(
torch
.
randn
(
1
,
self
.
embed_dim
))
self
.
patch_embedding
=
nn
.
Conv2d
(
in_channels
=
config
.
num_channels
,
out_channels
=
self
.
embed_dim
,
kernel_size
=
self
.
patch_size
,
stride
=
self
.
patch_size
,
bias
=
True
,
)
self
.
num_patches
=
(
self
.
image_size
//
self
.
patch_size
)
**
2
self
.
pad_tp_size
=
4
# hard code for padding
# To load the pretrained weights, we still use P+1 as the seqlen
self
.
position_embedding
=
torch
.
nn
.
Embedding
(
self
.
num_patches
+
1
,
self
.
embed_dim
)
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
self
.
num_patches
+
1
).
expand
(
(
1
,
-
1
)),
persistent
=
False
)
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
=
pixel_values
.
shape
[
0
]
patch_embeds
=
self
.
patch_embedding
(
pixel_values
)
# shape = [*, width, grid, grid]
patch_embeds
=
patch_embeds
.
flatten
(
2
).
transpose
(
1
,
2
)
# pad
class_embeds
=
self
.
class_embedding
.
expand
(
batch_size
,
1
,
-
1
)
embeddings
=
torch
.
cat
([
class_embeds
,
patch_embeds
],
dim
=
1
)
embeddings
=
embeddings
+
get_abs_pos
(
self
.
position_embedding
(
self
.
position_ids
),
patch_embeds
.
size
(
1
))
embeddings
=
torch
.
cat
([
embeddings
[:,
0
,
:].
unsqueeze
(
1
).
repeat
(
1
,
self
.
pad_tp_size
-
1
,
1
),
embeddings
],
dim
=
1
)
return
embeddings
class
StepCLIPAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
need_dp
:
bool
=
False
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
total_num_heads
=
config
.
num_attention_heads
self
.
head_dim
=
self
.
embed_dim
//
self
.
total_num_heads
self
.
scale
=
self
.
head_dim
**-
0.5
if
not
need_dp
:
tp_size
=
get_tensor_model_parallel_world_size
()
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
embed_dim
,
self
.
head_dim
,
self
.
total_num_heads
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
)
self
.
out_proj
=
RowParallelLinear
(
self
.
embed_dim
,
self
.
embed_dim
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
)
else
:
self
.
num_heads
=
self
.
total_num_heads
self
.
qkv_proj
=
ReplicatedLinear
(
self
.
embed_dim
,
self
.
embed_dim
*
3
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
,
)
self
.
out_proj
=
ReplicatedLinear
(
self
.
embed_dim
,
self
.
embed_dim
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
,
)
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
).
contiguous
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
residual
=
None
,
layernorm
=
None
,
):
"""Input shape: Batch x Time x Channel"""
if
layernorm
is
not
None
:
hidden_states
=
layernorm
(
hidden_states
)
bsz
,
tgt_len
,
_
=
hidden_states
.
size
()
# get query proj
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
=
q
.
view
(
bsz
,
tgt_len
,
self
.
num_heads
,
self
.
head_dim
)
k
=
k
.
view
(
bsz
,
tgt_len
,
self
.
num_heads
,
self
.
head_dim
)
v
=
v
.
view
(
bsz
,
tgt_len
,
self
.
num_heads
,
self
.
head_dim
)
# if self.head_dim % 16 != 0 or (self.head_dim != 64
# and self.head_dim != 128):
if
True
:
q
=
q
.
transpose
(
1
,
2
)
k
=
k
.
transpose
(
1
,
2
)
v
=
v
.
transpose
(
1
,
2
)
attn_output
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
scale
=
self
.
scale
,
is_causal
=
False
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
reshape
(
bsz
,
tgt_len
,
self
.
num_heads
*
self
.
head_dim
)
# else:
# attn_output = flash_attn_func(q,
# k,
# v,
# softmax_scale=self.scale,
# causal=False)
# attn_output = attn_output.view(bsz, tgt_len,
# self.num_heads * self.head_dim)
attn_output
,
_
=
self
.
out_proj
(
attn_output
,
residual
=
residual
)
return
attn_output
class
StepCLIPMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
need_dp
:
bool
=
False
):
super
().
__init__
()
self
.
config
=
config
self
.
activation_fn
=
get_act_fn
(
config
.
hidden_act
)
if
not
need_dp
:
self
.
fc1
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
)
self
.
fc2
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
)
else
:
self
.
fc1
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
)
self
.
fc2
=
ReplicatedLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
residual
=
None
,
layernorm
=
None
)
->
torch
.
Tensor
:
if
layernorm
is
not
None
:
hidden_states
=
layernorm
(
hidden_states
)
hidden_states
,
_
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
hidden_states
)
hidden_states
,
_
=
self
.
fc2
(
hidden_states
,
residual
=
residual
)
return
hidden_states
class
StepCLIPEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
need_dp
:
bool
=
False
):
super
().
__init__
()
self
.
embed_dim
=
config
.
hidden_size
self
.
self_attn
=
StepCLIPAttention
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
need_dp
=
need_dp
)
self
.
layer_norm1
=
OptimusLayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
mlp
=
StepCLIPMLP
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
need_dp
=
need_dp
)
self
.
layer_norm2
=
OptimusLayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
FloatTensor
:
residual
=
self
.
layer_norm1
(
self
.
self_attn
(
hidden_states
=
hidden_states
,
residual
=
None
,
layernorm
=
None
))
h
=
hidden_states
+
residual
out
=
h
+
self
.
layer_norm2
(
self
.
mlp
(
h
))
return
out
class
StepCLIPEncoder
(
nn
.
Module
):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`CLIPEncoderLayer`].
Args:
config: CLIPConfig
"""
def
__init__
(
self
,
config
:
CLIPVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
need_dp
:
bool
=
False
):
super
().
__init__
()
self
.
config
=
config
self
.
layers
=
nn
.
ModuleList
([
StepCLIPEncoderLayer
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.layers.
{
i
}
"
,
need_dp
=
need_dp
)
for
i
in
range
(
config
.
num_hidden_layers
)
])
def
forward
(
self
,
inputs_embeds
,
):
hidden_states
=
inputs_embeds
for
encoder_layer
in
self
.
layers
:
hidden_states
=
encoder_layer
(
hidden_states
,
)
return
hidden_states
class
StepCLIPVisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
need_dp
:
bool
=
False
):
super
().
__init__
()
self
.
config
=
config
self
.
image_size
=
config
.
image_size
self
.
vision_model_preprocessor
=
torchvision
.
transforms
.
Resize
(
(
self
.
image_size
,
self
.
image_size
),
interpolation
=
InterpolationMode
.
BICUBIC
,
antialias
=
True
)
self
.
embeddings
=
StepCLIPVisionEmbeddings
(
config
)
self
.
transformer
=
StepCLIPEncoder
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.transformer"
,
need_dp
=
need_dp
)
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
):
hidden_states
=
self
.
embeddings
(
pixel_values
)
hidden_states
=
self
.
transformer
(
inputs_embeds
=
hidden_states
)
return
hidden_states
,
None
class
StepCLIPVisionModel
(
nn
.
Module
):
_PARAMS_KEYS_TO_SELECT
=
[
"vision_model"
]
def
__init__
(
self
,
config
:
CLIPVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
need_dp
:
bool
=
False
):
super
().
__init__
()
quant_config
=
None
# FIXME(ys): step encoder does not support quantization
self
.
vision_model
=
StepCLIPVisionTransformer
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.vision_model"
,
need_dp
=
need_dp
)
def
get_input_embeddings
(
self
)
->
nn
.
Module
:
return
self
.
vision_model
.
embeddings
.
patch_embedding
def
forward
(
self
,
pixel_values
:
Optional
[
torch
.
FloatTensor
]
=
None
,
):
return
self
.
vision_model
(
pixel_values
=
pixel_values
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
_params_to_ignore
=
[
"text_model"
,
"logit_scale"
,
"vision_model.embeddings.position_ids"
,
"visual_projection.weight"
,
"text_projection.weight"
]
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
()
for
name
,
loaded_weight
in
weights
:
if
any
(
param_name
in
name
for
param_name
in
_params_to_ignore
):
continue
if
not
(
any
(
param_name
in
name
for
param_name
in
self
.
_PARAMS_KEYS_TO_SELECT
)):
continue
if
name
.
startswith
(
"model.vision_tower.vision_tower"
):
name
=
name
.
replace
(
"model.vision_tower.vision_tower."
,
""
)
elif
name
.
startswith
(
"model.vision_tower"
):
name
=
name
.
replace
(
"model.vision_tower."
,
""
)
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
.
split
(
"."
):
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
loaded_params
.
add
(
name
)
break
else
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
params_need_to_load
=
set
(
params_dict
.
keys
())
if
params_need_to_load
!=
loaded_params
:
param_name_example
=
list
(
params_need_to_load
-
loaded_params
)[
0
]
raise
RuntimeError
(
f
"Some parameters like
{
param_name_example
}
are not in the checkpoint and will falsely use random initialization"
)
class
StepCLIPVisionModelWithPostprocess
(
StepCLIPVisionModel
):
_PARAMS_KEYS_TO_SELECT
=
[
"vision_model"
,
"vit_downsampler"
]
def
__init__
(
self
,
config
:
CLIPVisionConfig
,
need_dp
:
bool
=
True
):
super
().
__init__
(
config
,
need_dp
=
need_dp
)
self
.
config
=
config
self
.
vit_downsampler
=
nn
.
Conv2d
(
self
.
config
.
hidden_size
,
self
.
config
.
output_hidden_size
,
kernel_size
=
2
,
stride
=
2
)
self
.
vit_downsampler2
=
nn
.
Conv2d
(
self
.
config
.
output_hidden_size
,
self
.
config
.
output_hidden_size
*
2
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
x
=
super
().
forward
(
x
)[
0
][:,
4
:]
B
,
P
=
x
.
shape
[:
2
]
HW
=
int
(
math
.
sqrt
(
P
))
x
=
x
.
permute
(
0
,
2
,
1
).
view
(
B
,
self
.
config
.
hidden_size
,
HW
,
HW
)
x
=
self
.
vit_downsampler
(
x
)
x
=
self
.
vit_downsampler2
(
x
)
x
=
x
.
view
(
B
,
self
.
config
.
output_hidden_size
*
2
,
-
1
).
permute
(
0
,
2
,
1
)
return
x
\ No newline at end of file
vllm/transformers_utils/config.py
View file @
583034f1
...
...
@@ -41,6 +41,15 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
OvisConfig
,
RWConfig
,
Step3TextConfig
,
Step3VLConfig
,
SkyworkR1VChatConfig
,
SolarConfig
,
MMGPTStep1Config
,
MMGPTStep1ConfigV2
,
MPTConfig
,
NemotronConfig
,
NVLM_D_Config
,
RWConfig
,
SkyworkR1VChatConfig
,
SolarConfig
,
Step1AudioConfig
,
Step1Config
,
Step1oConfig
,
Step2Config
,
Step2MiniConfig
,
Step3vConfig
,
StepAudioQwen2Config
,
Telechat2Config
,
UltravoxConfig
)
# yapf: enable
from
vllm.transformers_utils.utils
import
check_gguf_file
...
...
@@ -75,6 +84,20 @@ _CONFIG_REGISTRY_OVERRIDE_HF: dict[str, type[PretrainedConfig]] = {
"mllama"
:
MllamaConfig
}
_CUSTOM_CONFIG_STEP
=
{
"step1"
:
Step1Config
,
"step2"
:
Step2Config
,
"step2_mini"
:
Step2MiniConfig
,
"mmgpt_step1"
:
MMGPTStep1Config
,
"mmgpt_step1_v2"
:
MMGPTStep1ConfigV2
,
#"mmgpt_qwen2": MMGPTQwen2Config,
#"mmgpt_qwen2_v2": MMGPTQwen2ConfigV2,
"step1o"
:
Step1oConfig
,
"step1_audio"
:
Step1AudioConfig
,
"step_audio_qwen2"
:
StepAudioQwen2Config
,
"step3v"
:
Step3vConfig
,
}
_CONFIG_REGISTRY
:
dict
[
str
,
type
[
PretrainedConfig
]]
=
{
"chatglm"
:
ChatGLMConfig
,
"cohere2"
:
Cohere2Config
,
...
...
@@ -100,7 +123,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
"ultravox"
:
UltravoxConfig
,
"step3_vl"
:
Step3VLConfig
,
"step3_text"
:
Step3TextConfig
,
**
_CONFIG_REGISTRY_OVERRIDE_HF
**
_CONFIG_REGISTRY_OVERRIDE_HF
,
**
_CUSTOM_CONFIG_STEP
}
_CONFIG_ATTRS_MAPPING
:
dict
[
str
,
str
]
=
{
...
...
vllm/transformers_utils/configs/__init__.py
View file @
583034f1
...
...
@@ -32,6 +32,18 @@ from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig,
Step3VisionEncoderConfig
,
Step3VLConfig
)
from
vllm.transformers_utils.configs.ultravox
import
UltravoxConfig
from
vllm.transformers_utils.configs.mmgpt
import
(
CLIPVisionConfig
,
MMGPTQwen2Config
,
MMGPTQwen2ConfigV2
,
MMGPTStep1Config
,
MMGPTStep1ConfigV2
,
SamViTConfig
,
Step1oConfig
,
Step3vConfig
)
from
vllm.transformers_utils.configs.step
import
(
Step1Config
,
Step2Config
,
Step2MiniConfig
)
from
vllm.transformers_utils.configs.step1f
import
(
Step1AudioConfig
,
Step1fAudioEncoderConfig
,
StepAudioQwen2Config
)
__all__
=
[
"ChatGLMConfig"
,
...
...
@@ -62,4 +74,21 @@ __all__ = [
"Step3VLConfig"
,
"Step3VisionEncoderConfig"
,
"Step3TextConfig"
,
"Step1Config"
,
"Step2Config"
,
"Step2MiniConfig"
,
"CLIPVisionConfig"
,
"MMGPTBaiChuanConfig"
,
"MMGPTLlamaConfig"
,
"MMGPTLlamaConfigV2"
,
"MMGPTQwen2Config"
,
"MMGPTQwen2ConfigV2"
,
"MMGPTStep1Config"
,
"MMGPTStep1ConfigV2"
,
"Step3vConfig"
,
"SamViTConfig"
,
"Step1oConfig"
,
"Step1AudioConfig"
,
"Step1fAudioEncoderConfig"
,
"StepAudioQwen2Config"
,
]
vllm/transformers_utils/configs/mmgpt.py
0 → 100644
View file @
583034f1
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
List
,
Optional
,
Union
from
transformers
import
Qwen2Config
from
transformers.configuration_utils
import
PretrainedConfig
from
vllm.transformers_utils.configs.step
import
Step1Config
,
Step2MiniConfig
class
CLIPVisionConfig
(
PretrainedConfig
):
model_type
=
"clip_vision_model"
def
__init__
(
self
,
hidden_size
=
768
,
intermediate_size
=
3072
,
projection_dim
=
512
,
num_hidden_layers
=
12
,
num_attention_heads
=
12
,
num_channels
=
3
,
image_size
=
224
,
patch_size
=
32
,
hidden_act
=
"quick_gelu"
,
layer_norm_eps
=
1e-5
,
attention_dropout
=
0.0
,
initializer_range
=
0.02
,
initializer_factor
=
1.0
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
projection_dim
=
projection_dim
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
num_channels
=
num_channels
self
.
patch_size
=
patch_size
self
.
image_size
=
image_size
self
.
layer_norm_eps
=
layer_norm_eps
self
.
hidden_act
=
hidden_act
self
.
attention_dropout
=
attention_dropout
self
.
initializer_range
=
initializer_range
self
.
initializer_factor
=
initializer_factor
class
SamViTConfig
(
PretrainedConfig
):
model_type
=
"sam_vit_model"
def
__init__
(
self
,
depth
=
24
,
embed_dim
=
1024
,
image_size
=
1280
,
mlp_ratio
=
4
,
num_heads
=
16
,
patch_size
=
16
,
qkv_bias
=
True
,
use_abs_pos
=
True
,
use_rel_pos
=
True
,
global_attn_indexes
=
(
5
,
11
,
17
,
23
),
window_size
=
14
,
out_channels
=
256
,
layer_norm_eps
=
1e-6
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
depth
=
depth
self
.
embed_dim
=
embed_dim
self
.
image_size
=
image_size
self
.
mlp_ratio
=
mlp_ratio
self
.
num_heads
=
num_heads
self
.
patch_size
=
patch_size
self
.
qkv_bias
=
qkv_bias
self
.
use_abs_pos
=
use_abs_pos
self
.
use_rel_pos
=
use_rel_pos
self
.
global_attn_indexes
=
global_attn_indexes
self
.
window_size
=
window_size
self
.
out_channels
=
out_channels
self
.
layer_norm_eps
=
layer_norm_eps
class
MMGPTStep1Config
(
Step1Config
):
# for step1.5
model_type
=
"mmgpt_step1"
def
__init__
(
self
,
hidden_size
:
int
=
5120
,
intermediate_size
:
int
=
13312
,
num_attention_heads
:
int
=
40
,
num_attention_groups
:
int
=
8
,
num_hidden_layers
:
int
=
48
,
max_seq_len
:
int
=
4096
,
vocab_size
:
int
=
65536
,
rms_norm_eps
:
float
=
1e-5
,
use_im_start_end
=
True
,
vision_select_layer
=-
2
,
image_token_len
=
None
,
projector_stride
=
1
,
vision_tower_config
=
None
,
image_token_id
=
13
,
image_seq_length
=
400
,
bos_token_id
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
eos_token_id
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
**
kwargs
,
)
->
None
:
super
().
__init__
(
bos_token_id
=
1
if
bos_token_id
is
None
else
bos_token_id
,
eos_token_id
=
[
2
,
3
]
if
eos_token_id
is
None
else
eos_token_id
,
**
kwargs
)
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_groups
=
num_attention_groups
self
.
num_hidden_layers
=
num_hidden_layers
self
.
max_seq_len
=
max_seq_len
self
.
vocab_size
=
vocab_size
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_im_start_end
=
use_im_start_end
self
.
vision_select_layer
=
vision_select_layer
self
.
image_token_len
=
image_token_len
self
.
projector_stride
=
projector_stride
self
.
image_token_id
=
image_token_id
self
.
image_seq_length
=
image_seq_length
self
.
vision_tower_config
=
CLIPVisionConfig
(
**
vision_tower_config
)
if
vision_tower_config
is
not
None
else
None
self
.
text_config
=
Step1Config
(
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
num_attention_heads
=
num_attention_heads
,
num_attention_groups
=
num_attention_groups
,
num_hidden_layers
=
num_hidden_layers
,
max_seq_len
=
max_seq_len
,
vocab_size
=
vocab_size
,
rms_norm_eps
=
rms_norm_eps
,
architectures
=
[
"Step1ForCausalLM"
],
torch_dtype
=
getattr
(
self
,
"torch_dtype"
,
"bfloat16"
),
)
class
MMGPTStep1ConfigV2
(
Step1Config
):
# for step1.5c/step1u, models with both vit and sam encoders
model_type
=
"mmgpt_step1_v2"
def
__init__
(
self
,
hidden_size
:
int
=
5120
,
intermediate_size
:
int
=
13312
,
num_attention_heads
:
int
=
40
,
num_attention_groups
:
int
=
8
,
num_hidden_layers
:
int
=
48
,
max_seq_len
:
int
=
4096
,
vocab_size
:
int
=
65536
,
rms_norm_eps
:
float
=
1e-5
,
use_im_start_end
=
True
,
vision_select_layer
=-
1
,
image_token_len
=
None
,
understand_projector_stride
=
1
,
vit_scale
=
1.0
,
projector_bias
=
True
,
vision_tower_config
=
None
,
sam_model_config
=
None
,
image_token_id
=
13
,
bos_token_id
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
eos_token_id
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
**
kwargs
,
)
->
None
:
super
().
__init__
(
bos_token_id
=
1
if
bos_token_id
is
None
else
bos_token_id
,
eos_token_id
=
[
2
,
3
]
if
eos_token_id
is
None
else
eos_token_id
,
**
kwargs
)
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_groups
=
num_attention_groups
self
.
num_hidden_layers
=
num_hidden_layers
self
.
max_seq_len
=
max_seq_len
self
.
vocab_size
=
vocab_size
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_im_start_end
=
use_im_start_end
self
.
vision_select_layer
=
vision_select_layer
self
.
image_token_len
=
image_token_len
self
.
image_token_id
=
image_token_id
self
.
understand_projector_stride
=
understand_projector_stride
self
.
vit_scale
=
vit_scale
self
.
projector_bias
=
projector_bias
self
.
vision_tower_config
=
CLIPVisionConfig
(
**
vision_tower_config
)
if
vision_tower_config
is
not
None
else
None
self
.
sam_model_config
=
SamViTConfig
(
**
sam_model_config
)
if
sam_model_config
is
not
None
else
None
self
.
text_config
=
Step1Config
(
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
num_attention_heads
=
num_attention_heads
,
num_attention_groups
=
num_attention_groups
,
num_hidden_layers
=
num_hidden_layers
,
max_seq_len
=
max_seq_len
,
vocab_size
=
vocab_size
,
rms_norm_eps
=
rms_norm_eps
,
architectures
=
[
"Step1ForCausalLM"
],
torch_dtype
=
getattr
(
self
,
"torch_dtype"
,
"bfloat16"
),
)
class
Step1oConfig
(
Step1Config
):
# for step1o
model_type
=
"step1o"
def
__init__
(
self
,
hidden_size
:
int
=
5120
,
intermediate_size
:
int
=
13312
,
num_attention_heads
:
int
=
40
,
num_attention_groups
:
int
=
8
,
num_hidden_layers
:
int
=
48
,
max_seq_len
:
int
=
4096
,
vocab_size
:
int
=
65536
,
rms_norm_eps
:
float
=
1e-5
,
use_im_start_end
=
True
,
vision_select_layer
=-
1
,
image_token_len
=
None
,
image_token_id
=
13
,
understand_projector_stride
=
1
,
vit_scale
=
1.0
,
projector_bias
=
True
,
patch_token_len
=
None
,
vision_tower_config
=
None
,
bos_token_id
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
eos_token_id
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
**
kwargs
,
)
->
None
:
super
().
__init__
(
bos_token_id
=
1
if
bos_token_id
is
None
else
bos_token_id
,
eos_token_id
=
[
2
,
3
]
if
eos_token_id
is
None
else
eos_token_id
,
**
kwargs
)
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_groups
=
num_attention_groups
self
.
num_hidden_layers
=
num_hidden_layers
self
.
max_seq_len
=
max_seq_len
self
.
vocab_size
=
vocab_size
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_im_start_end
=
use_im_start_end
self
.
vision_select_layer
=
vision_select_layer
self
.
image_token_len
=
image_token_len
self
.
image_token_id
=
image_token_id
self
.
understand_projector_stride
=
understand_projector_stride
self
.
vit_scale
=
vit_scale
self
.
projector_bias
=
projector_bias
self
.
patch_token_len
=
patch_token_len
if
patch_token_len
is
not
None
else
self
.
image_token_len
self
.
vision_tower_config
=
CLIPVisionConfig
(
**
vision_tower_config
)
if
vision_tower_config
is
not
None
else
None
self
.
text_config
=
Step1Config
(
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
num_attention_heads
=
num_attention_heads
,
num_attention_groups
=
num_attention_groups
,
num_hidden_layers
=
num_hidden_layers
,
max_seq_len
=
max_seq_len
,
vocab_size
=
vocab_size
,
rms_norm_eps
=
rms_norm_eps
,
architectures
=
[
"Step1ForCausalLM"
],
torch_dtype
=
getattr
(
self
,
"torch_dtype"
,
"bfloat16"
),
)
class
MMGPTQwen2Config
(
PretrainedConfig
):
# for step1.5t
model_type
=
"mmgpt_qwen2"
def
__init__
(
self
,
vocab_size
=
64012
,
hidden_size
=
4096
,
intermediate_size
=
11008
,
num_hidden_layers
=
48
,
num_attention_heads
=
32
,
num_attention_groups
=
4
,
num_key_value_heads
=
4
,
hidden_act
=
"silu"
,
max_position_embeddings
=
8192
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
rope_theta
=
1000000.0
,
rope_scaling
=
None
,
use_im_start_end
=
True
,
vision_select_layer
=-
1
,
image_token_len
=
None
,
image_token_id
=
151656
,
understand_projector_stride
=
1
,
vit_scale
=
1.0
,
projector_bias
=
True
,
pad_token_id
=-
1
,
vision_tower_config
=
None
,
sam_model_config
=
None
,
eos_token_id
=
None
,
**
kwargs
,
)
->
None
:
if
eos_token_id
is
not
None
:
if
isinstance
(
eos_token_id
,
list
):
eos_token_id
=
list
(
set
([
151643
,
151646
]
+
eos_token_id
))
else
:
eos_token_id
=
[
151643
,
151646
,
eos_token_id
]
else
:
eos_token_id
=
[
151643
,
151646
]
super
().
__init__
(
eos_token_id
=
eos_token_id
,
**
kwargs
)
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_groups
=
num_attention_groups
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
max_position_embeddings
=
max_position_embeddings
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
use_im_start_end
=
use_im_start_end
self
.
vision_select_layer
=
vision_select_layer
self
.
image_token_len
=
image_token_len
self
.
image_token_id
=
image_token_id
self
.
understand_projector_stride
=
understand_projector_stride
self
.
vit_scale
=
vit_scale
self
.
projector_bias
=
projector_bias
self
.
pad_token_id
=
pad_token_id
self
.
vision_tower_config
=
CLIPVisionConfig
(
**
vision_tower_config
)
if
vision_tower_config
is
not
None
else
None
self
.
sam_model_config
=
SamViTConfig
(
**
sam_model_config
)
if
sam_model_config
is
not
None
else
None
self
.
text_config
=
Qwen2Config
(
vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
num_hidden_layers
=
num_hidden_layers
,
num_attention_heads
=
num_attention_heads
,
num_key_value_heads
=
num_key_value_heads
,
hidden_act
=
hidden_act
,
max_position_embeddings
=
max_position_embeddings
,
initializer_range
=
initializer_range
,
rms_norm_eps
=
rms_norm_eps
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
architectures
=
[
"Qwen2ForCausalLM"
],
torch_dtype
=
getattr
(
self
,
"torch_dtype"
,
"bfloat16"
),
)
class
MMGPTQwen2ConfigV2
(
MMGPTQwen2Config
):
model_type
=
"mmgpt_qwen2_v2"
def
__init__
(
self
,
vocab_size
=
64012
,
hidden_size
=
4096
,
intermediate_size
=
11008
,
num_hidden_layers
=
48
,
num_attention_heads
=
32
,
num_attention_groups
=
4
,
num_key_value_heads
=
4
,
hidden_act
=
"silu"
,
max_position_embeddings
=
8192
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
rope_theta
=
1000000.0
,
rope_scaling
=
None
,
use_im_start_end
=
True
,
vision_select_layer
=-
1
,
image_token_len
=
None
,
image_token_id
=
151675
,
understand_projector_stride
=
1
,
vit_scale
=
1.0
,
projector_bias
=
True
,
pad_token_id
=-
1
,
vision_tower_config
=
None
,
sam_model_config
=
None
,
eos_token_id
=
None
,
**
kwargs
,
)
->
None
:
if
eos_token_id
is
not
None
:
if
isinstance
(
eos_token_id
,
list
):
eos_token_id
=
list
(
set
([
151643
,
151645
,
151665
]
+
eos_token_id
))
else
:
eos_token_id
=
[
151643
,
151645
,
151665
,
eos_token_id
]
else
:
eos_token_id
=
[
151643
,
151645
,
151665
]
super
().
__init__
(
eos_token_id
=
eos_token_id
,
**
kwargs
)
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_groups
=
num_attention_groups
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
max_position_embeddings
=
max_position_embeddings
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
use_im_start_end
=
use_im_start_end
self
.
vision_select_layer
=
vision_select_layer
self
.
image_token_len
=
image_token_len
self
.
image_token_id
=
image_token_id
self
.
understand_projector_stride
=
understand_projector_stride
self
.
vit_scale
=
vit_scale
self
.
projector_bias
=
projector_bias
self
.
pad_token_id
=
pad_token_id
self
.
vision_tower_config
=
CLIPVisionConfig
(
**
vision_tower_config
)
if
vision_tower_config
is
not
None
else
None
self
.
sam_model_config
=
SamViTConfig
(
**
sam_model_config
)
if
sam_model_config
is
not
None
else
None
self
.
text_config
=
Qwen2Config
(
vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
num_hidden_layers
=
num_hidden_layers
,
num_attention_heads
=
num_attention_heads
,
num_key_value_heads
=
num_key_value_heads
,
hidden_act
=
hidden_act
,
max_position_embeddings
=
max_position_embeddings
,
initializer_range
=
initializer_range
,
rms_norm_eps
=
rms_norm_eps
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
architectures
=
[
"Qwen2ForCausalLM"
],
torch_dtype
=
getattr
(
self
,
"torch_dtype"
,
"bfloat16"
),
)
class
Step3vConfig
(
Step1Config
):
model_type
=
"step3v"
def
__init__
(
self
,
hidden_size
:
int
=
5120
,
intermediate_size
:
int
=
13312
,
num_attention_heads
:
int
=
40
,
num_attention_groups
:
int
=
8
,
num_hidden_layers
:
int
=
48
,
max_seq_len
:
int
=
4096
,
vocab_size
:
int
=
65536
,
rms_norm_eps
:
float
=
1e-5
,
moe_every_n_layer
:
int
=
2
,
# 2 means 50% layers use MoE, interleaved with normal non-MoE layers.
use_moe
:
bool
=
False
,
moe_intermediate_size
:
int
=
10240
,
moe_num_experts
:
int
=
16
,
moe_top_k
:
int
=
4
,
max_pos_interp_ratio
:
float
=
1
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
moe_layer_offset
:
int
=
0
,
moe_dynamic_exp_p
:
float
=
1.0
,
rope_theta
:
float
=
500000
,
rope_scaling
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
head_dim
:
Optional
[
int
]
=
None
,
max_position_embedding
:
int
=
16384
,
share_expert_dim
:
Optional
[
int
]
=
None
,
allgather_dtype
:
Optional
[
str
]
=
None
,
share_q_dim
:
Optional
[
int
]
=
None
,
norm_expert_weight
:
bool
=
True
,
moe_layers_enum
:
Optional
[
str
]
=
None
,
use_im_start_end
:
bool
=
True
,
vision_select_layer
:
int
=
-
1
,
image_token_len
:
Optional
[
int
]
=
None
,
image_token_id
:
int
=
128001
,
understand_projector_stride
:
int
=
1
,
vit_scale
:
float
=
1.0
,
projector_bias
:
bool
=
True
,
patch_token_len
:
Optional
[
int
]
=
None
,
vision_tower_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
bos_token_id
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
eos_token_id
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
**
kwargs
,
)
->
None
:
super
().
__init__
(
bos_token_id
=
0
if
bos_token_id
is
None
else
bos_token_id
,
eos_token_id
=
[
1
,
128805
]
if
eos_token_id
is
None
else
eos_token_id
,
**
kwargs
)
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_groups
=
num_attention_groups
self
.
num_hidden_layers
=
num_hidden_layers
self
.
max_seq_len
=
max_seq_len
self
.
vocab_size
=
vocab_size
self
.
rms_norm_eps
=
rms_norm_eps
self
.
moe_every_n_layer
=
moe_every_n_layer
self
.
use_moe
=
use_moe
self
.
moe_intermediate_size
=
moe_intermediate_size
self
.
moe_num_experts
=
moe_num_experts
self
.
moe_top_k
=
moe_top_k
self
.
max_pos_interp_ratio
=
max_pos_interp_ratio
self
.
alibi_slopes
=
alibi_slopes
self
.
moe_layer_offset
=
moe_layer_offset
self
.
moe_dynamic_exp_p
=
moe_dynamic_exp_p
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
head_dim
=
head_dim
self
.
max_position_embedding
=
max_position_embedding
self
.
share_expert_dim
=
share_expert_dim
self
.
allgather_dtype
=
allgather_dtype
self
.
share_q_dim
=
share_q_dim
self
.
norm_expert_weight
=
norm_expert_weight
self
.
use_im_start_end
=
use_im_start_end
self
.
vision_select_layer
=
vision_select_layer
self
.
image_token_len
=
image_token_len
self
.
image_token_id
=
image_token_id
self
.
understand_projector_stride
=
understand_projector_stride
self
.
vit_scale
=
vit_scale
self
.
projector_bias
=
projector_bias
self
.
patch_token_len
=
patch_token_len
if
patch_token_len
is
not
None
else
self
.
image_token_len
self
.
vision_tower_config
=
CLIPVisionConfig
(
**
vision_tower_config
)
if
vision_tower_config
is
not
None
else
None
self
.
text_config
=
Step2MiniConfig
(
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
num_attention_heads
=
num_attention_heads
,
num_attention_groups
=
num_attention_groups
,
num_hidden_layers
=
num_hidden_layers
,
max_seq_len
=
max_seq_len
,
vocab_size
=
vocab_size
,
rms_norm_eps
=
rms_norm_eps
,
moe_every_n_layer
=
moe_every_n_layer
,
use_moe
=
use_moe
,
moe_intermediate_size
=
moe_intermediate_size
,
moe_num_experts
=
moe_num_experts
,
moe_top_k
=
moe_top_k
,
max_pos_interp_ratio
=
max_pos_interp_ratio
,
alibi_slopes
=
alibi_slopes
,
moe_layer_offset
=
moe_layer_offset
,
moe_dynamic_exp_p
=
moe_dynamic_exp_p
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
head_dim
=
head_dim
,
max_position_embedding
=
max_position_embedding
,
share_expert_dim
=
share_expert_dim
,
allgather_dtype
=
allgather_dtype
,
share_q_dim
=
share_q_dim
,
norm_expert_weight
=
norm_expert_weight
,
moe_layers_enum
=
moe_layers_enum
,
architectures
=
[
"Step2MiniForCausalLM"
],
torch_dtype
=
getattr
(
self
,
"torch_dtype"
,
"bfloat16"
),
)
\ No newline at end of file
vllm/transformers_utils/configs/step.py
0 → 100644
View file @
583034f1
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
transformers
import
PretrainedConfig
class
StepConfig
(
PretrainedConfig
):
model_type
=
"step"
def
__init__
(
self
,
hidden_size
:
int
=
5120
,
intermediate_size
:
int
=
13312
,
num_attention_heads
:
int
=
40
,
num_attention_groups
:
int
=
8
,
num_hidden_layers
:
int
=
48
,
max_seq_len
:
int
=
4096
,
vocab_size
:
int
=
65536
,
rms_norm_eps
:
float
=
1e-5
,
moe_every_n_layer
:
int
=
2
,
# 2 means 50% layers use MoE, interleaved with normal non-MoE layers.
use_moe
:
bool
=
False
,
moe_intermediate_size
:
int
=
10240
,
moe_num_experts
:
int
=
16
,
moe_top_k
:
int
=
4
,
max_pos_interp_ratio
:
float
=
1
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
moe_layer_offset
:
int
=
0
,
moe_dynamic_exp_p
:
float
=
1.0
,
rope_theta
:
float
=
500000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
head_dim
:
Optional
[
int
]
=
None
,
max_position_embedding
:
int
=
16384
,
share_expert_dim
:
Optional
[
int
]
=
None
,
allgather_dtype
:
Optional
[
str
]
=
None
,
share_q_dim
:
Optional
[
int
]
=
None
,
norm_expert_weight
:
bool
=
True
,
bos_token_id
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
eos_token_id
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
**
kwargs
,
)
->
None
:
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_groups
=
num_attention_groups
self
.
num_hidden_layers
=
num_hidden_layers
self
.
max_seq_len
=
max_seq_len
self
.
vocab_size
=
vocab_size
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_moe
=
use_moe
self
.
moe_intermediate_size
=
moe_intermediate_size
self
.
moe_every_n_layer
=
moe_every_n_layer
self
.
moe_num_experts
=
moe_num_experts
self
.
moe_top_k
=
moe_top_k
self
.
max_pos_interp_ratio
=
max_pos_interp_ratio
self
.
alibi_slopes
=
alibi_slopes
self
.
moe_layer_offset
=
moe_layer_offset
self
.
moe_dynamic_exp_p
=
moe_dynamic_exp_p
#for step2 mini
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
head_dim
=
head_dim
self
.
max_position_embedding
=
max_position_embedding
if
share_expert_dim
is
None
:
self
.
share_expert_dim
=
self
.
moe_intermediate_size
*
self
.
moe_top_k
else
:
self
.
share_expert_dim
=
share_expert_dim
self
.
share_q_dim
=
share_q_dim
self
.
norm_expert_weight
=
norm_expert_weight
self
.
allgather_dtype
=
allgather_dtype
self
.
_verify_slopes
()
super
().
__init__
(
bos_token_id
=
1
if
bos_token_id
is
None
else
bos_token_id
,
eos_token_id
=
[
2
,
3
]
if
eos_token_id
is
None
else
eos_token_id
,
**
kwargs
)
def
_verify_slopes
(
self
):
if
self
.
alibi_slopes
is
None
:
return
if
len
(
self
.
alibi_slopes
)
!=
self
.
num_attention_heads
:
raise
ValueError
(
f
"Number of alibi_slopes (
{
len
(
self
.
alibi_slopes
)
}
) does not match num_attention_heads (
{
self
.
num_attention_heads
}
)"
)
class
Step1Config
(
StepConfig
):
model_type
=
"step1"
class
Step2Config
(
StepConfig
):
model_type
=
"step2"
def
__init__
(
self
,
use_offline_input_scales
:
bool
=
True
,
**
kwargs
):
self
.
use_offline_input_scales
=
use_offline_input_scales
super
().
__init__
(
**
kwargs
)
class
Step2MiniConfig
(
StepConfig
):
model_type
=
"step2_mini"
\ No newline at end of file
vllm/transformers_utils/configs/step1f.py
0 → 100644
View file @
583034f1
# SPDX-License-Identifier: Apache-2.0
from
transformers
import
Qwen2Config
from
transformers.configuration_utils
import
PretrainedConfig
from
vllm.transformers_utils.configs.step
import
Step1Config
class
Step1fAudioEncoderConfig
(
PretrainedConfig
):
model_type
=
"stepasr_encoder"
def
__init__
(
self
,
n_mels
:
int
=
128
,
n_audio_ctx
:
int
=
1500
,
n_audio_state
:
int
=
1280
,
n_audio_head
:
int
=
20
,
n_audio_layer
:
int
=
32
,
n_codebook_size
:
int
=
4096
,
llm_dim
:
int
=
3072
,
kernel_size
:
int
=
3
,
adapter_stride
:
int
=
2
,
adapter_state
:
int
=
2048
,
**
kwargs
,
)
->
None
:
super
().
__init__
(
**
kwargs
)
self
.
n_mels
=
n_mels
self
.
n_audio_ctx
=
n_audio_ctx
self
.
n_audio_state
=
n_audio_state
self
.
n_audio_head
=
n_audio_head
self
.
n_audio_layer
=
n_audio_layer
self
.
n_codebook_size
=
n_codebook_size
self
.
llm_dim
=
llm_dim
self
.
kernel_size
=
kernel_size
self
.
adapter_stride
=
adapter_stride
self
.
adapter_state
=
adapter_state
class
Step1AudioConfig
(
PretrainedConfig
):
# for step1.5t
model_type
=
"step1_audio"
def
__init__
(
self
,
hidden_size
:
int
=
5120
,
intermediate_size
:
int
=
13312
,
num_attention_heads
:
int
=
40
,
num_attention_groups
:
int
=
8
,
num_hidden_layers
:
int
=
48
,
max_seq_len
:
int
=
4096
,
vocab_size
:
int
=
65536
,
rms_norm_eps
:
float
=
1e-5
,
audio_token_id
:
int
=
29
,
eos_token_id
=
None
,
audio_encoder_config
=
None
,
**
kwargs
,
)
->
None
:
if
eos_token_id
is
not
None
:
if
isinstance
(
eos_token_id
,
list
):
eos_token_id
=
list
(
set
([
2
,
3
]
+
eos_token_id
))
else
:
eos_token_id
=
[
2
,
3
,
eos_token_id
]
else
:
eos_token_id
=
[
2
,
3
]
super
().
__init__
(
eos_token_id
=
eos_token_id
,
**
kwargs
)
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_groups
=
num_attention_groups
self
.
max_seq_len
=
max_seq_len
self
.
rms_norm_eps
=
rms_norm_eps
self
.
audio_token_id
=
audio_token_id
self
.
audio_encoder_config
=
Step1fAudioEncoderConfig
(
**
audio_encoder_config
)
if
audio_encoder_config
is
not
None
else
None
self
.
text_config
=
Step1Config
(
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
num_attention_heads
=
num_attention_heads
,
num_attention_groups
=
num_attention_groups
,
num_hidden_layers
=
num_hidden_layers
,
max_seq_len
=
max_seq_len
,
vocab_size
=
vocab_size
,
rms_norm_eps
=
rms_norm_eps
,
architectures
=
[
"Step1ForCausalLM"
],
torch_dtype
=
getattr
(
self
,
"torch_dtype"
,
"bfloat16"
),
)
class
StepAudioQwen2Config
(
PretrainedConfig
):
model_type
=
"step_audio_qwen2"
def
__init__
(
self
,
vocab_size
=
64012
,
hidden_size
=
4096
,
intermediate_size
=
11008
,
num_hidden_layers
=
48
,
num_attention_heads
=
32
,
num_attention_groups
=
4
,
num_key_value_heads
=
4
,
hidden_act
=
"silu"
,
max_position_embeddings
=
8192
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
rope_theta
=
1000000.0
,
rope_scaling
=
None
,
audio_token_id
=
151690
,
eos_token_id
=
None
,
audio_encoder_config
=
None
,
**
kwargs
):
if
eos_token_id
is
not
None
:
if
isinstance
(
eos_token_id
,
list
):
eos_token_id
=
list
(
set
([
151643
,
151645
,
151665
]
+
eos_token_id
))
else
:
eos_token_id
=
[
151643
,
151645
,
151665
,
eos_token_id
]
else
:
eos_token_id
=
[
151643
,
151645
,
151665
]
super
().
__init__
(
eos_token_id
=
eos_token_id
,
**
kwargs
)
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_groups
=
num_attention_groups
self
.
num_key_value_heads
=
num_key_value_heads
assert
self
.
num_attention_groups
==
self
.
num_key_value_heads
,
"num_attention_groups must be equal to num_key_value_heads"
self
.
hidden_act
=
hidden_act
self
.
max_position_embeddings
=
max_position_embeddings
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
audio_encoder_config
=
Step1fAudioEncoderConfig
(
**
audio_encoder_config
)
if
audio_encoder_config
is
not
None
else
None
self
.
audio_token_id
=
audio_token_id
self
.
text_config
=
Qwen2Config
(
vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
num_hidden_layers
=
num_hidden_layers
,
num_attention_heads
=
num_attention_heads
,
num_key_value_heads
=
num_key_value_heads
,
hidden_act
=
hidden_act
,
max_position_embeddings
=
max_position_embeddings
,
initializer_range
=
initializer_range
,
rms_norm_eps
=
rms_norm_eps
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
architectures
=
[
"Qwen2ForCausalLM"
],
torch_dtype
=
getattr
(
self
,
"torch_dtype"
,
"bfloat16"
),
)
\ No newline at end of file
vllm/transformers_utils/detokenizer_utils.py
View file @
583034f1
...
...
@@ -4,6 +4,8 @@
from
typing
import
Optional
from
.tokenizer
import
AnyTokenizer
# from vllm.transformers_utils.tokenizers.sentencepiece_tokenizer import (
# SentencePieceTokenizer)
def
_replace_none_with_empty
(
tokens
:
list
[
Optional
[
str
]]):
...
...
@@ -171,6 +173,13 @@ def detokenize_incrementally(
# The prefix text is necessary only to defeat cleanup algorithms in
# the decode which decide to add a space or not depending on the
# surrounding ids.
# FIXME(ys): for step1 sentencepiece tokenizer, we need to handle the special tokens in convert_tokens_to_string
# if isinstance(tokenizer, SentencePieceTokenizer):
# prefix_text = tokenizer.convert_tokens_to_string(
# output_tokens[prefix_offset:read_offset], skip_special_tokens=skip_special_tokens)
# new_text = tokenizer.convert_tokens_to_string(
# output_tokens[prefix_offset:], skip_special_tokens=skip_special_tokens)
if
tokenizer
.
is_fast
or
not
tokenizer
.
get_added_vocab
():
prefix_text
=
tokenizer
.
convert_tokens_to_string
(
output_tokens
[
prefix_offset
:
read_offset
])
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment