Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
583034f1
Commit
583034f1
authored
Oct 20, 2025
by
zhuwenwen
Browse files
[models] support step3v
parent
0adf9cda
Changes
24
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5936 additions
and
9 deletions
+5936
-9
vllm/config.py
vllm/config.py
+11
-1
vllm/entrypoints/openai/tool_parsers/__init__.py
vllm/entrypoints/openai/tool_parsers/__init__.py
+3
-0
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+9
-0
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+76
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+269
-3
vllm/model_executor/layers/logits_processor.py
vllm/model_executor/layers/logits_processor.py
+10
-4
vllm/model_executor/layers/quantization/groupwise_quant.py
vllm/model_executor/layers/quantization/groupwise_quant.py
+726
-0
vllm/model_executor/layers/quantization/quant_utils.py
vllm/model_executor/layers/quantization/quant_utils.py
+135
-0
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+7
-0
vllm/model_executor/models/mm_step1o.py
vllm/model_executor/models/mm_step1o.py
+1080
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+27
-0
vllm/model_executor/models/step1.py
vllm/model_executor/models/step1.py
+1442
-0
vllm/model_executor/models/step2_mini.py
vllm/model_executor/models/step2_mini.py
+807
-0
vllm/model_executor/models/step_encoder.py
vllm/model_executor/models/step_encoder.py
+447
-0
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+25
-1
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+29
-0
vllm/transformers_utils/configs/mmgpt.py
vllm/transformers_utils/configs/mmgpt.py
+558
-0
vllm/transformers_utils/configs/step.py
vllm/transformers_utils/configs/step.py
+102
-0
vllm/transformers_utils/configs/step1f.py
vllm/transformers_utils/configs/step1f.py
+164
-0
vllm/transformers_utils/detokenizer_utils.py
vllm/transformers_utils/detokenizer_utils.py
+9
-0
No files found.
vllm/config.py
View file @
583034f1
...
@@ -3418,6 +3418,8 @@ def _get_and_verify_max_len(
...
@@ -3418,6 +3418,8 @@ def _get_and_verify_max_len(
possible_keys
=
[
possible_keys
=
[
# OPT
# OPT
"max_position_embeddings"
,
"max_position_embeddings"
,
# step3
"max_position_embedding"
,
# GPT-2
# GPT-2
"n_positions"
,
"n_positions"
,
# MPT
# MPT
...
@@ -3491,7 +3493,13 @@ def _get_and_verify_max_len(
...
@@ -3491,7 +3493,13 @@ def _get_and_verify_max_len(
# loading HF config
# loading HF config
rope_type
=
rope_scaling
[
"rope_type"
]
rope_type
=
rope_scaling
[
"rope_type"
]
if
rope_type
not
in
(
"su"
,
"longrope"
,
"llama3"
):
if
rope_type
==
"ntk_bypart"
:
derived_max_model_len
=
min
(
derived_max_model_len
,
rope_scaling
[
"real_length"
]
*
rope_scaling
[
"scaling_factor"
]
)
if
"real_length"
in
rope_scaling
and
"scaling_factor"
in
rope_scaling
else
derived_max_model_len
elif
rope_type
not
in
(
"su"
,
"longrope"
,
"llama3"
):
if
disable_sliding_window
:
if
disable_sliding_window
:
# TODO(robertgshaw): Find a model that supports rope_scaling
# TODO(robertgshaw): Find a model that supports rope_scaling
# with sliding window to see if this case should be allowed.
# with sliding window to see if this case should be allowed.
...
@@ -3548,6 +3556,8 @@ def _get_and_verify_max_len(
...
@@ -3548,6 +3556,8 @@ def _get_and_verify_max_len(
logger
.
warning
(
logger
.
warning
(
"%s Make sure the value is correct and within the "
"%s Make sure the value is correct and within the "
"model context size."
,
msg
)
"model context size."
,
msg
)
if
getattr
(
hf_config
,
"max_position_embedding"
,
None
)
is
not
None
:
# step3/3v
hf_config
.
max_position_embedding
=
max_model_len
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"
{
msg
}
To allow overriding this maximum, set "
f
"
{
msg
}
To allow overriding this maximum, set "
...
...
vllm/entrypoints/openai/tool_parsers/__init__.py
View file @
583034f1
...
@@ -36,4 +36,7 @@ __all__ = [
...
@@ -36,4 +36,7 @@ __all__ = [
"xLAMToolParser"
,
"xLAMToolParser"
,
"MinimaxToolParser"
,
"MinimaxToolParser"
,
"Glm4MoeModelToolParser"
,
"Glm4MoeModelToolParser"
,
"Step1p5vMini2ToolParser"
,
"Step1p5vMini2MsToolParser"
,
"Step3ToolParser"
,
]
]
vllm/model_executor/layers/activation.py
View file @
583034f1
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
"""Custom activation functions."""
"""Custom activation functions."""
import
math
import
math
from
typing
import
Optional
from
typing
import
Optional
import
optimus
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -53,6 +54,14 @@ class FatreluAndMul(CustomOp):
...
@@ -53,6 +54,14 @@ class FatreluAndMul(CustomOp):
return
out
return
out
class
OptimusSiluAndMul
(
nn
.
Module
):
def
forward
(
self
,
x
:
torch
.
Tensor
,
output
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
ops
.
Optimus
.
SiluDot_forward
(
x
,
out
=
output
)
@
CustomOp
.
register
(
"silu_and_mul"
)
@
CustomOp
.
register
(
"silu_and_mul"
)
class
SiluAndMul
(
CustomOp
):
class
SiluAndMul
(
CustomOp
):
"""An activation function for SwiGLU.
"""An activation function for SwiGLU.
...
...
vllm/model_executor/layers/layernorm.py
View file @
583034f1
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Custom normalization layers."""
"""Custom normalization layers."""
from
typing
import
Optional
,
Union
,
Tuple
from
typing
import
Optional
,
Union
,
Tuple
import
optimus
# noqa F401
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -298,6 +299,49 @@ class RMSNorm(CustomOp):
...
@@ -298,6 +299,49 @@ class RMSNorm(CustomOp):
return
s
return
s
class
OptimusRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
eps
:
float
=
1e-6
,
)
->
None
:
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
fp16_out
:
bool
=
False
)
->
torch
.
Tensor
:
if
residual
is
not
None
:
assert
output
is
None
from
vllm
import
_custom_ops
as
ops
assert
not
fp16_out
ops
.
fused_add_rms_norm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
,
)
return
x
,
residual
else
:
if
fp16_out
:
if
output
is
None
:
output
=
torch
.
empty_like
(
x
).
half
()
else
:
output
=
output
.
half
()
# return torch.ops.Optimus.rms_norm(x,
# self.weight,
# self.variance_epsilon,
# out=output)
return
torch
.
nn
.
functional
.
rms_norm
(
x
,
self
.
weight
,
self
.
variance_epsilon
,
out
=
output
)
@
CustomOp
.
register
(
"gemma_rms_norm"
)
@
CustomOp
.
register
(
"gemma_rms_norm"
)
class
GemmaRMSNorm
(
CustomOp
):
class
GemmaRMSNorm
(
CustomOp
):
"""RMS normalization for Gemma.
"""RMS normalization for Gemma.
...
@@ -363,3 +407,35 @@ class GemmaRMSNorm(CustomOp):
...
@@ -363,3 +407,35 @@ class GemmaRMSNorm(CustomOp):
self
.
forward_static
)
self
.
forward_static
)
self
.
_is_compiled
=
True
self
.
_is_compiled
=
True
return
self
.
forward_native
(
x
,
residual
)
return
self
.
forward_native
(
x
,
residual
)
class
OptimusLayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
eps
:
float
=
1e-6
)
->
None
:
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
assert
residual
is
None
# return torch.ops.Optimus.layer_norm(x,
# self.weight,
# self.bias,
# eps=self.variance_epsilon,
# out=output)
# return torch.nn.functional.layer_norm(x,
# self.weight,
# self.bias,
# eps=self.variance_epsilon,
# out=output)
return
torch
.
nn
.
functional
.
layer_norm
(
x
,
self
.
weight
.
shape
,
# normalized_shape 应为 weight 的形状
self
.
weight
,
self
.
bias
,
eps
=
self
.
variance_epsilon
)
vllm/model_executor/layers/linear.py
View file @
583034f1
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
itertools
import
itertools
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
typing
import
Any
,
Literal
,
Optional
,
Union
from
typing
import
Any
,
Literal
,
Optional
,
Union
,
List
import
vllm.envs
as
envs
import
vllm.envs
as
envs
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -269,6 +269,40 @@ class UnquantizedLinearMethod(LinearMethodBase):
...
@@ -269,6 +269,40 @@ class UnquantizedLinearMethod(LinearMethodBase):
return
dispatch_unquantized_gemm
()(
x
,
layer
.
weight
,
bias
)
return
dispatch_unquantized_gemm
()(
x
,
layer
.
weight
,
bias
)
class
UnquantizedMoELinearMethod
(
LinearMethodBase
):
"""MoE Linear method without quantization.
"""
def
__init__
(
self
):
self
.
quant_config
=
None
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
num_experts
:
Optional
[
int
]
=
None
,
**
extra_weight_attrs
):
weight
=
Parameter
(
torch
.
empty
(
num_experts
,
sum
(
output_partition_sizes
),
input_size_per_partition
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
params_dtype
),
requires_grad
=
False
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
2
,
"output_dim"
:
1
})
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""Apply the weights to the input tensor."""
raise
NotImplementedError
class
LinearBase
(
torch
.
nn
.
Module
):
class
LinearBase
(
torch
.
nn
.
Module
):
"""Base linear layer.
"""Base linear layer.
...
@@ -783,6 +817,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -783,6 +817,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param
.
shard_id
.
append
(
loaded_shard_id
)
param
.
shard_id
.
append
(
loaded_shard_id
)
param
.
shard_id_map
[
loaded_shard_id
]
=
len
(
param
.
data_container
)
param
.
shard_id_map
[
loaded_shard_id
]
=
len
(
param
.
data_container
)
param
.
data_container
.
append
(
loaded_weight
)
param
.
data_container
.
append
(
loaded_weight
)
if
len
(
param
.
data_container
)
==
2
:
self
.
qweight
=
param
.
materialize_nested
()
return
return
param_data
=
param
.
data
param_data
=
param
.
data
...
@@ -986,6 +1022,175 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -986,6 +1022,175 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset
=
shard_offset
,
shard_offset
=
shard_offset
,
shard_size
=
shard_size
)
shard_size
=
shard_size
)
class
MergedColumnParallelMoELinear
(
MergedColumnParallelLinear
):
def
__init__
(
self
,
num_experts
:
int
,
input_size
:
int
,
output_sizes
:
List
[
int
],
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
torch
.
nn
.
Module
.
__init__
(
self
)
output_size
=
sum
(
output_sizes
)
self
.
num_experts
=
num_experts
self
.
output_sizes
=
output_sizes
self
.
input_size
=
input_size
self
.
output_size
=
sum
(
output_sizes
)
tp_size
=
get_tensor_model_parallel_world_size
()
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
self
.
output_size_per_partition
=
divide
(
self
.
output_size
,
tp_size
)
self
.
output_partition_sizes
=
[
divide
(
output_size
,
tp_size
)
for
output_size
in
self
.
output_sizes
]
self
.
gather_output
=
False
if
output_sizes
is
None
:
output_sizes
=
[
output_size
]
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
if
quant_config
is
None
:
self
.
quant_method
=
UnquantizedMoELinearMethod
()
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
# FIXME(ys): hack for moe
if
isinstance
(
self
.
quant_method
,
UnquantizedLinearMethod
):
self
.
quant_method
=
UnquantizedMoELinearMethod
()
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
self
,
self
.
input_size
,
self
.
output_partition_sizes
,
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
,
self
.
num_experts
,
weight_loader
=
self
.
weight_loader
)
self
.
register_parameter
(
"bias"
,
None
)
def
forward
(
self
,
input_
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_idx
:
int
=
-
1
):
if
isinstance
(
self
.
quant_method
,
UnquantizedMoELinearMethod
):
# use optimus moe_ffn outside
return
bias
=
None
assert
self
.
quant_method
is
not
None
output
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
,
expert_idx
=
expert_idx
,
output
=
output
)
return
output
class
QKVReplicatedLinear
(
ReplicatedLinear
):
def
__init__
(
self
,
hidden_size
:
int
,
head_size
:
int
,
total_num_heads
:
int
,
total_num_kv_heads
:
Optional
[
int
]
=
None
,
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
return_bias
:
bool
=
True
):
nn
.
Module
.
__init__
(
self
)
self
.
hidden_size
=
hidden_size
self
.
head_size
=
head_size
self
.
num_heads
=
total_num_heads
self
.
num_kv_heads
=
total_num_kv_heads
if
total_num_kv_heads
else
total_num_heads
self
.
input_size
=
self
.
hidden_size
self
.
output_size
=
(
self
.
num_heads
+
2
*
self
.
num_kv_heads
)
*
self
.
head_size
self
.
skip_bias_add
=
skip_bias_add
self
.
return_bias
=
return_bias
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedLinearMethod
()
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
self
,
self
.
input_size
,
[
self
.
output_size
],
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
,
weight_loader
=
self
.
weight_loader
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
dtype
=
self
.
params_dtype
))
set_weight_attrs
(
self
.
bias
,
{
"output_dim"
:
0
,
"weight_loader"
:
self
.
weight_loader
})
else
:
self
.
register_parameter
(
"bias"
,
None
)
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
str
]
=
None
):
param_data
=
param
.
data
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
is_quantization
=
not
isinstance
(
self
.
quant_method
,
UnquantizedLinearMethod
)
if
loaded_shard_id
is
None
:
# Loaded weight is already packed.
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
return
assert
loaded_shard_id
in
[
"q"
,
"k"
,
"v"
]
if
output_dim
is
not
None
:
if
loaded_shard_id
==
"q"
:
shard_offset
=
0
shard_size
=
self
.
num_heads
*
self
.
head_size
elif
loaded_shard_id
==
"k"
:
shard_offset
=
self
.
num_heads
*
self
.
head_size
shard_size
=
self
.
num_kv_heads
*
self
.
head_size
elif
loaded_shard_id
==
"v"
:
shard_offset
=
(
self
.
num_heads
+
self
.
num_kv_heads
)
*
self
.
head_size
shard_size
=
self
.
num_kv_heads
*
self
.
head_size
# If quantized, we need to adjust the offset and size to account
# for the packing.
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
if
not
envs
.
VLLM_USE_NN
or
is_quantization
:
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
else
:
param_data
=
param_data
.
narrow
(
int
(
not
(
output_dim
)),
shard_offset
,
shard_size
)
else
:
ignore_warning
=
getattr
(
param
,
"ignore_warning"
,
False
)
if
not
ignore_warning
:
logger
.
warning
(
"Loading a weight without `output_dim` attribute in "
"QKVReplicatedLinear, assume the weight is the same "
"for all partitions."
)
if
envs
.
VLLM_USE_NN
and
not
is_quantization
:
loaded_weight
=
loaded_weight
.
t
()
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
class
QKVParallelLinear
(
ColumnParallelLinear
):
class
QKVParallelLinear
(
ColumnParallelLinear
):
"""Linear layers for the attention's QKV transformation.
"""Linear layers for the attention's QKV transformation.
...
@@ -1185,6 +1390,8 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -1185,6 +1390,8 @@ class QKVParallelLinear(ColumnParallelLinear):
param
.
shard_id
.
append
(
loaded_shard_id
)
param
.
shard_id
.
append
(
loaded_shard_id
)
param
.
shard_id_map
[
loaded_shard_id
]
=
len
(
param
.
data_container
)
param
.
shard_id_map
[
loaded_shard_id
]
=
len
(
param
.
data_container
)
param
.
data_container
.
append
(
loaded_weight
)
param
.
data_container
.
append
(
loaded_weight
)
if
len
(
param
.
data_container
)
==
3
:
self
.
qweight
=
param
.
materialize_nested
()
return
return
param_data
=
param
.
data
param_data
=
param
.
data
...
@@ -1495,7 +1702,7 @@ class RowParallelLinear(LinearBase):
...
@@ -1495,7 +1702,7 @@ class RowParallelLinear(LinearBase):
def
forward
(
def
forward
(
self
,
input_
,
self
,
input_
,
use_fused_silu_mul_quant
:
Optional
[
bool
]
=
False
use_fused_silu_mul_quant
:
Optional
[
bool
]
=
False
,
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
if
self
.
input_is_parallel
:
if
self
.
input_is_parallel
:
input_parallel
=
input_
input_parallel
=
input_
...
@@ -1758,3 +1965,62 @@ class QKVCrossParallelLinear(LinearBase):
...
@@ -1758,3 +1965,62 @@ class QKVCrossParallelLinear(LinearBase):
s
+=
f
", tp_size=
{
get_tensor_model_parallel_world_size
()
}
"
s
+=
f
", tp_size=
{
get_tensor_model_parallel_world_size
()
}
"
s
+=
", gather_output=False"
s
+=
", gather_output=False"
return
s
return
s
class
RowParallelMoELinear
(
RowParallelLinear
):
def
__init__
(
self
,
num_experts
:
int
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
torch
.
nn
.
Module
.
__init__
(
self
)
self
.
num_experts
=
num_experts
self
.
input_size
=
input_size
self
.
output_size
=
output_size
self
.
reduce_results
=
False
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedMoELinearMethod
()
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
# FIXME(ys): hack for moe
if
isinstance
(
self
.
quant_method
,
UnquantizedLinearMethod
):
self
.
quant_method
=
UnquantizedMoELinearMethod
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
self
,
self
.
input_size_per_partition
,
[
self
.
output_size
],
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
,
self
.
num_experts
,
weight_loader
=
self
.
weight_loader
)
self
.
register_parameter
(
"bias"
,
None
)
def
forward
(
# type: ignore[override]
self
,
input_
,
residual
=
None
,
expert_idx
:
int
=
-
1
,
output
:
Optional
[
torch
.
Tensor
]
=
None
):
if
isinstance
(
self
.
quant_method
,
UnquantizedMoELinearMethod
):
# use optimus moe_ffn outside
return
bias
=
None
assert
self
.
quant_method
is
not
None
output
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
,
expert_idx
=
expert_idx
,
output
=
output
)
return
output
\ No newline at end of file
vllm/model_executor/layers/logits_processor.py
View file @
583034f1
...
@@ -36,7 +36,8 @@ class LogitsProcessor(nn.Module):
...
@@ -36,7 +36,8 @@ class LogitsProcessor(nn.Module):
org_vocab_size
:
Optional
[
int
]
=
None
,
org_vocab_size
:
Optional
[
int
]
=
None
,
scale
:
float
=
1.0
,
scale
:
float
=
1.0
,
logits_as_input
:
bool
=
False
,
logits_as_input
:
bool
=
False
,
soft_cap
:
Optional
[
float
]
=
None
)
->
None
:
soft_cap
:
Optional
[
float
]
=
None
,
need_fp32_logits
:
bool
=
False
)
->
None
:
"""
"""
Args:
Args:
scale: A scaling factor to apply to the logits.
scale: A scaling factor to apply to the logits.
...
@@ -52,6 +53,7 @@ class LogitsProcessor(nn.Module):
...
@@ -52,6 +53,7 @@ class LogitsProcessor(nn.Module):
self
.
soft_cap
=
soft_cap
self
.
soft_cap
=
soft_cap
# Whether to use gather or all-gather to gather the logits.
# Whether to use gather or all-gather to gather the logits.
self
.
use_all_gather
=
current_platform
.
use_all_gather
()
self
.
use_all_gather
=
current_platform
.
use_all_gather
()
self
.
need_fp32_logits
=
need_fp32_logits
def
forward
(
def
forward
(
self
,
self
,
...
@@ -106,6 +108,10 @@ class LogitsProcessor(nn.Module):
...
@@ -106,6 +108,10 @@ class LogitsProcessor(nn.Module):
embedding_bias
:
Optional
[
torch
.
Tensor
],
embedding_bias
:
Optional
[
torch
.
Tensor
],
)
->
Optional
[
torch
.
Tensor
]:
)
->
Optional
[
torch
.
Tensor
]:
# Get the logits for the next tokens.
# Get the logits for the next tokens.
if
self
.
need_fp32_logits
:
logits
=
torch
.
ops
.
OptimusMoe
.
matmul_fp32
(
hidden_states
,
lm_head
.
weight
.
t
())
else
:
logits
=
lm_head
.
quant_method
.
apply
(
lm_head
,
logits
=
lm_head
.
quant_method
.
apply
(
lm_head
,
hidden_states
,
hidden_states
,
bias
=
embedding_bias
)
bias
=
embedding_bias
)
...
...
vllm/model_executor/layers/quantization/groupwise_quant.py
0 → 100644
View file @
583034f1
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.layers.fused_moe.optimus_moe
import
(
# noqa: F401
optimus_moe_int8
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
direct_register_custom_op
class
GroupwiseQuantConfig
(
QuantizationConfig
):
"""Config class for Groupwise Quantization.
"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
symmetric
:
bool
=
False
,
bf16_blocks
:
Optional
[
list
]
=
None
,
int8_blocks
:
Optional
[
list
]
=
None
,
weight_dtype
:
Optional
[
str
]
=
None
,
# Literal["int8", "fp8_e4m3", "fp6", "int4"] = "int8",
# FIXME: hack for mixed precision quantization
extra_quant_configs
:
Optional
[
dict
]
=
None
,
)
->
None
:
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
symmetric
=
symmetric
self
.
bf16_blocks
=
bf16_blocks
if
bf16_blocks
else
[]
self
.
int8_blocks
=
int8_blocks
if
int8_blocks
else
[]
self
.
extra_quant_configs
=
extra_quant_configs
if
self
.
weight_bits
==
4
:
self
.
pack_factor
=
32
//
self
.
weight_bits
else
:
self
.
pack_factor
=
1
if
not
weight_dtype
:
if
weight_bits
==
8
:
self
.
weight_dtype
=
"int8"
# fp8e4m3 must explicitly set
elif
weight_bits
==
6
:
self
.
weight_dtype
=
"fp6"
elif
weight_bits
==
4
:
self
.
weight_dtype
=
"int4"
else
:
raise
ValueError
(
f
"Unsupported weight bits:
{
weight_bits
}
"
)
else
:
self
.
weight_dtype
=
weight_dtype
def
__repr__
(
self
)
->
str
:
return
(
f
"GroupwiseQuantConfig(weight_bits=
{
self
.
weight_bits
}
, "
f
"group_size=
{
self
.
group_size
}
)"
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"groupwise_quant"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# The Groupwise Quant kernel only supports Ampere or newer GPUs.
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"quant_config.json"
,
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"GroupwiseQuantConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"w_bit"
,
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"q_group_size"
,
"group_size"
])
bf16_blocks
=
config
.
get
(
"bf16_blocks"
,
[])
int8_blocks
=
config
.
get
(
"int8_blocks"
,
[])
weight_dtype
=
config
.
get
(
"weight_type"
)
extra_quant_configs
=
config
.
get
(
"extra_quant_configs"
,
{})
return
cls
(
weight_bits
,
group_size
,
bf16_blocks
=
bf16_blocks
,
int8_blocks
=
int8_blocks
,
weight_dtype
=
weight_dtype
,
extra_quant_configs
=
extra_quant_configs
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
=
""
)
->
Optional
[
"QuantizeMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
block_index
=
int
(
prefix
.
split
(
"layers."
)[
1
].
split
(
"."
)[
0
])
layer_name
=
prefix
.
split
(
"."
)[
-
1
]
if
block_index
in
self
.
bf16_blocks
:
return
UnquantizedLinearMethod
()
if
self
.
extra_quant_configs
:
for
config
in
self
.
extra_quant_configs
:
if
block_index
in
config
[
"block_indices"
]
and
layer_name
in
config
[
"target_modules"
]:
return
GroupwiseQuantLinearMethod
(
GroupwiseQuantConfig
(
weight_bits
=
config
[
"weight_bit"
],
group_size
=
config
[
"group_size"
]
or
self
.
group_size
,
weight_dtype
=
config
[
"weight_type"
],
extra_quant_configs
=
self
.
extra_quant_configs
))
# no specific config for this layer means no quantization
return
UnquantizedLinearMethod
()
else
:
# Compatible with old config
return
GroupwiseQuantLinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
# For MoE layers, only support 8-bit quantization
if
self
.
weight_bits
==
8
:
return
GroupwiseInt8MoeMethod
(
self
)
else
:
raise
ValueError
(
f
"Unsupported weight bits for MoE:
{
self
.
weight_bits
}
. Only 8-bit is supported."
)
return
None
class
GroupwiseQuantLinearMethod
(
LinearMethodBase
):
"""Linear method for GroupwiseQuant.
Args:
quant_config: The groupwise_quant quantization config.
"""
def
__init__
(
self
,
quant_config
:
GroupwiseQuantConfig
)
->
None
:
self
.
quant_config
=
quant_config
self
.
sm
=
torch
.
cuda
.
get_device_capability
()
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
num_experts
:
Optional
[
int
]
=
None
,
**
extra_weight_attrs
):
assert
input_size_per_partition
%
self
.
quant_config
.
group_size
==
0
output_size_per_partition
=
sum
(
output_partition_sizes
)
assert
output_size_per_partition
%
self
.
quant_config
.
pack_factor
==
0
layer_keys
=
dir
(
layer
)
has_num_experts
=
any
(
"num_experts"
in
name
for
name
in
layer_keys
)
if
not
has_num_experts
:
layer
.
register_parameter
(
"num_experts"
,
None
)
if
self
.
quant_config
.
weight_bits
==
4
:
assert
input_size_per_partition
%
self
.
quant_config
.
group_size
==
0
assert
output_size_per_partition
%
self
.
quant_config
.
pack_factor
==
0
if
num_experts
:
weight_shape
=
[
num_experts
,
input_size_per_partition
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
]
scale_shape
=
[
num_experts
,
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
]
else
:
weight_shape
=
[
input_size_per_partition
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
]
scale_shape
=
[
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
]
qweight
=
Parameter
(
torch
.
empty
(
*
weight_shape
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
scales
=
Parameter
(
torch
.
empty
(
*
scale_shape
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
zeros
=
Parameter
(
torch
.
empty
(
*
scale_shape
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
if
num_experts
:
set_weight_attrs
(
qweight
,
{
"input_dim"
:
1
,
"output_dim"
:
2
,
"packed_dim"
:
2
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
set_weight_attrs
(
scales
,
{
"input_dim"
:
1
,
"output_dim"
:
2
,
})
set_weight_attrs
(
zeros
,
{
"input_dim"
:
1
,
"output_dim"
:
2
,
})
else
:
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
set_weight_attrs
(
scales
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
})
set_weight_attrs
(
zeros
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
})
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
layer
.
register_parameter
(
"zeros"
,
zeros
)
set_weight_attrs
(
zeros
,
extra_weight_attrs
)
elif
self
.
quant_config
.
weight_bits
==
8
:
assert
input_size_per_partition
%
self
.
quant_config
.
group_size
==
0
if
num_experts
:
weight_shape
=
[
num_experts
,
input_size_per_partition
,
output_size_per_partition
]
scale_shape
=
[
num_experts
,
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
]
else
:
weight_shape
=
[
input_size_per_partition
,
output_size_per_partition
]
scale_shape
=
[
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
]
qweight
=
Parameter
(
torch
.
empty
(
*
weight_shape
,
device
=
"cuda"
,
dtype
=
torch
.
int8
,
),
requires_grad
=
False
,
)
scales
=
Parameter
(
torch
.
empty
(
*
scale_shape
,
device
=
"cuda"
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
if
num_experts
:
set_weight_attrs
(
qweight
,
{
"input_dim"
:
1
,
"output_dim"
:
2
,
})
set_weight_attrs
(
scales
,
{
"input_dim"
:
1
,
"output_dim"
:
2
,
})
else
:
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
})
set_weight_attrs
(
scales
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
})
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
elif
self
.
quant_config
.
weight_bits
==
6
:
assert
input_size_per_partition
%
self
.
quant_config
.
group_size
==
0
if
num_experts
:
weight_shape
=
[
num_experts
,
output_size_per_partition
,
input_size_per_partition
,
]
scale_shape
=
[
num_experts
,
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
]
else
:
weight_shape
=
[
output_size_per_partition
,
input_size_per_partition
]
scale_shape
=
[
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
]
qweight
=
Parameter
(
torch
.
zeros
(
*
weight_shape
,
device
=
"cpu"
,
# hack for fp6 weight is stored in float16, to avoid cuda oom
dtype
=
torch
.
float16
,
),
requires_grad
=
False
,
)
scales
=
Parameter
(
torch
.
empty
(
*
scale_shape
,
device
=
"cuda"
,
dtype
=
torch
.
float16
,
),
requires_grad
=
False
,
)
if
num_experts
:
set_weight_attrs
(
qweight
,
{
"input_dim"
:
2
,
"output_dim"
:
1
,
})
set_weight_attrs
(
scales
,
{
"input_dim"
:
1
,
"output_dim"
:
2
,
})
else
:
set_weight_attrs
(
qweight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
,
})
set_weight_attrs
(
scales
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
})
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
else
:
raise
NotImplementedError
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
if
not
hasattr
(
layer
,
"qweight"
):
return
if
self
.
quant_config
.
weight_bits
==
4
:
num_experts
=
layer
.
num_experts
qweight
=
layer
.
qweight
zeros
=
layer
.
zeros
scales
=
layer
.
scales
if
num_experts
:
qscales_list
=
[]
for
i
in
range
(
num_experts
):
qweight_processed
,
qscales
=
torch
.
ops
.
Optimus
.
GemmInt4GroupQuantWeight
(
qweight
[
i
],
zeros
[
i
],
scales
[
i
]
+
zeros
[
i
],
self
.
quant_config
.
group_size
)
qweight
[
i
].
copy_
(
qweight_processed
)
qscales_list
.
append
(
qscales
)
qscales
=
Parameter
(
torch
.
stack
(
qscales_list
),
requires_grad
=
False
)
layer
.
register_parameter
(
"qscales"
,
qscales
)
else
:
qweight_processed
,
qscales
=
torch
.
ops
.
Optimus
.
GemmInt4GroupQuantWeight
(
qweight
,
zeros
,
scales
+
zeros
,
self
.
quant_config
.
group_size
)
qweight
.
copy_
(
qweight_processed
)
qscales
=
Parameter
(
qscales
,
requires_grad
=
False
)
layer
.
register_parameter
(
"qscales"
,
qscales
)
layer
.
_parameters
.
pop
(
"zeros"
)
layer
.
_parameters
.
pop
(
"scales"
)
elif
self
.
quant_config
.
weight_bits
==
8
:
num_experts
=
layer
.
num_experts
qweight
=
layer
.
qweight
if
num_experts
:
for
i
in
range
(
num_experts
):
qweight
[
i
].
copy_
(
torch
.
ops
.
Optimus
.
FpAIntBPreprocessWeightGPU
(
qweight
[
i
].
t
().
contiguous
(),
torch
.
int8
))
else
:
qweight
.
copy_
(
torch
.
ops
.
Optimus
.
FpAIntBPreprocessWeightGPU
(
qweight
.
t
().
contiguous
(),
torch
.
int8
))
elif
self
.
quant_config
.
weight_bits
==
6
:
num_experts
=
layer
.
num_experts
qweight
=
layer
.
qweight
layer
.
_parameters
.
pop
(
"qweight"
)
assert
qweight
.
shape
[
-
1
]
%
8
==
0
if
num_experts
:
qweight_processed
=
torch
.
empty
(
qweight
.
shape
[
0
],
qweight
.
shape
[
1
],
qweight
.
shape
[
2
]
*
6
//
8
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
else
:
qweight_processed
=
torch
.
empty
(
qweight
.
shape
[
0
],
qweight
.
shape
[
1
]
*
6
//
8
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
if
num_experts
:
for
i
in
range
(
num_experts
):
qweight_processed
[
i
]
=
torch
.
ops
.
Optimus
.
fp6_preprocess_weight
(
qweight
[
i
].
cpu
()).
cuda
()
qweight_processed
=
Parameter
(
qweight_processed
,
requires_grad
=
False
)
layer
.
register_parameter
(
"qweight"
,
qweight_processed
)
else
:
qweight_processed
=
Parameter
(
torch
.
ops
.
Optimus
.
fp6_preprocess_weight
(
qweight
.
cpu
()).
cuda
(),
requires_grad
=
False
)
layer
.
register_parameter
(
"qweight"
,
qweight_processed
)
else
:
raise
NotImplementedError
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_idx
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
if
self
.
quant_config
.
weight_bits
==
4
:
qweight
=
layer
.
qweight
qscales
=
layer
.
qscales
num_experts
=
layer
.
num_experts
if
num_experts
:
assert
expert_idx
is
not
None
,
"expert_idx is None"
qweight
=
qweight
[
expert_idx
]
qscales
=
qscales
[
expert_idx
]
out
=
torch
.
ops
.
vllm
.
optimus_gemm_int4_group
(
x
,
qweight
,
qscales
,
bias
,
None
,
# Placeholder for a fifth argument that is None
out
=
output
)
if
residual
is
not
None
:
out
+=
residual
return
out
elif
self
.
quant_config
.
weight_bits
==
8
:
qweight
=
layer
.
qweight
scales
=
layer
.
scales
num_experts
=
layer
.
num_experts
if
num_experts
:
assert
expert_idx
is
not
None
,
"expert_idx is None"
qweight
=
qweight
[
expert_idx
]
scales
=
scales
[
expert_idx
]
if
residual
is
not
None
:
assert
output
is
None
or
output
is
residual
out
=
torch
.
ops
.
vllm
.
optimus_fp_aintb_gemm
(
x
,
qweight
,
torch
.
int8
,
# Placeholder for dtype argument
scales
,
residual
,
out
=
residual
)
if
bias
is
not
None
:
out
+=
bias
else
:
out
=
torch
.
ops
.
vllm
.
optimus_fp_aintb_gemm
(
x
,
qweight
,
torch
.
int8
,
# Placeholder for dtype argument
scales
,
bias
,
out
=
output
)
return
out
elif
self
.
quant_config
.
weight_bits
==
6
:
qweight
=
layer
.
qweight
scales
=
layer
.
scales
num_experts
=
layer
.
num_experts
if
num_experts
:
assert
expert_idx
is
not
None
,
"expert_idx is None"
qweight
=
qweight
[
expert_idx
]
scales
=
scales
[
expert_idx
]
if
x
.
dtype
!=
torch
.
bfloat16
:
if
output
is
None
:
output
=
torch
.
empty
(
x
.
shape
[
0
],
qweight
.
shape
[
0
],
device
=
x
.
device
,
dtype
=
torch
.
bfloat16
)
else
:
output
=
output
.
to
(
torch
.
bfloat16
)
out
=
torch
.
ops
.
vllm
.
optimus_fp6_linear
(
x
,
qweight
,
scales
,
4
,
# Placeholder for fp6_format_code
out
=
output
)
if
bias
is
not
None
:
out
+=
bias
if
residual
is
not
None
:
out
+=
residual
return
out
else
:
raise
NotImplementedError
class
GroupwiseInt8MoeMethod
(
FusedMoEMethodBase
):
"""MoE method for Groupwise INT8 quantization.
Args:
quant_config: The groupwise quantization config.
"""
def
__init__
(
self
,
quant_config
:
GroupwiseQuantConfig
):
self
.
quant_config
=
quant_config
assert
self
.
quant_config
.
weight_bits
==
8
,
"Only 8-bit quantization is supported for GroupwiseInt8MoeMethod"
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
extra_weight_attrs
.
update
({
"is_transposed"
:
True
,
"quant_method"
:
FusedMoeWeightScaleSupported
.
GROUP
.
value
,
})
# Create INT8 weights
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
2
*
intermediate_size_per_partition
,
dtype
=
torch
.
int8
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
{
"input_dim"
:
1
,
"output_dim"
:
2
,
})
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size_per_partition
,
hidden_size
,
dtype
=
torch
.
int8
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
{
"input_dim"
:
1
,
"output_dim"
:
2
,
})
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# Create scales for groupwise quantization
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
//
self
.
quant_config
.
group_size
,
2
*
intermediate_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
set_weight_attrs
(
w13_weight_scale
,
{
"input_dim"
:
1
,
"output_dim"
:
2
,
})
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size_per_partition
//
self
.
quant_config
.
group_size
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
set_weight_attrs
(
w2_weight_scale
,
{
"input_dim"
:
1
,
"output_dim"
:
2
,
})
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# Process weights similar to GroupwiseQuantLinearMethod for 8-bit case
num_experts
=
layer
.
w13_weight
.
shape
[
0
]
for
expert
in
range
(
num_experts
):
# Preprocess w13 weight (gate and up combined)
layer
.
w13_weight
[
expert
].
copy_
(
torch
.
ops
.
Optimus
.
FpAIntBPreprocessWeightGPU
(
layer
.
w13_weight
[
expert
].
t
().
contiguous
(),
torch
.
int8
))
# Preprocess w2 weight (down)
layer
.
w2_weight
[
expert
].
copy_
(
torch
.
ops
.
Optimus
.
FpAIntBPreprocessWeightGPU
(
layer
.
w2_weight
[
expert
].
t
().
contiguous
(),
torch
.
int8
))
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
return
torch
.
ops
.
vllm
.
optimus_moe_int8
(
hidden_states
=
x
,
router_logits
=
router_logits
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
top_k
=
top_k
,
global_num_experts
=
global_num_experts
,
norm_expert_weight
=
renormalize
,
activation
=
activation
,
)
# Wrapper and Fake Functions for Optimus::GemmInt4Group
def
optimus_gemm_int4_group
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
qscales
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
torch
.
ops
.
Optimus
.
GemmInt4Group
(
x
,
qweight
,
qscales
,
bias
,
None
,
out
=
out
)
def
optimus_gemm_int4_group_fake
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
qscales
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output_shape
=
list
(
x
.
shape
[:
-
1
])
+
[
qscales
.
shape
[
-
1
]]
if
out
is
not
None
:
return
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
return
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
# Wrapper and Fake Functions for Optimus::FpAIntBGemm
def
optimus_fp_aintb_gemm
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
dtype_arg
:
torch
.
dtype
,
scales
:
torch
.
Tensor
,
bias_or_residual
:
Optional
[
torch
.
Tensor
],
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
torch
.
ops
.
Optimus
.
FpAIntBGemm
(
x
,
qweight
,
dtype_arg
,
scales
,
bias_or_residual
,
"identity"
,
out
=
out
)
def
optimus_fp_aintb_gemm_fake
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
dtype_arg
:
torch
.
dtype
,
scales
:
torch
.
Tensor
,
bias_or_residual
:
Optional
[
torch
.
Tensor
],
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output_shape
=
list
(
x
.
shape
[:
-
1
])
+
[
qweight
.
shape
[
-
1
]]
if
out
is
not
None
:
return
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
return
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
# Wrapper and Fake Functions for Optimus::fp6_linear
def
optimus_fp6_linear
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
fp6_format_code
:
int
,
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
torch
.
ops
.
Optimus
.
fp6_linear
(
x
,
qweight
,
scales
,
fp6_format_code
,
out
=
out
)
def
optimus_fp6_linear_fake
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
fp6_format_code
:
int
,
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output_channels
=
scales
.
shape
[
-
1
]
output_shape
=
list
(
x
.
shape
[:
-
1
])
+
[
output_channels
]
output_dtype
=
x
.
dtype
if
x
.
dtype
!=
torch
.
bfloat16
:
output_dtype
=
torch
.
bfloat16
if
out
is
not
None
:
return
torch
.
empty
(
output_shape
,
dtype
=
output_dtype
,
device
=
x
.
device
)
return
torch
.
empty
(
output_shape
,
dtype
=
output_dtype
,
device
=
x
.
device
)
direct_register_custom_op
(
op_name
=
"optimus_gemm_int4_group"
,
op_func
=
optimus_gemm_int4_group
,
mutates_args
=
[
"out"
],
fake_impl
=
optimus_gemm_int4_group_fake
,
)
direct_register_custom_op
(
op_name
=
"optimus_fp_aintb_gemm"
,
op_func
=
optimus_fp_aintb_gemm
,
mutates_args
=
[
"out"
,
"bias_or_residual"
],
fake_impl
=
optimus_fp_aintb_gemm_fake
,
)
direct_register_custom_op
(
op_name
=
"optimus_fp6_linear"
,
op_func
=
optimus_fp6_linear
,
mutates_args
=
[
"out"
],
fake_impl
=
optimus_fp6_linear_fake
,
)
\ No newline at end of file
vllm/model_executor/layers/quantization/quant_utils.py
0 → 100644
View file @
583034f1
import
torch
@
torch
.
jit
.
script
def
cal_scale
(
amax
,
fp_max
,
scale
):
margin
=
0
exp
=
torch
.
floor
(
torch
.
log2
(
fp_max
/
amax
))
-
margin
sf
=
torch
.
round
(
torch
.
pow
(
2
,
torch
.
abs
(
exp
)))
sf
=
torch
.
where
(
amax
>
0.0
,
sf
,
scale
)
sf
=
torch
.
where
(
torch
.
isfinite
(
amax
),
sf
,
scale
)
scale
=
torch
.
where
(
exp
<
0
,
1
/
sf
,
sf
)
scale_inv
=
torch
.
reciprocal
(
scale
)
return
scale
,
scale_inv
instances
=
{}
def
singleton
(
cls
):
global
instances
def
get_instance
(
*
args
,
**
kwargs
):
if
cls
not
in
instances
:
instances
[
cls
]
=
cls
(
*
args
,
**
kwargs
)
return
instances
[
cls
]
return
get_instance
def
reset_singleton
():
global
instances
instances
=
{}
@
singleton
class
QuantFp8
:
def
__init__
(
self
,
device
):
self
.
fp_max
=
torch
.
tensor
([
448.0
],
device
=
device
)
self
.
device
=
device
self
.
scale
=
torch
.
tensor
([
1.0
],
device
=
self
.
device
)
pass
@
staticmethod
def
quantize_v1
(
weight
,
bits
):
if
bits
==
8
:
amax
=
weight
.
abs
().
max
()
fp_max
=
torch
.
tensor
([
448.0
]).
to
(
weight
.
device
)
margin
=
0
scale
=
torch
.
tensor
([
1.0
]).
to
(
weight
.
device
)
exp
=
torch
.
floor
(
torch
.
log2
(
fp_max
/
amax
))
-
margin
sf
=
torch
.
round
(
torch
.
pow
(
2
,
torch
.
abs
(
exp
)))
sf
=
torch
.
where
(
amax
>
0.0
,
sf
,
scale
)
sf
=
torch
.
where
(
torch
.
isfinite
(
amax
),
sf
,
scale
)
scale
=
torch
.
where
(
exp
<
0
,
1
/
sf
,
sf
)
qweight
=
(
weight
.
to
(
torch
.
float32
)
*
scale
).
to
(
torch
.
float8_e4m3fn
)
scale
=
torch
.
reciprocal
(
scale
)
# print(f"amax={amax},scalse={scale}")
else
:
raise
ValueError
(
f
"Unsupported bit width:
{
bits
}
"
)
return
qweight
,
scale
def
quantize
(
self
,
weight
,
bits
,
weight_scale
,
use_offline_input_scales
):
if
bits
==
8
:
amax
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
scale
=
torch
.
tensor
([
1.0
],
device
=
self
.
device
)
torch
.
ops
.
OptimusFp8
.
abs_max_nan_to_inf
(
weight
,
amax
)
if
weight_scale
is
None
or
not
use_offline_input_scales
:
scale
,
scale_inv
=
cal_scale
(
amax
,
self
.
fp_max
,
scale
)
else
:
scale
,
scale_inv
=
weight_scale
,
torch
.
reciprocal
(
weight_scale
)
qweight
=
torch
.
ops
.
OptimusFp8
.
quantize
(
weight
,
scale
,
None
,
torch
.
float8_e4m3fn
)
# print(f"scale={scale},self.amax={self.amax}")
return
qweight
,
scale_inv
else
:
raise
ValueError
(
f
"Unsupported bit width:
{
bits
}
"
)
def
get_quant_scale
(
self
,
tensor
):
amax
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
tensor
.
device
)
torch
.
ops
.
OptimusFp8
.
abs_max_nan_to_inf
(
tensor
,
amax
)
scale
,
_
=
cal_scale
(
amax
,
self
.
fp_max
,
self
.
scale
)
return
scale
def
quantize
(
weight
,
bits
,
weight_scale
=
None
,
use_offline_input_scales
=
True
):
quant
=
QuantFp8
(
weight
.
device
)
return
quant
.
quantize
(
weight
,
bits
,
weight_scale
,
use_offline_input_scales
)
def
dequant
(
weight
,
weight_scales
):
return
torch
.
ops
.
OptimusFp8
.
dequantize
(
weight
,
weight_scales
,
torch
.
bfloat16
)
def
experts_dequant
(
weights
,
weight_scales
):
ret
=
torch
.
empty
(
*
weights
.
shape
,
device
=
weights
.
device
,
dtype
=
torch
.
bfloat16
)
for
i
in
range
(
weights
.
shape
[
0
]):
ret
[
i
]
=
dequant
(
weights
[
i
],
weight_scales
[
i
])
return
ret
def
experts_quantize
(
weight
,
bits
):
if
bits
==
8
:
qweight_experts
=
torch
.
empty
(
*
weight
.
shape
,
dtype
=
torch
.
float8_e4m3fn
,
device
=
weight
.
device
)
scales
=
torch
.
empty
(
weight
.
shape
[
0
],
dtype
=
torch
.
float32
,
device
=
weight
.
device
)
for
idx
in
range
(
weight
.
shape
[
0
]):
expert_weight
=
weight
[
idx
]
qweight
,
scale
=
quantize
(
expert_weight
,
bits
)
qweight_experts
[
idx
]
=
qweight
scales
[
idx
]
=
scale
return
qweight_experts
,
scales
else
:
raise
ValueError
(
f
"Unsupported bit width:
{
bits
}
"
)
def
dynamic_fp8_pertensor_quantize
(
tensor
):
# amax = torch.empty(1, dtype=torch.float32, device=tensor.device)
# scale = torch.tensor([1.0], device=tensor.device)
# fp_max = torch.tensor([448.0], device=tensor.device)
# torch.ops.OptimusFp8.abs_max_nan_to_inf(tensor, amax)
# scale, _ = cal_scale(amax, fp_max, scale)
# return scale
quant
=
QuantFp8
(
tensor
.
device
)
return
quant
.
get_quant_scale
(
tensor
)
\ No newline at end of file
vllm/model_executor/model_loader/weight_utils.py
View file @
583034f1
...
@@ -797,3 +797,10 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
...
@@ -797,3 +797,10 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
# If there were no matches, return the untouched param name
# If there were no matches, return the untouched param name
return
name
return
name
def
fp8_input_scales_loader
(
path
:
str
):
with
safe_open
(
path
,
framework
=
"pt"
)
as
f
:
for
name
in
f
.
keys
():
# noqa: SIM118
param
=
f
.
get_slice
(
name
)
yield
name
,
param
vllm/model_executor/models/mm_step1o.py
0 → 100644
View file @
583034f1
# SPDX-License-Identifier: Apache-2.0
import
os
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
functools
import
cached_property
from
itertools
import
product
from
math
import
ceil
,
sqrt
from
typing
import
Any
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
,
Union
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
PIL
import
Image
from
transformers
import
BatchFeature
,
PretrainedConfig
,
TensorType
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
)
from
vllm.model_executor.layers.linear
import
RowParallelLinear
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolerOutput
,
PoolingType
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.step_encoder
import
StepCLIPVisionModel
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargs
,
NestedTensors
)
from
vllm.multimodal.parse
import
ImageSize
,
MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.step_image_preprocessor
import
StepPreprocessor
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
SentencePieceTokenizer
)
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.interfaces_base
import
VllmModelForPooling
from
.utils
import
(
flatten_bn
,
init_vllm_registered_model
,
is_pp_missing_parameter
,
maybe_prefix
,
merge_multimodal_embeddings
)
DEFAULT_HIGH_RESOLUTION
=
os
.
getenv
(
"VLLM_DEFAULT_HIGH_RESOLUTION"
,
"false"
).
lower
()
in
[
"true"
,
"1"
]
VISION_MODEL_USE_DP
=
os
.
getenv
(
"VLLM_VISION_MODEL_USE_DP"
,
"false"
).
lower
()
in
[
"true"
,
"1"
]
print
(
f
"DEFAULT_HIGH_RESOLUTION:
{
DEFAULT_HIGH_RESOLUTION
}
"
)
print
(
f
"VISION_MODEL_USE_DP:
{
VISION_MODEL_USE_DP
}
"
)
class
MMStep1oImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
pixel_values
:
torch
.
Tensor
# (batch_size * num_images, num_channels, height, width)
patch_pixel_values
:
Optional
[
torch
.
Tensor
]
# (batch_size * num_patches, num_channels, patch_size, patch_size)
num_patches
:
List
[
int
]
# (batch_size * num_patches)
class
MMStep1oImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
image_embeds
:
torch
.
Tensor
# (batch_size * num_images * image_feature_size, hidden_size)
MMStep1oImageInputs
=
Union
[
MMStep1oImagePixelInputs
,
MMStep1oImageEmbeddingInputs
]
ImageWithPatches
=
Tuple
[
Image
.
Image
,
list
[
Image
.
Image
],
list
[
int
]
|
None
]
class
ImagePatcher
:
def
determine_window_size
(
self
,
long
:
int
,
short
:
int
)
->
int
:
if
long
<=
728
:
return
short
if
long
/
short
>
1.5
else
0
return
min
(
short
,
504
)
if
long
/
short
>
4
else
504
def
slide_window
(
self
,
width
:
int
,
height
:
int
,
sizes
:
list
[
tuple
[
int
,
int
]],
steps
:
list
[
tuple
[
int
,
int
]],
img_rate_thr
:
float
=
0.6
,
)
->
Tuple
[
List
[
Tuple
[
int
,
int
,
int
,
int
]],
Tuple
[
int
,
int
]]:
assert
1
>=
img_rate_thr
>=
0
,
"The `in_rate_thr` should lie in 0~1"
windows
=
[]
# Sliding windows.
for
size
,
step
in
zip
(
sizes
,
steps
):
size_w
,
size_h
=
size
step_w
,
step_h
=
step
x_num
=
1
if
width
<=
size_w
else
ceil
((
width
-
size_w
)
/
step_w
+
1
)
x_start
=
[
step_w
*
i
for
i
in
range
(
x_num
)]
if
len
(
x_start
)
>
1
and
x_start
[
-
1
]
+
size_w
>
width
:
x_start
[
-
1
]
=
width
-
size_w
y_num
=
1
if
height
<=
size_h
else
ceil
((
height
-
size_h
)
/
step_h
+
1
)
y_start
=
[
step_h
*
i
for
i
in
range
(
y_num
)]
if
len
(
y_start
)
>
1
and
y_start
[
-
1
]
+
size_h
>
height
:
y_start
[
-
1
]
=
height
-
size_h
start
=
np
.
array
(
list
(
product
(
y_start
,
x_start
)),
dtype
=
int
)
start
[:,
[
0
,
1
]]
=
start
[:,
[
1
,
0
]]
windows
.
append
(
np
.
concatenate
([
start
,
start
+
size
],
axis
=
1
))
windows
=
np
.
concatenate
(
windows
,
axis
=
0
)
return
[(
int
(
box
[
0
]),
int
(
box
[
1
]),
int
(
box
[
2
]
-
box
[
0
]),
int
(
box
[
3
]
-
box
[
1
]))
for
box
in
windows
],
(
x_num
,
y_num
)
def
square_pad
(
self
,
img
:
Image
.
Image
)
->
Image
.
Image
:
w
,
h
=
img
.
size
if
w
==
h
:
return
img
size
=
max
(
w
,
h
)
padded
=
Image
.
new
(
img
.
mode
,
(
size
,
size
),
0
)
padded
.
paste
(
img
,
(
0
,
0
))
return
padded
def
get_image_size_for_padding
(
self
,
img_width
:
int
,
img_height
:
int
)
->
Tuple
[
int
,
int
]:
ratio
=
img_width
/
img_height
if
min
(
img_height
,
img_width
)
<
32
and
(
ratio
>
4
or
ratio
<
1
/
4
):
new_size
=
max
(
img_height
,
img_width
)
return
new_size
,
new_size
return
img_width
,
img_height
def
get_image_size_for_preprocess
(
self
,
img_width
:
int
,
img_height
:
int
)
->
Tuple
[
int
,
int
]:
if
max
(
img_height
,
img_width
)
>
3024
:
scale_factor
=
3024
/
max
(
img_height
,
img_width
)
img_width
=
int
(
img_width
*
scale_factor
)
img_height
=
int
(
img_height
*
scale_factor
)
return
img_width
,
img_height
else
:
return
img_width
,
img_height
def
get_image_size_for_crop
(
self
,
img_width
:
int
,
img_height
:
int
,
window_size
:
int
):
w_ratio
=
img_width
/
window_size
h_ratio
=
img_height
/
window_size
if
w_ratio
<
1
:
width_new
=
img_width
else
:
xiaoshu_w
=
w_ratio
-
img_width
//
window_size
w_ratio
=
int
(
w_ratio
)
+
1
if
xiaoshu_w
>
0.2
else
int
(
w_ratio
)
width_new
=
window_size
*
w_ratio
if
h_ratio
<
1
:
height_new
=
img_height
else
:
xiaoshu_h
=
h_ratio
-
img_height
//
window_size
h_ratio
=
int
(
h_ratio
)
+
1
if
xiaoshu_h
>
0.2
else
int
(
h_ratio
)
height_new
=
window_size
*
h_ratio
return
int
(
width_new
),
int
(
height_new
)
def
patch_crop
(
self
,
img
:
Image
.
Image
,
i
:
int
,
j
:
int
,
th
:
int
,
tw
:
int
):
target
=
img
.
crop
((
j
,
i
,
j
+
tw
,
i
+
th
))
return
target
def
get_num_patches
(
self
,
img_width
:
int
,
img_height
:
int
)
->
Tuple
[
int
,
int
]:
img_width
,
img_height
=
self
.
get_image_size_for_padding
(
img_width
,
img_height
)
img_width
,
img_height
=
self
.
get_image_size_for_preprocess
(
img_width
,
img_height
)
window_size
=
self
.
determine_window_size
(
max
(
img_height
,
img_width
),
min
(
img_height
,
img_width
))
if
window_size
==
0
:
return
0
,
0
else
:
img_width
,
img_height
=
self
.
get_image_size_for_crop
(
img_width
,
img_height
,
window_size
)
center_list
,
(
x_num
,
y_num
)
=
self
.
slide_window
(
img_width
,
img_height
,
[(
window_size
,
window_size
)],
[(
window_size
,
window_size
)])
full_rows
=
(
len
(
center_list
)
-
1
)
//
x_num
+
1
if
len
(
center_list
)
>
0
and
len
(
center_list
)
%
x_num
==
0
:
full_rows
-=
1
return
len
(
center_list
),
full_rows
def
__call__
(
self
,
img
:
Image
.
Image
)
->
Tuple
[
Image
.
Image
,
List
[
Image
.
Image
],
List
[
bool
]
|
None
]:
img_width
,
img_height
=
img
.
size
new_img_width
,
new_img_height
=
self
.
get_image_size_for_padding
(
img_width
,
img_height
)
if
new_img_width
!=
img_width
or
new_img_height
!=
img_height
:
img
=
self
.
square_pad
(
img
)
img_width
,
img_height
=
img
.
size
new_img_width
,
new_img_height
=
self
.
get_image_size_for_preprocess
(
img_width
,
img_height
)
img
=
img
.
resize
((
new_img_width
,
new_img_height
),
Image
.
Resampling
.
BILINEAR
)
window_size
=
self
.
determine_window_size
(
max
(
new_img_height
,
new_img_width
),
min
(
new_img_height
,
new_img_width
))
if
window_size
==
0
:
return
img
,
[],
None
else
:
new_img_width
,
new_img_height
=
self
.
get_image_size_for_crop
(
new_img_width
,
new_img_height
,
window_size
)
if
(
new_img_width
,
new_img_height
)
!=
(
img_width
,
img_height
):
img_for_crop
=
img
.
resize
((
new_img_width
,
new_img_height
),
Image
.
Resampling
.
BILINEAR
)
else
:
img_for_crop
=
img
patches
=
[]
newlines
=
[]
center_list
,
(
x_num
,
y_num
)
=
self
.
slide_window
(
new_img_width
,
new_img_height
,
[(
window_size
,
window_size
)],
[(
window_size
,
window_size
)])
for
patch_id
,
center_lf_point
in
enumerate
(
center_list
):
x
,
y
,
patch_w
,
patch_h
=
center_lf_point
big_patch
=
self
.
patch_crop
(
img_for_crop
,
y
,
x
,
patch_h
,
patch_w
)
patches
.
append
(
big_patch
)
if
(
patch_id
+
1
)
%
x_num
==
0
:
newlines
.
append
(
patch_id
)
if
newlines
and
newlines
[
-
1
]
==
len
(
patches
)
-
1
:
newlines
.
pop
()
return
img
,
patches
,
[
i
in
newlines
for
i
in
range
(
len
(
patches
))
]
if
len
(
patches
)
>
0
else
None
class
Step1oProcessor
:
def
__init__
(
self
,
config
:
PretrainedConfig
,
tokenizer
:
AnyTokenizer
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
tokenizer
=
tokenizer
self
.
image_size
=
728
self
.
patch_size
=
504
self
.
image_preprocessor
=
StepPreprocessor
(
self
.
image_size
,
"bilinear"
,
self
.
patch_size
)
self
.
num_image_feature_size
=
169
self
.
num_patch_feature_size
=
81
self
.
image_token
=
"<im_patch>"
self
.
image_feature_placeholder
=
self
.
image_token
*
self
.
num_image_feature_size
self
.
patch_feature_placeholder
=
self
.
image_token
*
self
.
num_patch_feature_size
self
.
patcher
=
ImagePatcher
()
@
property
def
image_token_id
(
self
)
->
int
:
return
self
.
tokenizer
.
get_vocab
()[
self
.
image_token
]
def
get_num_image_tokens
(
self
,
img_width
:
int
,
img_height
:
int
,
detail
:
str
=
"auto"
)
->
int
:
if
detail
==
"high"
:
use_high_resolution
=
True
elif
detail
==
"low"
:
use_high_resolution
=
False
else
:
use_high_resolution
=
DEFAULT_HIGH_RESOLUTION
if
use_high_resolution
:
num_patches
,
num_newlines
=
self
.
patcher
.
get_num_patches
(
img_width
,
img_height
)
else
:
num_patches
=
0
num_newlines
=
0
return
num_patches
*
(
self
.
num_patch_feature_size
+
2
)
+
self
.
num_image_feature_size
+
2
+
num_newlines
def
_split_images
(
self
,
images
:
list
[
Image
.
Image
])
->
list
[
ImageWithPatches
]:
result
=
[]
for
img
in
images
:
detail
=
img
.
info
.
get
(
"detail"
,
None
)
if
detail
==
"high"
:
use_high_resolution
=
True
elif
detail
==
"low"
:
use_high_resolution
=
False
else
:
use_high_resolution
=
DEFAULT_HIGH_RESOLUTION
if
use_high_resolution
:
result
.
append
(
self
.
patcher
(
img
))
else
:
result
.
append
((
img
,
[],
None
))
return
result
def
_convert_images_to_pixel_values
(
self
,
images
:
list
[
Image
.
Image
],
is_patch
:
bool
=
False
,
)
->
list
[
torch
.
Tensor
]:
return
[
self
.
image_preprocessor
.
preprocess
(
img
,
is_patch
=
is_patch
)[
"pixel_values"
]
for
img
in
images
]
def
_get_patch_repl
(
self
,
num_patches
:
int
,
patch_newline_mask
:
list
[
bool
]
|
None
,
)
->
Tuple
[
str
,
list
[
int
]]:
text
=
""
token_ids
=
[]
for
i
in
range
(
num_patches
):
assert
len
(
patch_newline_mask
)
==
num_patches
text
+=
f
"<patch_start>
{
self
.
patch_feature_placeholder
}
<patch_end>"
token_ids
.
extend
(
[
self
.
tokenizer
.
convert_tokens_to_ids
(
"<patch_start>"
)]
+
[
self
.
image_token_id
]
*
self
.
num_patch_feature_size
+
[
self
.
tokenizer
.
convert_tokens_to_ids
(
"<patch_end>"
)])
if
patch_newline_mask
and
patch_newline_mask
[
i
]:
text
+=
"<patch_newline>"
token_ids
.
append
(
self
.
tokenizer
.
convert_tokens_to_ids
(
"<patch_newline>"
))
return
text
,
token_ids
def
_get_image_repl
(
self
,
num_images
:
int
,
)
->
Tuple
[
str
,
list
[
int
]]:
text
=
f
"<im_start>
{
self
.
image_feature_placeholder
}
<im_end>"
token_ids
=
[
self
.
tokenizer
.
convert_tokens_to_ids
(
"<im_start>"
)
]
+
[
self
.
image_token_id
]
*
self
.
num_image_feature_size
+
[
self
.
tokenizer
.
convert_tokens_to_ids
(
"<im_end>"
)
]
return
text
*
num_images
,
token_ids
*
num_images
def
_get_image_repl_features
(
self
,
num_images
:
int
,
num_patches
:
int
,
patch_new_line_idx
:
Optional
[
list
[
bool
]],
)
->
Tuple
[
str
,
list
[
int
]]:
if
num_patches
>
0
:
patch_repl
,
patch_repl_ids
=
self
.
_get_patch_repl
(
num_patches
,
patch_new_line_idx
)
else
:
patch_repl
=
""
patch_repl_ids
=
[]
image_repl
,
image_repl_ids
=
self
.
_get_image_repl
(
num_images
)
return
patch_repl
+
image_repl
,
patch_repl_ids
+
image_repl_ids
def
replace_placeholder
(
self
,
text
:
str
,
placeholder
:
str
,
repls
:
list
[
str
])
->
str
:
parts
=
text
.
split
(
placeholder
)
if
len
(
parts
)
-
1
!=
len
(
repls
):
raise
ValueError
(
"The number of placeholders does not match the number of replacements."
)
result
=
[
parts
[
0
]]
for
i
,
repl
in
enumerate
(
repls
):
result
.
append
(
repl
)
result
.
append
(
parts
[
i
+
1
])
return
""
.
join
(
result
)
def
__call__
(
self
,
text
:
Optional
[
Union
[
str
,
list
[
str
]]]
=
None
,
images
:
Optional
[
Union
[
Image
.
Image
,
list
[
Image
.
Image
]]]
=
None
,
return_tensors
:
Optional
[
Union
[
str
,
TensorType
]]
=
None
,
)
->
BatchFeature
:
if
text
is
None
:
text
=
[]
if
not
isinstance
(
text
,
list
):
text
=
[
text
]
if
images
is
None
:
images
=
[]
if
not
isinstance
(
images
,
list
):
images
=
[
images
]
if
len
(
images
)
==
0
:
image_inputs
=
{}
if
isinstance
(
self
.
tokenizer
,
SentencePieceTokenizer
):
assert
len
(
text
)
==
1
text_inputs
=
{
"input_ids"
:
torch
.
tensor
([
self
.
tokenizer
.
encode
(
text
[
0
],
add_special_tokens
=
True
)
],
dtype
=
torch
.
long
)
}
# step-tokenizer does not support text input for special tokens
else
:
text_inputs
=
self
.
tokenizer
(
text
)
else
:
splitted_images_data
=
self
.
_split_images
(
images
)
pixel_values_lst
=
[]
patch_pixel_values_lst
=
[]
patch_newline_mask_lst
=
[]
image_repl_str_lst
=
[]
image_repl_ids_lst
=
[]
num_patches
=
[]
for
raw_img
,
img_patches
,
patch_newline_mask
in
splitted_images_data
:
pixel_values_lst
.
extend
(
self
.
_convert_images_to_pixel_values
([
raw_img
]))
if
len
(
img_patches
)
>
0
:
patch_pixel_values_lst
.
extend
(
self
.
_convert_images_to_pixel_values
(
img_patches
,
is_patch
=
True
))
num_patches
.
append
(
len
(
img_patches
))
image_repl_str
,
image_repl_ids
=
self
.
_get_image_repl_features
(
1
,
len
(
img_patches
),
patch_newline_mask
)
image_repl_str_lst
.
append
(
image_repl_str
)
image_repl_ids_lst
.
extend
(
image_repl_ids
)
if
patch_newline_mask
is
not
None
:
patch_newline_mask_lst
.
extend
(
patch_newline_mask
)
image_inputs
=
{
"pixel_values"
:
torch
.
cat
(
pixel_values_lst
),
"num_patches"
:
num_patches
,
}
if
patch_pixel_values_lst
:
image_inputs
[
"patch_pixel_values"
]
=
torch
.
cat
(
patch_pixel_values_lst
)
if
patch_newline_mask_lst
:
image_inputs
[
"patch_newline_mask"
]
=
torch
.
tensor
(
patch_newline_mask_lst
,
dtype
=
torch
.
bool
)
if
isinstance
(
self
.
tokenizer
,
SentencePieceTokenizer
):
text_inputs
=
{
"input_ids"
:
torch
.
tensor
(
image_repl_ids_lst
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
}
# step-tokenizer does not support text input for special tokens
else
:
text
=
[
self
.
replace_placeholder
(
t
,
self
.
image_token
,
image_repl_str_lst
)
for
t
in
text
]
text_inputs
=
self
.
tokenizer
(
text
)
return
BatchFeature
(
{
**
text_inputs
,
**
image_inputs
,
},
tensor_type
=
return_tensors
,
)
class
Step1oProcessingInfo
(
BaseProcessingInfo
):
def
get_hf_processor
(
self
)
->
Step1oProcessor
:
return
Step1oProcessor
(
self
.
get_hf_config
(),
self
.
get_tokenizer
(),
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
}
def
get_max_image_tokens
(
self
)
->
int
:
hf_processor
=
self
.
get_hf_processor
()
return
hf_processor
.
get_num_image_tokens
(
self
.
get_image_size_with_most_features
().
width
,
self
.
get_image_size_with_most_features
().
height
)
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
self
.
get_max_image_tokens
()}
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
return
ImageSize
(
728
,
728
)
def
get_num_mm_tokens
(
self
,
mm_data
:
MultiModalDataDict
)
->
int
:
if
len
(
mm_data
)
!=
1
or
"image"
not
in
mm_data
:
raise
ValueError
(
"mm_data could only contain one key 'image' for steo1o"
)
image_data
=
mm_data
[
"image"
]
if
not
isinstance
(
image_data
,
(
list
,
tuple
)):
image_data
=
[
image_data
]
return
sum
(
self
.
get_hf_processor
().
get_num_image_tokens
(
img
.
width
,
img
.
height
,
detail
=
img
.
info
.
get
(
"detail"
,
None
))
for
img
in
image_data
)
class
Step1oDummyInputsBuilder
(
BaseDummyInputsBuilder
[
Step1oProcessingInfo
]):
# def get_dummy_processor_inputs(
# self,
# seq_len: int,
# mm_counts: Mapping[str, int],
# ) -> ProcessorInputs:
# target_width, target_height = \
# self.info.get_image_size_with_most_features()
# num_images = mm_counts.get("image", 0)
# mm_data = {
# "image":
# self._get_dummy_images(width=target_width,
# height=target_height,
# num_images=num_images)
# }
# return ProcessorInputs(
# prompt_text="<im_patch>" * num_images,
# mm_data=mm_data,
# )
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
return
"<im_patch>"
*
num_images
def
get_dummy_mm_data
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
MultiModalDataDict
:
target_width
,
target_height
=
\
self
.
info
.
get_image_size_with_most_features
()
num_images
=
mm_counts
.
get
(
"image"
,
0
)
return
{
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
num_images
=
num_images
)
}
class
Step1oMultiModalProcessor
(
BaseMultiModalProcessor
[
Step1oProcessingInfo
]):
def
_get_prompt_updates
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
Any
],
out_mm_kwargs
:
MultiModalKwargs
,
)
->
Sequence
[
PromptUpdate
]:
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
image_placeholder_token_id
=
hf_processor
.
image_token_id
batch_num_patches
=
out_mm_kwargs
[
"num_patches"
].
tolist
()
def
get_replacement_step1o
(
item_idx
:
int
):
img_out
=
out_mm_kwargs
.
get_item
(
"image"
,
item_idx
)
num_patches
=
batch_num_patches
[
item_idx
]
if
num_patches
>
0
:
patch_newline_mask
=
img_out
[
"patch_newline_mask"
].
data
.
tolist
(
)
image_repl_ids
=
hf_processor
.
_get_image_repl_features
(
1
,
num_patches
,
patch_newline_mask
)[
1
]
else
:
image_repl_ids
=
hf_processor
.
_get_image_repl_features
(
1
,
0
,
None
)[
1
]
return
PromptUpdateDetails
.
select_token_id
(
seq
=
image_repl_ids
,
embed_token_id
=
image_placeholder_token_id
,
)
return
[
PromptReplacement
(
modality
=
"image"
,
target
=
[
image_placeholder_token_id
],
replacement
=
get_replacement_step1o
,
)
]
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
num_patches
=
hf_inputs
.
get
(
"num_patches"
,
torch
.
empty
(
0
))
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
patch_pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
num_patches
),
num_patches
=
MultiModalFieldConfig
.
batched
(
"image"
),
patch_newline_mask
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
num_patches
),
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
Step1oMultiModalProcessor
,
info
=
Step1oProcessingInfo
,
dummy_inputs
=
Step1oDummyInputsBuilder
)
class
MMGPTStep1oForCausalLM
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
vision_model
=
StepCLIPVisionModel
(
config
.
vision_tower_config
,
None
,
prefix
=
maybe_prefix
(
prefix
,
"vision_model"
),
need_dp
=
VISION_MODEL_USE_DP
)
self
.
vit_downsampler
=
nn
.
Conv2d
(
config
.
vision_tower_config
.
hidden_size
,
config
.
vision_tower_config
.
output_hidden_size
,
kernel_size
=
2
,
stride
=
config
.
understand_projector_stride
)
self
.
vit_downsampler2
=
nn
.
Conv2d
(
config
.
vision_tower_config
.
output_hidden_size
,
config
.
vision_tower_config
.
output_hidden_size
*
2
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
)
self
.
vit_large_projector
=
nn
.
Linear
(
config
.
vision_tower_config
.
output_hidden_size
*
2
,
config
.
hidden_size
,
bias
=
config
.
projector_bias
,
)
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
))
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
@
cached_property
def
sampler
(
self
):
if
hasattr
(
self
.
language_model
,
"sampler"
):
return
self
.
language_model
.
sampler
return
get_sampler
()
@
property
def
device
(
self
):
return
next
(
self
.
parameters
()).
device
@
property
def
dtype
(
self
):
return
next
(
self
.
parameters
()).
dtype
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
MMStep1oImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
patch_pixel_values
=
kwargs
.
pop
(
"patch_pixel_values"
,
None
)
num_patches
=
kwargs
.
pop
(
"num_patches"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
if
pixel_values
is
None
and
image_embeds
is
None
:
return
None
if
pixel_values
is
not
None
:
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
)
if
pixel_values
.
dim
()
>=
3
:
pixel_values
=
pixel_values
.
view
(
-
1
,
*
pixel_values
.
shape
[
-
3
:])
if
patch_pixel_values
is
not
None
:
patch_pixel_values
=
flatten_bn
(
patch_pixel_values
,
concat
=
True
)
patch_pixel_values
=
patch_pixel_values
.
view
(
-
1
,
*
patch_pixel_values
.
shape
[
-
3
:])
# Handle empty patch_pixel_values by setting to None
if
patch_pixel_values
.
shape
[
0
]
==
0
:
patch_pixel_values
=
None
num_patches
=
flatten_bn
(
num_patches
,
concat
=
True
).
tolist
()
return
MMStep1oImagePixelInputs
(
type
=
"pixel_values"
,
pixel_values
=
pixel_values
.
to
(
self
.
dtype
).
to
(
self
.
device
),
patch_pixel_values
=
patch_pixel_values
.
to
(
self
.
dtype
).
to
(
self
.
device
)
if
patch_pixel_values
is
not
None
else
None
,
num_patches
=
num_patches
,
)
if
image_embeds
is
not
None
:
if
image_embeds
.
dim
()
==
2
or
image_embeds
.
dim
()
>=
3
:
image_embeds
=
image_embeds
.
view
(
-
1
,
image_embeds
.
shape
[
-
1
])
else
:
raise
ValueError
(
f
"Unexpected shape for image_embeds:
{
image_embeds
.
shape
}
"
)
return
MMStep1oImageEmbeddingInputs
(
type
=
"image_embeds"
,
image_embeds
=
image_embeds
.
to
(
self
.
dtype
).
to
(
self
.
device
),
)
return
None
def
_process_image_features
(
self
,
image_features
:
torch
.
Tensor
)
->
torch
.
Tensor
:
B
,
P
=
image_features
.
shape
[:
2
]
HW
=
int
(
sqrt
(
P
))
image_features
=
image_features
.
permute
(
0
,
2
,
1
).
view
(
B
,
-
1
,
HW
,
HW
)
image_features
=
self
.
vit_downsampler
(
image_features
)
image_features
=
self
.
vit_downsampler2
(
image_features
)
n_dim
=
image_features
.
size
(
1
)
image_features
=
image_features
.
view
(
B
,
n_dim
,
-
1
).
permute
(
0
,
2
,
1
)
image_features
=
self
.
vit_large_projector
(
image_features
)
return
image_features
def
_get_vision_model_output
(
self
,
input_tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
VISION_MODEL_USE_DP
and
get_tensor_model_parallel_world_size
()
>
1
:
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
batch_size
=
input_tensor
.
shape
[
0
]
chunk_size
=
(
batch_size
+
tp_size
-
1
)
//
tp_size
start_idx
=
tp_rank
*
chunk_size
end_idx
=
min
(
start_idx
+
chunk_size
,
batch_size
)
local_input_tensor
=
torch
.
empty
(
chunk_size
,
*
input_tensor
.
shape
[
1
:],
dtype
=
input_tensor
.
dtype
,
device
=
input_tensor
.
device
)
if
end_idx
>
start_idx
:
local_input_tensor
[:
end_idx
-
start_idx
].
copy_
(
input_tensor
[
start_idx
:
end_idx
])
local_features
=
self
.
vision_model
(
local_input_tensor
)[
0
][:,
4
:]
total_features
=
tensor_model_parallel_all_gather
(
local_features
.
contiguous
(),
dim
=
0
)
return
total_features
[:
batch_size
]
else
:
return
self
.
vision_model
(
input_tensor
)[
0
][:,
4
:]
def
_process_image_input
(
self
,
image_input
:
MMStep1oImageInputs
)
->
tuple
[
torch
.
Tensor
,
...]:
if
image_input
[
"type"
]
==
"image_embeds"
:
image_features
=
image_input
[
"image_embeds"
]
else
:
image_features
=
self
.
_get_vision_model_output
(
image_input
[
"pixel_values"
])
patch_image_features
=
self
.
_get_vision_model_output
(
image_input
[
"patch_pixel_values"
])
if
image_input
[
"patch_pixel_values"
]
is
not
None
else
None
num_patches
=
image_input
[
"num_patches"
]
image_features
=
self
.
_process_image_features
(
image_features
)
patch_image_features
=
self
.
_process_image_features
(
patch_image_features
)
if
patch_image_features
is
not
None
else
None
merged_image_features
=
[]
cur_patch_idx
=
0
for
i
,
num_patch
in
enumerate
(
num_patches
):
cur_feature
=
[]
if
num_patch
>
0
:
patch_slice
=
patch_image_features
[
cur_patch_idx
:
cur_patch_idx
+
num_patch
]
cur_feature
.
append
(
patch_slice
.
view
(
-
1
,
patch_slice
.
shape
[
-
1
]))
cur_feature
.
append
(
image_features
[
i
].
view
(
-
1
,
image_features
.
shape
[
-
1
]))
cur_patch_idx
+=
num_patch
merged_image_features
.
append
(
torch
.
cat
(
cur_feature
)
if
len
(
cur_feature
)
>
1
else
cur_feature
[
0
])
return
merged_image_features
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
return
vision_embeddings
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
,
vision_embeddings
:
Optional
[
NestedTensors
]
=
None
,
)
->
torch
.
Tensor
:
if
vision_embeddings
is
None
:
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
else
:
is_text
=
input_ids
!=
self
.
config
.
image_token_id
text_ids
=
input_ids
[
is_text
]
text_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
text_ids
)
inputs_embeds
=
torch
.
empty
(
input_ids
.
shape
[
0
],
text_embeds
.
shape
[
-
1
],
dtype
=
text_embeds
.
dtype
,
device
=
text_embeds
.
device
)
inputs_embeds
[
is_text
]
=
text_embeds
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
config
.
image_token_id
)
return
inputs_embeds
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
intermediate_tensors
is
not
None
:
inputs_embeds
=
None
elif
inputs_embeds
is
None
:
vision_embeddings
=
self
.
get_multimodal_embeddings
(
**
kwargs
)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
,
vision_embeddings
)
input_ids
=
None
hidden_states
=
self
.
language_model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
maybe_remap_params
(
self
,
name
):
if
name
.
startswith
(
"model."
):
name
=
name
.
replace
(
"model."
,
"language_model.model."
)
if
name
.
startswith
(
"lm_head"
):
name
=
name
.
replace
(
"lm_head"
,
"language_model.lm_head"
)
if
name
.
startswith
(
"vision_model."
):
name
=
name
.
replace
(
"vision_model."
,
"vision_model.vision_model."
)
return
name
def
load_weights_1o
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
()
for
name
,
loaded_weight
in
weights
:
name
=
self
.
maybe_remap_params
(
name
)
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
loaded_params
.
add
(
name
)
break
else
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
params_need_to_load
=
[]
for
name
in
params_dict
:
params_need_to_load
.
append
(
name
)
params_need_to_load
=
set
(
params_need_to_load
)
if
params_need_to_load
!=
loaded_params
:
param_name_example
=
list
(
params_need_to_load
-
loaded_params
)[
0
]
raise
RuntimeError
(
f
"Some parameters like
{
param_name_example
}
are not in the checkpoint and will falsely use random initialization"
)
def
load_weights_3v
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
qkv_params_mapping
=
[
# (param_name, shard_name, relative_start_idx, relative_end_idx)
(
".qkv_proj"
,
".q_proj"
,
0
,
self
.
config
.
share_q_dim
/
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
)),
(
".qkv_proj"
,
".k_proj"
,
self
.
config
.
share_q_dim
/
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
),
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
)
/
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
)),
(
".qkv_proj"
,
".v_proj"
,
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
)
/
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
),
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
)
/
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
)),
]
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
()
params_need_to_load
=
set
()
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
language_model
.
model
.
config
.
moe_num_experts
)
if
self
.
language_model
.
model
.
use_fused_moe
:
quant_config
=
self
.
language_model
.
model
.
vllm_config
.
quant_config
if
quant_config
is
not
None
and
quant_config
.
get_name
()
==
"groupwise_quant"
:
expert_params_mapping
=
[
(
".moe.experts.w13_weight"
,
".moe.gate_proj.qweight"
,
"w1"
),
(
".moe.experts.w13_weight"
,
".moe.up_proj.qweight"
,
"w3"
),
(
".moe.experts.w2_weight"
,
".moe.down_proj.qweight"
,
"w2"
),
(
".moe.experts.w13_weight_scale"
,
".moe.gate_proj.scales"
,
"w1"
),
(
".moe.experts.w13_weight_scale"
,
".moe.up_proj.scales"
,
"w3"
),
(
".moe.experts.w2_weight_scale"
,
".moe.down_proj.scales"
,
"w2"
),
]
else
:
expert_params_mapping
=
[
(
".moe.experts.w13_weight"
,
".moe.gate_proj.weight"
,
"w1"
),
(
".moe.experts.w13_weight"
,
".moe.up_proj.weight"
,
"w3"
),
(
".moe.experts.w2_weight"
,
".moe.down_proj.weight"
,
"w2"
)
]
else
:
expert_params_mapping
=
[]
disable_moe_stacked_params
=
[
data
[
1
]
for
data
in
expert_params_mapping
]
for
name
,
loaded_weight
in
weights
:
name
=
self
.
maybe_remap_params
(
name
)
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
if
any
(
disable_moe_stacked_param
in
name
for
disable_moe_stacked_param
in
disable_moe_stacked_params
):
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
loaded_params
.
add
(
name
)
break
else
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
# Skip loading extra bias for GPTQ models.
if
((
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
))
and
name
not
in
params_dict
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
for
expert_id
in
range
(
loaded_weight
.
shape
[
0
]):
loaded_weight_expert
=
loaded_weight
[
expert_id
]
weight_loader
(
param
,
loaded_weight_expert
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
loaded_params
.
add
(
name
)
break
else
:
for
(
param_name
,
weight_name
,
start_idx
,
end_idx
)
in
qkv_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
dim
=
param
.
shape
[
param
.
output_dim
]
begin_idx
=
int
(
start_idx
*
dim
)
end_idx
=
int
(
end_idx
*
dim
)
param_slice
=
param
.
narrow
(
param
.
output_dim
,
begin_idx
,
end_idx
-
begin_idx
)
param_slice
.
copy_
(
loaded_weight
)
loaded_params
.
add
(
name
)
break
else
:
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
params_need_to_load
=
[]
for
name
in
params_dict
:
params_need_to_load
.
append
(
name
)
params_need_to_load
=
set
(
params_need_to_load
)
if
params_need_to_load
!=
loaded_params
:
param_name_example
=
list
(
params_need_to_load
-
loaded_params
)[
0
]
raise
RuntimeError
(
f
"Some parameters like
{
param_name_example
}
are not in the checkpoint and will falsely use random initialization"
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
if
self
.
config
.
model_type
in
[
"step1o"
,
"mmgpt_qwen2_v2"
]:
self
.
load_weights_1o
(
weights
)
elif
self
.
config
.
model_type
==
"step3v"
:
self
.
load_weights_3v
(
weights
)
else
:
raise
ValueError
(
f
"Unsupported model type:
{
self
.
multimodal_config
.
model_type
}
"
)
class
MMGPTStep1oRewardModel
(
MMGPTStep1oForCausalLM
,
VllmModelForPooling
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
,
"Pooler config must be provided for classification models"
# Remove attributes specific to CausalLM if they exist directly on self
# (They are typically part of language_model)
for
attr
in
(
"sampler"
,
"lm_head"
):
if
hasattr
(
self
.
language_model
,
attr
):
delattr
(
self
.
language_model
,
attr
)
# Initialize the classification score head
self
.
score
=
RowParallelLinear
(
config
.
text_config
.
hidden_size
,
config
.
num_labels
,
# Assumes num_labels is in the main config
quant_config
=
quant_config
,
input_is_parallel
=
False
,
bias
=
False
,
prefix
=
maybe_prefix
(
prefix
,
"score"
))
# Initialize the pooler
# Use LAST pooling, no normalization, apply softmax (typical for classification)
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
ALL
,
normalize
=
False
,
softmax
=
False
,
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
)
->
torch
.
Tensor
:
# Get hidden states from the base model (without the LM head)
hidden_states
=
super
().
forward
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
**
kwargs
)
# Apply the classification head
logits
,
_
=
self
.
score
(
hidden_states
)
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# Filter out lm_head weights before passing to the base loader
weights_iterator
=
((
name
,
data
)
for
name
,
data
in
weights
if
"language_model.lm_head."
not
in
name
)
# Use the base class's load_weights logic, which now includes
# handling for the 'score' layer via maybe_remap_params
super
().
load_weights
(
weights_iterator
)
\ No newline at end of file
vllm/model_executor/models/registry.py
View file @
583034f1
...
@@ -134,6 +134,11 @@ _TEXT_GENERATION_MODELS = {
...
@@ -134,6 +134,11 @@ _TEXT_GENERATION_MODELS = {
# [Encoder-decoder]
# [Encoder-decoder]
"BartModel"
:
(
"bart"
,
"BartForConditionalGeneration"
),
"BartModel"
:
(
"bart"
,
"BartForConditionalGeneration"
),
"BartForConditionalGeneration"
:
(
"bart"
,
"BartForConditionalGeneration"
),
"BartForConditionalGeneration"
:
(
"bart"
,
"BartForConditionalGeneration"
),
# step model
"Step1ForCausalLM"
:
(
"step1"
,
"Step1ForCausalLM"
),
"Step2ForCausalLM"
:
(
"step1"
,
"Step1ForCausalLM"
),
"Step1MoEForCausalLM"
:
(
"step1"
,
"Step1ForCausalLM"
),
"Step2MiniForCausalLM"
:
(
"step2_mini"
,
"Step2MiniForCausalLM"
),
}
}
_EMBEDDING_MODELS
=
{
_EMBEDDING_MODELS
=
{
...
@@ -174,6 +179,19 @@ _EMBEDDING_MODELS = {
...
@@ -174,6 +179,19 @@ _EMBEDDING_MODELS = {
# input and output. I am adding it here because it piggy-backs on embedding
# input and output. I am adding it here because it piggy-backs on embedding
# models for the time being.
# models for the time being.
"PrithviGeoSpatialMAE"
:
(
"prithvi_geospatial_mae"
,
"PrithviGeoSpatialMAE"
),
"PrithviGeoSpatialMAE"
:
(
"prithvi_geospatial_mae"
,
"PrithviGeoSpatialMAE"
),
# step model
"Step1ForSequenceClassification"
:
(
"step1"
,
"Step1ForSequenceClassification"
),
"Step2ForClassification"
:
(
"step1"
,
"Step1ForSequenceClassification"
),
"Step2ForSequenceClassification"
:
(
"step2"
,
"Step2ForSequenceClassification"
),
"Step2MiniForClassification"
:
(
"step2_mini"
,
"Step2MiniForSequenceClassification"
),
"MMGPTQwen2RewardModel"
:
(
"mm_step1o"
,
"MMGPTStep1oRewardModel"
),
# Technically PrithviGeoSpatialMAE is a model that works on images, both in
# input and output. I am adding it here because it piggy-backs on embedding
# models for the time being.
"PrithviGeoSpatialMAE"
:
(
"prithvi_geospatial_mae"
,
"PrithviGeoSpatialMAE"
),
}
}
_CROSS_ENCODER_MODELS
=
{
_CROSS_ENCODER_MODELS
=
{
...
@@ -251,6 +269,15 @@ _SPECULATIVE_DECODING_MODELS = {
...
@@ -251,6 +269,15 @@ _SPECULATIVE_DECODING_MODELS = {
"Glm4MoeMTPModel"
:
(
"glm4_moe_mtp"
,
"Glm4MoeMTP"
),
"Glm4MoeMTPModel"
:
(
"glm4_moe_mtp"
,
"Glm4MoeMTP"
),
"MedusaModel"
:
(
"medusa"
,
"Medusa"
),
"MedusaModel"
:
(
"medusa"
,
"Medusa"
),
"MLPSpeculatorPreTrainedModel"
:
(
"mlp_speculator"
,
"MLPSpeculator"
),
"MLPSpeculatorPreTrainedModel"
:
(
"mlp_speculator"
,
"MLPSpeculator"
),
# step model
"MMGPTStep1ForCausalLMV2"
:
(
"mm_step1p5c_1u"
,
"MMGPTStep1ForCausalLMV2"
),
"MMGPTStep1ForCausalLMV3"
:
(
"mm_step1p5c_1u"
,
"MMGPTStep1ForCausalLMV3"
),
"MMGPTStep1ForCausalLMV4"
:
(
"mm_step1o"
,
"MMGPTStep1oForCausalLM"
),
"MMGPTQwen2ForCausalLM"
:
(
"mm_step1p5c_1u"
,
"MMGPTStep1ForCausalLMV3"
),
"MMGPTQwen2ForCausalLMV2"
:
(
"mm_step1o"
,
"MMGPTStep1oForCausalLM"
),
"MMGPTStep3vForCausalLM"
:
(
"mm_step1o"
,
"MMGPTStep1oForCausalLM"
),
"Step1AudioForCausalLM"
:
(
"mm_step_audio"
,
"MMGPTStep1fForCausalLM"
),
"StepAudioForCausalLMV2"
:
(
"mm_step_audio"
,
"MMGPTStep1fForCausalLM"
),
}
}
_TRANSFORMERS_MODELS
=
{
_TRANSFORMERS_MODELS
=
{
...
...
vllm/model_executor/models/step1.py
0 → 100644
View file @
583034f1
# SPDX-License-Identifier: Apache-2.0
import
math
import
os
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
# from optimus import moe_expert_histogram as optimus_moe_expert_histogram
# from optimus import moe_gather as optimus_moe_gather
# from optimus import moe_scatter as optimus_moe_scatter
from
torch
import
nn
from
vllm.attention
import
Attention
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.activation
import
OptimusSiluAndMul
,
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
OptimusRMSNorm
,
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
MergedColumnParallelMoELinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
RowParallelMoELinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.quant_utils
import
(
dynamic_fp8_pertensor_quantize
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
fp8_input_scales_loader
)
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
.interfaces
import
SupportsPP
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
DISABLE_SEQUENCE_PARALLEL
=
True
# FIXME: os.getenv("DISABLE_SEQUENCE_PARALLEL", "0") == "1"
SEQUENCE_PARALLEL_THRESHOLD
=
512
if
os
.
getenv
(
"SEQUENCE_PARALLEL_THRESHOLD"
,
"0"
)
==
"0"
else
int
(
os
.
getenv
(
"SEQUENCE_PARALLEL_THRESHOLD"
))
GEMM_COMM_OVERLAP_RATIO
=
0.5
MLP_BATCH_SIZE
=
8192
def
_get_alibi_slopes
(
n_heads
):
n
=
2
**
math
.
floor
(
math
.
log2
(
n_heads
))
# nearest 2**n to n_heads
m0
=
2.0
**
(
-
8.0
/
n
)
slopes
=
np
.
power
(
m0
,
np
.
arange
(
1
,
n
+
1
))
if
n
<
n_heads
:
m1
=
2.0
**
(
-
4.0
/
n
)
mm
=
np
.
power
(
m1
,
np
.
arange
(
1
,
1
+
2
*
(
n_heads
-
n
),
2
))
slopes
=
np
.
concatenate
([
slopes
,
mm
])
return
slopes
def
_get_ntk_alibi_slopes
(
max_pos_interp_ratio
,
slopes
):
if
max_pos_interp_ratio
==
1.0
:
return
slopes
smax
,
smin
=
slopes
.
max
(),
slopes
.
min
()
D0
=
np
.
log2
(
smax
)
-
np
.
log2
(
smin
)
W1
=
(
np
.
log2
(
smax
)
-
np
.
log2
(
slopes
))
/
D0
ratios
=
np
.
power
(
max_pos_interp_ratio
,
W1
)
return
slopes
/
(
ratios
**
0.5
)
class
Step1MoEMLP
(
nn
.
Module
):
def
__init__
(
self
,
num_experts
:
int
,
top_k
:
int
,
top_p
:
float
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
=
""
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
norm_expert_weight
=
True
,
prefix
:
str
=
""
,
enable_cudagraph
:
bool
=
False
):
super
().
__init__
()
self
.
gate
=
ReplicatedLinear
(
input_size
=
hidden_size
,
output_size
=
num_experts
,
bias
=
False
,
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.gate"
)
self
.
top_k
=
top_k
self
.
top_p
=
top_p
self
.
num_experts
=
num_experts
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
(
)
assert
intermediate_size
%
tensor_model_parallel_world_size
==
0
self
.
gate_up_proj
=
MergedColumnParallelMoELinear
(
num_experts
,
hidden_size
,
[
intermediate_size
]
*
2
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
if
(
intermediate_size
/
tensor_model_parallel_world_size
)
%
64
==
0
:
self
.
act_fn
=
OptimusSiluAndMul
()
else
:
self
.
act_fn
=
SiluAndMul
()
self
.
down_proj
=
RowParallelMoELinear
(
num_experts
,
intermediate_size
,
hidden_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
)
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
quant_config
=
quant_config
self
.
norm_expert_weight
=
norm_expert_weight
self
.
enable_cudagraph
=
enable_cudagraph
self
.
need_fp32_gate
=
False
def
get_expert_output
(
self
,
inputs
:
torch
.
Tensor
,
expert_token_cnt
:
torch
.
Tensor
,
token_nums
:
int
):
if
self
.
quant_config
and
getattr
(
self
.
gate_up_proj
.
quant_method
,
"quant_config"
,
None
)
and
getattr
(
self
.
down_proj
.
quant_method
,
"quant_config"
,
None
):
if
inputs
.
size
(
0
)
<=
1024
and
self
.
gate_up_proj
.
quant_method
.
quant_config
.
weight_bits
==
8
and
self
.
down_proj
.
quant_method
.
quant_config
.
weight_bits
==
8
:
if
self
.
enable_cudagraph
:
tmp
=
torch
.
ops
.
Optimus
.
MoeFpAIntBGemm
(
inputs
,
self
.
gate_up_proj
.
qweight
,
self
.
gate_up_proj
.
qweight
.
dtype
,
self
.
gate_up_proj
.
scales
,
expert_token_cnt
,
token_nums
,
None
)
tmp
=
self
.
act_fn
(
tmp
)
tmp
=
torch
.
ops
.
Optimus
.
MoeFpAIntBGemm
(
tmp
,
self
.
down_proj
.
qweight
,
self
.
down_proj
.
qweight
.
dtype
,
self
.
down_proj
.
scales
,
expert_token_cnt
,
token_nums
,
None
)
return
tmp
else
:
quant_output_
=
torch
.
ops
.
OptimusMoe
.
moe_ffn_quant
(
inputs
,
self
.
gate_up_proj
.
qweight
.
dtype
,
self
.
gate_up_proj
.
qweight
,
self
.
gate_up_proj
.
scales
,
self
.
down_proj
.
qweight
,
self
.
down_proj
.
scales
,
expert_token_cnt
,
token_nums
,
out
=
inputs
)
return
quant_output_
else
:
expert_token_cnt
=
expert_token_cnt
.
to
(
"cpu"
).
tolist
()
start
=
0
end
=
0
if
getattr
(
self
.
gate_up_proj
.
quant_method
,
"quant_config"
,
None
)
and
self
.
gate_up_proj
.
quant_method
.
quant_config
.
weight_bits
==
6
:
output
=
torch
.
empty_like
(
inputs
,
dtype
=
torch
.
bfloat16
,
device
=
inputs
.
device
)
else
:
output
=
inputs
for
i
in
range
(
len
(
expert_token_cnt
)):
cur_token_cnt
=
expert_token_cnt
[
i
]
if
(
cur_token_cnt
<=
0
):
continue
end
+=
cur_token_cnt
tmp
=
self
.
gate_up_proj
(
inputs
[
start
:
end
],
expert_idx
=
i
)
tmp
=
self
.
act_fn
(
tmp
)
tmp
=
self
.
down_proj
(
tmp
,
expert_idx
=
i
,
output
=
output
[
start
:
end
])
start
+=
cur_token_cnt
return
output
else
:
moe_output
=
torch
.
ops
.
OptimusMoe
.
moe_ffn
(
inputs
,
self
.
gate_up_proj
.
weight
,
self
.
down_proj
.
weight
,
expert_token_cnt
,
token_nums
)
return
moe_output
def
forward
(
self
,
x
,
residual
=
None
,
layernorm
=
None
,
disable_allreduce
=
False
,
user_output
=
None
,
):
if
layernorm
is
not
None
:
x
=
layernorm
(
x
,
fp16_out
=
getattr
(
self
.
gate_up_proj
.
quant_method
,
"quant_config"
,
None
)
and
self
.
gate_up_proj
.
quant_method
.
quant_config
.
weight_bits
==
6
if
self
.
gate_up_proj
.
quant_method
else
False
)
x_shape
=
x
.
shape
if
self
.
need_fp32_gate
:
if
getattr
(
self
.
gate_up_proj
.
quant_method
,
"quant_config"
,
None
)
and
self
.
gate_up_proj
.
quant_method
.
quant_config
.
weight_bits
==
6
:
logits
=
torch
.
ops
.
OptimusMoe
.
matmul_fp32
(
x
.
to
(
torch
.
bfloat16
),
self
.
gate
.
weight
.
t
())
else
:
logits
=
torch
.
ops
.
OptimusMoe
.
matmul_fp32
(
x
,
self
.
gate
.
weight
.
t
())
else
:
logits
=
self
.
gate
(
x
)[
0
]
# if self.top_p < 1.0:
# top_k_index, expert_weight, scatter_index = torch.ops.OptimusMoe.topk_topp_gating(
# logits, self.top_k, self.top_p, self.norm_expert_weight)
# expert_token_cnt = optimus_moe_expert_histogram(
# top_k_index, self.num_experts)
# scatter_index = torch.ops.OptimusMoe.index_compute(
# top_k_index, expert_token_cnt, out=scatter_index)
# mid_output = optimus_moe_scatter(x, scatter_index)
# expert_output = self.get_expert_output(mid_output,
# expert_token_cnt,
# x_shape[0])
# output = optimus_moe_gather(expert_output, scatter_index,
# expert_weight)
# else:
# expert_weight, expert_token_cnt, scatter_index = torch.ops.OptimusMoe.gating_histogram_index(
# logits, self.top_k, 1.0, self.norm_expert_weight)
# mid_output = optimus_moe_scatter(x, scatter_index)
# expert_output = self.get_expert_output(mid_output,
# expert_token_cnt,
# x_shape[0])
# output = optimus_moe_gather(expert_output, scatter_index,
# expert_weight)
if
self
.
top_p
<
1.0
:
top_k_index
,
expert_weight
,
scatter_index
=
torch
.
ops
.
OptimusMoe
.
topk_topp_gating
(
logits
,
self
.
top_k
,
self
.
top_p
,
self
.
norm_expert_weight
)
expert_token_cnt
=
torch
.
ops
.
OptimusMoe
.
expert_histogram
(
top_k_index
,
self
.
num_experts
)
scatter_index
=
torch
.
ops
.
OptimusMoe
.
index_compute
(
top_k_index
,
expert_token_cnt
,
out
=
scatter_index
)
mid_output
=
torch
.
ops
.
OptimusMoe
.
scatter
(
x
,
scatter_index
)
expert_output
=
self
.
get_expert_output
(
mid_output
,
expert_token_cnt
,
x_shape
[
0
])
output
=
torch
.
ops
.
OptimusMoe
.
gather
(
expert_output
,
scatter_index
,
expert_weight
)
else
:
expert_weight
,
expert_token_cnt
,
scatter_index
=
torch
.
ops
.
OptimusMoe
.
gating_histogram_index
(
logits
,
self
.
top_k
,
1.0
,
self
.
norm_expert_weight
)
mid_output
=
torch
.
ops
.
OptimusMoe
.
scatter
(
x
,
scatter_index
)
expert_output
=
self
.
get_expert_output
(
mid_output
,
expert_token_cnt
,
x_shape
[
0
])
output
=
torch
.
ops
.
OptimusMoe
.
gather
(
expert_output
,
scatter_index
,
expert_weight
)
if
self
.
tp_rank
==
0
and
residual
is
not
None
:
output
+=
residual
if
not
disable_allreduce
:
output
=
tensor_model_parallel_all_reduce
(
output
)
if
user_output
is
not
None
:
user_output
.
copy_
(
output
)
return
output
class
Step1MLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
use_optimus_silu
:
bool
=
True
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
if
use_optimus_silu
:
self
.
act_fn
=
OptimusSiluAndMul
()
else
:
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
,
residual
=
None
,
layernorm
=
None
,
disable_allreduce
=
False
,
user_output
=
None
):
if
layernorm
is
not
None
:
x
=
layernorm
(
x
,
fp16_out
=
self
.
gate_up_proj
.
quant_method
.
quant_config
.
weight_bits
==
6
if
getattr
(
self
.
gate_up_proj
.
quant_method
,
"quant_config"
,
None
)
else
False
)
x
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
x
)
residual
,
_
=
self
.
down_proj
(
x
,
residual
,
output
=
user_output
,
disable_allreduce
=
disable_allreduce
)
return
residual
class
Step1Attention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
slopes
:
Optional
[
List
[
float
]]
=
None
,
max_pos_interp_ratio
:
float
=
1.0
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
# Create the alibi slopes and slice them.
tp_rank
=
get_tensor_model_parallel_rank
()
head_start
=
tp_rank
*
self
.
num_heads
head_end
=
(
tp_rank
+
1
)
*
self
.
num_heads
if
slopes
is
None
:
alibi_slopes
=
_get_alibi_slopes
(
self
.
total_num_heads
)
alibi_slopes
=
_get_ntk_alibi_slopes
(
max_pos_interp_ratio
,
alibi_slopes
)
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
]
else
:
assert
len
(
slopes
)
==
self
.
total_num_heads
alibi_slopes
=
_get_ntk_alibi_slopes
(
max_pos_interp_ratio
,
slopes
).
tolist
()
alibi_slopes
=
slopes
[
head_start
:
head_end
]
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scaling
,
self
.
num_kv_heads
,
alibi_slopes
,
alibi_sqrt
=
True
,
cache_config
=
cache_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
layernorm
:
Optional
[
nn
.
Module
]
=
None
,
disable_allreduce
=
False
,
user_output
=
None
)
->
torch
.
Tensor
:
del
positions
# Unused.
hidden_states
=
layernorm
(
hidden_states
,
fp16_out
=
self
.
qkv_proj
.
quant_method
.
quant_config
.
weight_bits
==
6
if
getattr
(
self
.
qkv_proj
.
quant_method
,
"quant_config"
,
None
)
else
False
)
if
layernorm
else
hidden_states
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
residual
,
_
=
self
.
o_proj
(
attn_output
,
residual
,
disable_allreduce
=
disable_allreduce
,
output
=
user_output
)
return
residual
class
Step1DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
model_config
:
ModelConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
enable_cudagraph
=
not
model_config
.
enforce_eager
config
=
model_config
.
hf_config
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
Step1Attention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
config
.
num_attention_groups
,
slopes
=
config
.
alibi_slopes
,
max_pos_interp_ratio
=
config
.
max_pos_interp_ratio
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
layer_idx
=
int
(
prefix
.
split
(
"layers."
)[
1
].
split
(
"."
)[
0
])
self
.
use_moe
=
config
.
use_moe
and
(
layer_idx
+
config
.
moe_layer_offset
)
%
config
.
moe_every_n_layer
==
0
if
self
.
use_moe
:
self
.
moe
=
Step1MoEMLP
(
config
.
moe_num_experts
,
config
.
moe_top_k
,
config
.
moe_dynamic_exp_p
,
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
hidden_act
=
"silu"
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.moe"
,
enable_cudagraph
=
self
.
enable_cudagraph
,
)
else
:
self
.
mlp
=
Step1MLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
"silu"
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
ln_cls
=
OptimusRMSNorm
if
config
.
hidden_size
%
64
==
0
else
RMSNorm
self
.
input_layernorm
=
ln_cls
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
ln_cls
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# Self Attention
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
hidden_states
,
layernorm
=
self
.
input_layernorm
,
)
# Fully Connected
def
ffn_switch
():
return
self
.
moe
if
self
.
use_moe
else
self
.
mlp
hidden_states
=
ffn_switch
()(
hidden_states
,
hidden_states
,
self
.
post_attention_layernorm
)
return
hidden_states
# @support_torch_compile
class
Step1Model
(
nn
.
Module
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
assert
lora_config
is
None
self
.
config
=
config
self
.
allgather_dtype
=
None
# FIXME(ys): disable fp8 allgather
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
if
get_pp_group
().
is_first_rank
or
(
config
.
tie_word_embeddings
and
get_pp_group
().
is_last_rank
):
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
quant_config
=
quant_config
,
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
Step1DecoderLayer
(
model_config
=
vllm_config
.
model_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
,
)
ln_cls
=
OptimusRMSNorm
if
config
.
hidden_size
%
64
==
0
else
RMSNorm
if
get_pp_group
().
is_last_rank
:
self
.
norm
=
ln_cls
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
hidden_size
))
self
.
sequence_parallel_threshold
=
None
if
DISABLE_SEQUENCE_PARALLEL
else
SEQUENCE_PARALLEL_THRESHOLD
self
.
overlap_ratio
=
GEMM_COMM_OVERLAP_RATIO
self
.
mlp_batch_size
=
MLP_BATCH_SIZE
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
use_moe
=
config
.
use_moe
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
world_size
>
1
:
return
self
.
forward_pp
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
else
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
if
self
.
use_moe
:
return
self
.
forward_hidden_states_moe
(
hidden_states
,
positions
)
else
:
return
self
.
forward_hidden_states
(
hidden_states
,
positions
)
def
forward_pp
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
})
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
def
forward_hidden_states_moe
(
self
,
hidden_states
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
S
=
hidden_states
.
shape
[
0
]
if
(
self
.
tp_size
>
1
and
self
.
sequence_parallel_threshold
is
not
None
and
self
.
sequence_parallel_threshold
<
S
):
if
self
.
tp_size
>
8
:
return
self
.
forward_overlap_v2
(
hidden_states
,
positions
)
else
:
# TODO(xwx): overlap mlp layer of MoE model
return
self
.
forward_split_ffn
(
hidden_states
,
positions
)
else
:
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
,
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
def
forward_overlap_v2
(
self
,
hidden_states
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
del
positions
S
=
hidden_states
.
shape
[
0
]
tp_size
=
get_tensor_model_parallel_world_size
()
rank
=
get_tensor_model_parallel_rank
()
if
S
%
tp_size
!=
0
:
# pad to multiple of tp_size with 0
pad_len
=
tp_size
-
S
%
tp_size
hidden_states
=
torch
.
cat
([
hidden_states
,
torch
.
zeros
(
pad_len
,
hidden_states
.
shape
[
1
],
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
)
])
S
=
hidden_states
.
shape
[
0
]
else
:
pad_len
=
0
assert
S
%
tp_size
==
0
hidden_states
=
hidden_states
.
view
(
S
,
-
1
)
dim_0
=
int
((
S
*
self
.
overlap_ratio
+
tp_size
-
1
)
//
tp_size
*
tp_size
)
# round up to multiple of tp_size
buffer
=
torch
.
empty
(
S
*
int
(
self
.
config
.
intermediate_size
/
tp_size
),
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
)
mlp_buffer
=
buffer
.
view
(
S
,
-
1
)
if
tp_size
>
8
:
assert
tp_size
%
8
==
0
,
f
"tp_size should be an integer multiple of 8,but cur tp_size=
{
tp_size
}
"
kv_repeat
=
tp_size
//
8
else
:
kv_repeat
=
1
qkv_buffer
=
buffer
[:
S
*
int
(
(
self
.
config
.
num_attention_heads
+
self
.
config
.
num_attention_groups
*
kv_repeat
*
2
)
//
tp_size
*
(
self
.
config
.
hidden_size
//
self
.
config
.
num_attention_heads
))].
view
(
S
,
-
1
)
chunk_size
=
S
//
tp_size
residual
=
torch
.
empty
(
chunk_size
,
self
.
config
.
hidden_size
,
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
)
chunk_size_0
=
dim_0
//
tp_size
chunk_size_1
=
chunk_size
-
chunk_size_0
residual_intersect_0
=
residual
[:
chunk_size_0
]
residual_intersect_1
=
residual
[
chunk_size_0
:]
hidden_states_intersect_0
=
hidden_states
[
rank
*
chunk_size_0
:(
rank
+
1
)
*
chunk_size_0
]
hidden_states_intersect_1
=
hidden_states
[
dim_0
+
rank
*
chunk_size_1
:
dim_0
+
(
rank
+
1
)
*
chunk_size_1
]
s1
=
torch
.
cuda
.
Stream
(
device
=
residual
.
device
)
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
ffn
=
layer
.
moe
if
layer
.
use_moe
else
layer
.
mlp
# Attention Forward
residual_intersect_0
.
copy_
(
hidden_states_intersect_0
)
layer
.
input_layernorm
(
hidden_states
[:
dim_0
],
output
=
hidden_states
[:
dim_0
])
layer
.
self_attn
.
qkv_proj
(
hidden_states
[:
dim_0
],
output
=
qkv_buffer
[:
dim_0
])
with
torch
.
cuda
.
stream
(
s1
):
residual_intersect_1
.
copy_
(
hidden_states_intersect_1
)
layer
.
input_layernorm
(
hidden_states
[
dim_0
:],
output
=
hidden_states
[
dim_0
:])
layer
.
self_attn
.
qkv_proj
(
hidden_states
[
dim_0
:],
output
=
qkv_buffer
[
dim_0
:])
torch
.
cuda
.
current_stream
().
wait_stream
(
s1
)
q
,
k
,
v
=
qkv_buffer
.
view
(
S
,
-
1
).
split
([
layer
.
self_attn
.
q_size
,
layer
.
self_attn
.
kv_size
,
layer
.
self_attn
.
kv_size
],
dim
=-
1
)
if
pad_len
>
0
:
attn_output
=
layer
.
self_attn
.
attn
(
q
[:
-
pad_len
],
k
[:
-
pad_len
],
v
[:
-
pad_len
])
attn_output
=
torch
.
cat
([
attn_output
,
torch
.
zeros
(
pad_len
,
attn_output
.
shape
[
1
],
dtype
=
attn_output
.
dtype
,
device
=
attn_output
.
device
)],
dim
=
0
)
else
:
attn_output
=
layer
.
self_attn
.
attn
(
q
,
k
,
v
)
hidden_states
=
hidden_states
.
view
(
S
,
-
1
)
layer
.
self_attn
.
o_proj
(
attn_output
[:
dim_0
],
output
=
hidden_states
[:
dim_0
],
disable_allreduce
=
True
)
hidden_states_intersect_0
.
add_
(
residual_intersect_0
)
torch
.
distributed
.
all_reduce
(
hidden_states
[:
dim_0
],
group
=
get_tensor_model_parallel_group
().
device_group
)
with
torch
.
cuda
.
stream
(
s1
):
layer
.
self_attn
.
o_proj
(
attn_output
[
dim_0
:],
output
=
hidden_states
[
dim_0
:],
disable_allreduce
=
True
)
hidden_states_intersect_1
.
add_
(
residual_intersect_1
)
torch
.
distributed
.
all_reduce
(
hidden_states
[
dim_0
:],
group
=
get_tensor_model_parallel_group
().
device_group
)
del
attn_output
residual_intersect_0
.
copy_
(
hidden_states_intersect_0
)
layer
.
post_attention_layernorm
(
hidden_states
[:
dim_0
],
output
=
hidden_states
[:
dim_0
])
num_batch_size
=
(
dim_0
+
self
.
mlp_batch_size
-
1
)
//
self
.
mlp_batch_size
for
idx
in
range
(
num_batch_size
):
start
=
idx
*
self
.
mlp_batch_size
end
=
min
((
idx
+
1
)
*
self
.
mlp_batch_size
,
dim_0
)
ffn
(
hidden_states
[
start
:
end
],
disable_allreduce
=
True
,
user_output
=
hidden_states
[
start
:
end
])
hidden_states_intersect_0
.
add_
(
residual_intersect_0
)
torch
.
distributed
.
all_reduce
(
hidden_states
[:
dim_0
],
group
=
get_tensor_model_parallel_group
().
device_group
)
with
torch
.
cuda
.
stream
(
s1
):
residual_intersect_1
.
copy_
(
hidden_states_intersect_1
)
layer
.
post_attention_layernorm
(
hidden_states
[
dim_0
:],
output
=
hidden_states
[
dim_0
:])
num_batch_size
=
(
S
-
dim_0
+
self
.
mlp_batch_size
-
1
)
//
self
.
mlp_batch_size
for
idx
in
range
(
num_batch_size
):
start
=
dim_0
+
idx
*
self
.
mlp_batch_size
end
=
dim_0
+
min
(
(
idx
+
1
)
*
self
.
mlp_batch_size
,
S
-
dim_0
)
ffn
(
hidden_states
[
start
:
end
],
disable_allreduce
=
True
,
user_output
=
hidden_states
[
start
:
end
])
hidden_states_intersect_1
.
add_
(
residual_intersect_1
)
torch
.
distributed
.
all_reduce
(
hidden_states
[
dim_0
:],
group
=
get_tensor_model_parallel_group
().
device_group
)
torch
.
cuda
.
current_stream
().
wait_stream
(
s1
)
del
buffer
,
mlp_buffer
,
qkv_buffer
,
residual
self
.
norm
(
hidden_states
,
output
=
hidden_states
)
return
hidden_states
[:
S
-
pad_len
]
def
forward_split_ffn
(
self
,
hidden_states
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
seq_len
=
hidden_states
.
shape
[
0
]
tp_size
=
get_tensor_model_parallel_world_size
()
rank
=
get_tensor_model_parallel_rank
()
chunk_size
=
self
.
config
.
hidden_size
//
tp_size
residual
=
torch
.
empty
(
seq_len
,
chunk_size
,
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
)
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states_intersect_0
=
hidden_states
.
narrow
(
1
,
rank
*
chunk_size
,
chunk_size
)
residual
.
copy_
(
hidden_states_intersect_0
)
layer
.
input_layernorm
(
hidden_states
,
output
=
hidden_states
)
layer
.
self_attn
(
positions
,
hidden_states
,
residual
=
None
,
layernorm
=
None
,
disable_allreduce
=
True
,
user_output
=
hidden_states
)
hidden_states_intersect_0
.
add_
(
residual
)
torch
.
distributed
.
all_reduce
(
hidden_states
,
group
=
get_tensor_model_parallel_group
().
device_group
)
residual
.
copy_
(
hidden_states_intersect_0
)
layer
.
post_attention_layernorm
(
hidden_states
,
output
=
hidden_states
)
num_batch_size
=
(
seq_len
+
self
.
mlp_batch_size
-
1
)
//
self
.
mlp_batch_size
hidden_states
=
hidden_states
.
view
(
seq_len
,
-
1
)
for
idx
in
range
(
num_batch_size
):
start
=
idx
*
self
.
mlp_batch_size
end
=
min
((
idx
+
1
)
*
self
.
mlp_batch_size
,
seq_len
)
if
layer
.
use_moe
:
hidden_states
[
start
:
end
]
=
layer
.
moe
(
hidden_states
[
start
:
end
],
disable_allreduce
=
True
)
else
:
layer
.
mlp
(
hidden_states
[
start
:
end
],
disable_allreduce
=
True
,
user_output
=
hidden_states
[
start
:
end
])
hidden_states_intersect_0
.
add_
(
residual
)
torch
.
distributed
.
all_reduce
(
hidden_states
,
group
=
get_tensor_model_parallel_group
().
device_group
)
del
residual
self
.
norm
(
hidden_states
,
output
=
hidden_states
)
return
hidden_states
def
forward_hidden_states
(
self
,
hidden_states
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
S
=
hidden_states
.
shape
[
0
]
if
self
.
tp_size
>
1
and
self
.
sequence_parallel_threshold
is
not
None
and
self
.
sequence_parallel_threshold
<
S
:
tp_size
=
get_tensor_model_parallel_world_size
()
rank
=
get_tensor_model_parallel_rank
()
if
S
%
tp_size
!=
0
:
# pad to multiple of tp_size with 0
pad_len
=
tp_size
-
S
%
tp_size
hidden_states
=
torch
.
cat
([
hidden_states
,
torch
.
zeros
(
pad_len
,
hidden_states
.
shape
[
1
],
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
)
])
S
=
hidden_states
.
shape
[
0
]
else
:
pad_len
=
0
assert
S
%
tp_size
==
0
chunk_size
=
S
//
tp_size
residual
=
torch
.
empty
(
chunk_size
,
self
.
config
.
hidden_size
,
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
)
dim_0
=
int
((
S
*
self
.
overlap_ratio
+
tp_size
-
1
)
//
tp_size
*
tp_size
)
# round up to multiple of tp_size
dim_1
=
S
-
dim_0
mlp_dim
=
int
(
self
.
config
.
intermediate_size
/
tp_size
)
qkv_dim
=
int
(
(
self
.
config
.
num_attention_heads
+
self
.
config
.
num_attention_groups
*
2
)
//
tp_size
*
(
self
.
config
.
hidden_size
//
self
.
config
.
num_attention_heads
))
if
self
.
allgather_dtype
is
not
None
:
fp8_dim
=
int
(
self
.
config
.
hidden_size
/
2
)
max_buffer_dim
=
max
(
mlp_dim
,
qkv_dim
,
fp8_dim
)
else
:
max_buffer_dim
=
max
(
mlp_dim
,
qkv_dim
)
buffer
=
torch
.
empty
(
S
*
max_buffer_dim
,
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
)
buffer_0
=
buffer
[:
dim_0
*
max_buffer_dim
]
buffer_1
=
buffer
[
dim_0
*
max_buffer_dim
:]
mlp_buffer_0
=
buffer_0
[:
dim_0
*
mlp_dim
].
view
(
dim_0
,
-
1
)
mlp_buffer_1
=
buffer_1
[:
dim_1
*
mlp_dim
].
view
(
dim_1
,
-
1
)
qkv_buffer
=
buffer
[
dim_0
*
max_buffer_dim
-
dim_0
*
qkv_dim
:
dim_0
*
max_buffer_dim
+
dim_1
*
qkv_dim
].
view
(
S
,
-
1
)
chunk_size_0
=
dim_0
//
tp_size
chunk_size_1
=
chunk_size
-
chunk_size_0
residual_intersect_0
=
residual
[:
chunk_size_0
]
hidden_states_0
=
hidden_states
[:
dim_0
]
hidden_states_1
=
hidden_states
[
dim_0
:]
hidden_states_intersect_0
=
hidden_states
[
rank
*
chunk_size_0
:(
rank
+
1
)
*
chunk_size_0
]
residual_intersect_1
=
residual
[
chunk_size_0
:]
hidden_states_intersect_1
=
hidden_states
[
dim_0
+
rank
*
chunk_size_1
:
dim_0
+
(
rank
+
1
)
*
chunk_size_1
]
if
self
.
allgather_dtype
is
not
None
:
hidden_states_fp8_0
=
buffer_0
[:
dim_0
*
fp8_dim
]
hidden_states_fp8_0
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_0
,
torch
.
uint8
).
reshape
(
dim_0
,
self
.
config
.
hidden_size
)
hidden_states_fp8_1
=
buffer_1
[:
dim_1
*
fp8_dim
]
hidden_states_fp8_1
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_1
,
torch
.
uint8
).
reshape
(
dim_1
,
self
.
config
.
hidden_size
)
hidden_states_fp8_intersect_0
=
hidden_states_fp8_0
[
rank
*
chunk_size_0
:(
rank
+
1
)
*
chunk_size_0
]
hidden_states_fp8_intersect_1
=
hidden_states_fp8_1
[
rank
*
chunk_size_1
:(
rank
+
1
)
*
chunk_size_1
]
s1
=
torch
.
cuda
.
Stream
(
device
=
residual
.
device
)
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
# Attention Forward
if
i
==
0
:
residual_intersect_0
.
copy_
(
hidden_states_intersect_0
)
layer
.
input_layernorm
(
hidden_states
[:
dim_0
],
output
=
hidden_states
[:
dim_0
])
layer
.
self_attn
.
qkv_proj
(
hidden_states
[:
dim_0
],
output
=
qkv_buffer
[:
dim_0
])
with
torch
.
cuda
.
stream
(
s1
):
if
i
==
0
:
residual_intersect_1
.
copy_
(
hidden_states_intersect_1
)
layer
.
input_layernorm
(
hidden_states
[
dim_0
:],
output
=
hidden_states
[
dim_0
:])
else
:
if
self
.
allgather_dtype
is
not
None
:
if
self
.
allgather_dtype
==
"static_fp8e4m3"
:
qkv_input_scale_1
=
torch
.
full
(
[
1
],
layer
.
self_attn
.
qkv_proj
.
input_scales
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
torch
.
ops
.
OptimusFp8
.
rms_norm_quantize_infer
(
residual_intersect_1
,
layer
.
input_layernorm
.
weight
,
qkv_input_scale_1
,
out
=
hidden_states_fp8_intersect_1
)
elif
self
.
allgather_dtype
==
"dynamic_fp8e4m3"
:
layer
.
input_layernorm
(
residual_intersect_1
,
output
=
hidden_states_intersect_1
)
qkv_input_scale_1
=
dynamic_fp8_pertensor_quantize
(
hidden_states_intersect_1
)
torch
.
distributed
.
all_reduce
(
qkv_input_scale_1
,
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
get_tensor_model_parallel_group
(
).
device_group
)
hidden_states_fp8_intersect_1
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_intersect_1
,
torch
.
float8_e4m3fn
)
torch
.
ops
.
OptimusFp8
.
quantize
(
hidden_states_intersect_1
,
qkv_input_scale_1
,
out
=
hidden_states_fp8_intersect_1
)
hidden_states_fp8_intersect_1
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_intersect_1
,
torch
.
uint8
)
else
:
raise
ValueError
(
f
"Unsupported allgather_dtype:
{
self
.
allgather_dtype
}
"
)
torch
.
distributed
.
all_gather_into_tensor
(
hidden_states_fp8_1
,
hidden_states_fp8_intersect_1
,
group
=
get_tensor_model_parallel_group
(
).
device_group
)
hidden_states_fp8_1
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_1
,
torch
.
float8_e4m3fn
)
torch
.
ops
.
OptimusFp8
.
dequantize
(
hidden_states_fp8_1
,
qkv_input_scale_1
.
reciprocal
(),
torch
.
bfloat16
,
out
=
hidden_states
[
dim_0
:])
hidden_states_fp8_1
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_1
,
torch
.
uint8
)
else
:
layer
.
input_layernorm
(
residual_intersect_1
,
output
=
hidden_states_intersect_1
)
torch
.
distributed
.
all_gather_into_tensor
(
hidden_states
[
dim_0
:],
hidden_states_intersect_1
,
group
=
get_tensor_model_parallel_group
(
).
device_group
)
layer
.
self_attn
.
qkv_proj
(
hidden_states
[
dim_0
:],
output
=
qkv_buffer
[
dim_0
:])
torch
.
cuda
.
current_stream
().
wait_stream
(
s1
)
q
,
k
,
v
=
qkv_buffer
.
split
([
layer
.
self_attn
.
q_size
,
layer
.
self_attn
.
kv_size
,
layer
.
self_attn
.
kv_size
],
dim
=-
1
)
if
pad_len
>
0
:
attn_output
=
layer
.
self_attn
.
attn
(
q
[:
S
-
pad_len
],
k
[:
S
-
pad_len
],
v
[:
S
-
pad_len
])
attn_output
=
torch
.
cat
([
attn_output
,
torch
.
zeros
(
pad_len
,
attn_output
.
shape
[
1
],
dtype
=
attn_output
.
dtype
,
device
=
attn_output
.
device
)],
dim
=
0
)
else
:
attn_output
=
layer
.
self_attn
.
attn
(
q
,
k
,
v
)
layer
.
self_attn
.
o_proj
(
attn_output
[:
dim_0
],
output
=
hidden_states
[:
dim_0
],
disable_allreduce
=
True
)
hidden_states_intersect_0
.
add_
(
residual_intersect_0
)
torch
.
distributed
.
reduce_scatter_tensor
(
residual_intersect_0
,
hidden_states
[:
dim_0
],
group
=
get_tensor_model_parallel_group
().
device_group
)
with
torch
.
cuda
.
stream
(
s1
):
layer
.
self_attn
.
o_proj
(
attn_output
[
dim_0
:],
output
=
hidden_states
[
dim_0
:],
disable_allreduce
=
True
)
hidden_states_intersect_1
.
add_
(
residual_intersect_1
)
del
attn_output
if
self
.
allgather_dtype
is
not
None
:
if
self
.
allgather_dtype
==
"static_fp8e4m3"
:
gate_up_input_scale_0
=
torch
.
full
(
[
1
],
layer
.
mlp
.
gate_up_proj
.
input_scales
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
torch
.
ops
.
OptimusFp8
.
rms_norm_quantize_infer
(
residual_intersect_0
,
layer
.
post_attention_layernorm
.
weight
,
gate_up_input_scale_0
,
out
=
hidden_states_fp8_intersect_0
)
elif
self
.
allgather_dtype
==
"dynamic_fp8e4m3"
:
layer
.
post_attention_layernorm
(
residual_intersect_0
,
output
=
hidden_states_intersect_0
)
gate_up_input_scale_0
=
dynamic_fp8_pertensor_quantize
(
hidden_states_intersect_0
)
torch
.
distributed
.
all_reduce
(
gate_up_input_scale_0
,
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
get_tensor_model_parallel_group
(
).
device_group
)
hidden_states_fp8_intersect_0
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_intersect_0
,
torch
.
float8_e4m3fn
)
torch
.
ops
.
OptimusFp8
.
quantize
(
hidden_states_intersect_0
,
gate_up_input_scale_0
,
out
=
hidden_states_fp8_intersect_0
)
hidden_states_fp8_intersect_0
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_intersect_0
,
torch
.
uint8
)
else
:
raise
ValueError
(
f
"Unsupported allgather_dtype:
{
self
.
allgather_dtype
}
"
)
torch
.
distributed
.
all_gather_into_tensor
(
hidden_states_fp8_0
,
hidden_states_fp8_intersect_0
,
group
=
get_tensor_model_parallel_group
().
device_group
)
else
:
layer
.
post_attention_layernorm
(
residual_intersect_0
,
output
=
hidden_states_intersect_0
)
torch
.
distributed
.
all_gather_into_tensor
(
hidden_states
[:
dim_0
],
hidden_states_intersect_0
,
group
=
get_tensor_model_parallel_group
().
device_group
)
with
torch
.
cuda
.
stream
(
s1
):
torch
.
distributed
.
reduce_scatter_tensor
(
residual_intersect_1
,
hidden_states
[
dim_0
:],
group
=
get_tensor_model_parallel_group
().
device_group
)
if
self
.
allgather_dtype
is
not
None
:
hidden_states_fp8_0
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_0
,
torch
.
float8_e4m3fn
)
torch
.
ops
.
OptimusFp8
.
dequantize
(
hidden_states_fp8_0
,
gate_up_input_scale_0
.
reciprocal
(),
torch
.
bfloat16
,
out
=
hidden_states
[:
dim_0
])
hidden_states_fp8_0
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_0
,
torch
.
uint8
)
num_batch_size
=
(
dim_0
+
self
.
mlp_batch_size
-
1
)
//
self
.
mlp_batch_size
for
idx
in
range
(
num_batch_size
):
start
=
idx
*
self
.
mlp_batch_size
end
=
min
((
idx
+
1
)
*
self
.
mlp_batch_size
,
dim_0
)
w0_out_0
,
_
=
layer
.
mlp
.
gate_up_proj
(
hidden_states_0
[
start
:
end
])
layer
.
mlp
.
act_fn
(
w0_out_0
,
output
=
mlp_buffer_0
[
start
:
end
])
del
w0_out_0
with
torch
.
cuda
.
stream
(
s1
):
if
self
.
allgather_dtype
is
not
None
:
if
self
.
allgather_dtype
==
"static_fp8e4m3"
:
gate_up_input_scale_1
=
torch
.
full
(
[
1
],
layer
.
mlp
.
gate_up_proj
.
input_scales
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
torch
.
ops
.
OptimusFp8
.
rms_norm_quantize_infer
(
residual_intersect_1
,
layer
.
post_attention_layernorm
.
weight
,
gate_up_input_scale_1
,
out
=
hidden_states_fp8_intersect_1
)
elif
self
.
allgather_dtype
==
"dynamic_fp8e4m3"
:
layer
.
post_attention_layernorm
(
residual_intersect_1
,
output
=
hidden_states_intersect_1
)
gate_up_input_scale_1
=
dynamic_fp8_pertensor_quantize
(
hidden_states_intersect_1
)
torch
.
distributed
.
all_reduce
(
gate_up_input_scale_1
,
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
get_tensor_model_parallel_group
(
).
device_group
)
hidden_states_fp8_intersect_1
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_intersect_1
,
torch
.
float8_e4m3fn
)
torch
.
ops
.
OptimusFp8
.
quantize
(
hidden_states_intersect_1
,
gate_up_input_scale_1
,
out
=
hidden_states_fp8_intersect_1
)
hidden_states_fp8_intersect_1
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_intersect_1
,
torch
.
uint8
)
else
:
raise
ValueError
(
f
"Unsupported allgather_dtype:
{
self
.
allgather_dtype
}
"
)
torch
.
distributed
.
all_gather_into_tensor
(
hidden_states_fp8_1
,
hidden_states_fp8_intersect_1
,
group
=
get_tensor_model_parallel_group
(
).
device_group
)
else
:
layer
.
post_attention_layernorm
(
residual_intersect_1
,
output
=
hidden_states_intersect_1
)
torch
.
distributed
.
all_gather_into_tensor
(
hidden_states
[
dim_0
:],
hidden_states_intersect_1
,
group
=
get_tensor_model_parallel_group
(
).
device_group
)
layer
.
mlp
.
down_proj
(
mlp_buffer_0
,
output
=
hidden_states
[:
dim_0
],
disable_allreduce
=
True
)
hidden_states_intersect_0
.
add_
(
residual_intersect_0
)
if
i
<
len
(
self
.
layers
)
-
1
:
torch
.
distributed
.
reduce_scatter_tensor
(
residual_intersect_0
,
hidden_states
[:
dim_0
],
group
=
get_tensor_model_parallel_group
().
device_group
)
else
:
torch
.
distributed
.
all_reduce
(
hidden_states
[:
dim_0
],
group
=
get_tensor_model_parallel_group
().
device_group
)
with
torch
.
cuda
.
stream
(
s1
):
if
self
.
allgather_dtype
is
not
None
:
hidden_states_fp8_1
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_1
,
torch
.
float8_e4m3fn
)
torch
.
ops
.
OptimusFp8
.
dequantize
(
hidden_states_fp8_1
,
gate_up_input_scale_1
.
reciprocal
(),
torch
.
bfloat16
,
out
=
hidden_states
[
dim_0
:])
hidden_states_fp8_1
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_1
,
torch
.
uint8
)
num_batch_size
=
(
dim_1
+
self
.
mlp_batch_size
-
1
)
//
self
.
mlp_batch_size
for
idx
in
range
(
num_batch_size
):
start
=
idx
*
self
.
mlp_batch_size
end
=
min
((
idx
+
1
)
*
self
.
mlp_batch_size
,
dim_1
)
w0_out_1
,
_
=
layer
.
mlp
.
gate_up_proj
(
hidden_states_1
[
start
:
end
])
layer
.
mlp
.
act_fn
(
w0_out_1
,
output
=
mlp_buffer_1
[
start
:
end
])
del
w0_out_1
if
i
<
len
(
self
.
layers
)
-
1
:
next_layer
=
self
.
layers
[
i
+
1
]
if
self
.
allgather_dtype
is
not
None
:
if
self
.
allgather_dtype
==
"static_fp8e4m3"
:
qkv_input_scale_0
=
torch
.
full
(
[
1
],
next_layer
.
self_attn
.
qkv_proj
.
input_scales
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
torch
.
ops
.
OptimusFp8
.
rms_norm_quantize_infer
(
residual_intersect_0
,
next_layer
.
input_layernorm
.
weight
,
qkv_input_scale_0
,
out
=
hidden_states_fp8_intersect_0
)
elif
self
.
allgather_dtype
==
"dynamic_fp8e4m3"
:
next_layer
.
input_layernorm
(
residual_intersect_0
,
output
=
hidden_states_intersect_0
)
qkv_input_scale_0
=
dynamic_fp8_pertensor_quantize
(
hidden_states_intersect_0
)
torch
.
distributed
.
all_reduce
(
qkv_input_scale_0
,
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
get_tensor_model_parallel_group
(
).
device_group
)
hidden_states_fp8_intersect_0
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_intersect_0
,
torch
.
float8_e4m3fn
)
torch
.
ops
.
OptimusFp8
.
quantize
(
hidden_states_intersect_0
,
qkv_input_scale_0
,
out
=
hidden_states_fp8_intersect_0
)
hidden_states_fp8_intersect_0
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_intersect_0
,
torch
.
uint8
)
else
:
raise
ValueError
(
f
"Unsupported allgather_dtype:
{
self
.
allgather_dtype
}
"
)
torch
.
distributed
.
all_gather_into_tensor
(
hidden_states_fp8_0
,
hidden_states_fp8_intersect_0
,
group
=
get_tensor_model_parallel_group
(
).
device_group
)
else
:
next_layer
.
input_layernorm
(
residual_intersect_0
,
output
=
hidden_states_intersect_0
)
torch
.
distributed
.
all_gather_into_tensor
(
hidden_states
[:
dim_0
],
hidden_states_intersect_0
,
group
=
get_tensor_model_parallel_group
(
).
device_group
)
with
torch
.
cuda
.
stream
(
s1
):
layer
.
mlp
.
down_proj
(
mlp_buffer_1
,
output
=
hidden_states
[
dim_0
:],
disable_allreduce
=
True
)
hidden_states_intersect_1
.
add_
(
residual_intersect_1
)
if
i
<
len
(
self
.
layers
)
-
1
:
torch
.
distributed
.
reduce_scatter_tensor
(
residual_intersect_1
,
hidden_states
[
dim_0
:],
group
=
get_tensor_model_parallel_group
(
).
device_group
)
else
:
torch
.
distributed
.
all_reduce
(
hidden_states
[
dim_0
:],
group
=
get_tensor_model_parallel_group
(
).
device_group
)
if
i
<
len
(
self
.
layers
)
-
1
:
next_layer
=
self
.
layers
[
i
+
1
]
if
self
.
allgather_dtype
is
not
None
:
hidden_states_fp8_0
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_0
,
torch
.
float8_e4m3fn
)
torch
.
ops
.
OptimusFp8
.
dequantize
(
hidden_states_fp8_0
,
qkv_input_scale_0
.
reciprocal
(),
torch
.
bfloat16
,
out
=
hidden_states
[:
dim_0
])
hidden_states_fp8_0
=
torch
.
ops
.
OptimusFp8
.
as_type
(
hidden_states_fp8_0
,
torch
.
uint8
)
next_layer
.
self_attn
.
qkv_proj
(
hidden_states
[:
dim_0
],
output
=
qkv_buffer
[:
dim_0
])
torch
.
cuda
.
current_stream
().
wait_stream
(
s1
)
del
buffer
,
residual
self
.
norm
(
hidden_states
,
output
=
hidden_states
)
return
hidden_states
[:
S
-
pad_len
]
else
:
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
,
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
class
Step1PretrainedModel
(
nn
.
Module
,
SupportsPP
):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
()
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
loaded_params
.
add
(
name
)
break
else
:
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
params_need_to_load
=
[]
for
name
in
params_dict
:
if
not
(
"vision_model"
in
name
or
"latent_query_tokens"
in
name
or
"sam_model"
in
name
):
params_need_to_load
.
append
(
name
)
params_need_to_load
=
set
(
params_need_to_load
)
if
params_need_to_load
!=
loaded_params
:
param_name_example
=
list
(
params_need_to_load
-
loaded_params
)[
0
]
raise
RuntimeError
(
f
"Some parameters like
{
param_name_example
}
are not in the checkpoint and will falsely use random initialization"
)
def
load_fp8_input_scales
(
self
,
input_scales_path
):
for
name
,
loaded_weight
in
fp8_input_scales_loader
(
input_scales_path
):
if
name
.
startswith
(
"refrence_model."
):
name
=
name
.
replace
(
"refrence_model."
,
""
)
idx
=
int
(
name
.
split
(
"."
)[
2
])
layer
=
self
.
model
.
layers
[
idx
]
if
"qkv_proj"
in
name
:
layer
.
self_attn
.
qkv_proj
.
input_scales
=
loaded_weight
[:].
item
()
elif
"gate_up_proj"
in
name
:
layer
.
mlp
.
gate_up_proj
.
input_scales
=
loaded_weight
[:].
item
()
class
Step1ForCausalLM
(
Step1PretrainedModel
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
model
=
Step1Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
,
need_fp32_logits
=
False
)
self
.
sampler
=
get_sampler
()
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
class
Step1ForSequenceClassification
(
Step1PretrainedModel
):
"""
\
Step1 Transformer with a sequence classification head.
"""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
model
=
Step1Model
(
vllm_config
,
prefix
)
config
=
vllm_config
.
model_config
.
hf_config
assert
len
(
config
.
id2label
.
keys
())
==
config
.
num_labels
if
get_pp_group
().
is_last_rank
:
self
.
score
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
num_labels
,
bias
=
False
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
ALL
,
normalize
=
False
,
softmax
=
False
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
logits
,
_
=
self
.
score
(
hidden_states
)
ret
=
self
.
_pooler
(
logits
,
pooling_metadata
)
return
ret
\ No newline at end of file
vllm/model_executor/models/step2_mini.py
0 → 100644
View file @
583034f1
# SPDX-License-Identifier: Apache-2.0
"""Inference-only Jurassic model."""
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.distributed
import
(
get_dp_group
,
get_pp_group
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.step1
import
Step1MoEMLP
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
.interfaces
import
SupportsPP
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
logger
=
init_logger
(
__name__
)
# 全局共享的CUDA graph memory pool,类似model_runner.py中的实现
_graph_memory_pool
:
Optional
[
Tuple
[
int
,
int
]]
=
None
class
FusedMoEBlock
(
nn
.
Module
):
def
__init__
(
self
,
config
:
ModelConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
if
self
.
tp_size
>
config
.
moe_num_experts
:
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
config
.
moe_num_experts
}
."
)
assert
config
.
moe_dynamic_exp_p
==
1
,
"Only support dynamic exp p=1"
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
moe_num_experts
,
top_k
=
config
.
moe_top_k
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
config
.
norm_expert_weight
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
moe_num_experts
,
bias
=
False
,
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.gate"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape
=
hidden_states
.
shape
hidden_dim
=
hidden_states
.
shape
[
-
1
]
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
orig_shape
)
class
Step2MiniMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
self
.
prefix
=
prefix
self
.
hidden_size
=
hidden_size
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
gate_up
,
_
=
self
.
gate_up_proj
(
hidden_states
)
intermediate_act
=
self
.
act_fn
(
gate_up
)
output
,
_
=
self
.
down_proj
(
intermediate_act
)
return
output
class
Step2MiniAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
norm_eps
:
float
,
rope_theta
:
int
,
share_q_dim
:
Optional
[
int
]
=
None
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embedding
:
int
=
8192
,
head_dim
:
int
=
256
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
q_size
=
share_q_dim
if
share_q_dim
else
self
.
head_dim
self
.
qkv_proj
=
ReplicatedLinear
(
hidden_size
,
self
.
q_size
+
self
.
kv_size
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
self
.
inter_norm
=
RMSNorm
(
self
.
q_size
,
eps
=
norm_eps
)
self
.
wq
=
ColumnParallelLinear
(
self
.
q_size
,
self
.
head_dim
*
self
.
total_num_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.wq"
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embedding
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
)
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scaling
,
self
.
num_kv_heads
,
cache_config
=
cache_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
self
.
prefix
=
prefix
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
=
self
.
inter_norm
(
q
.
contiguous
())
q
=
self
.
wq
(
q
)[
0
]
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
residual
,
_
=
self
.
o_proj
(
attn_output
)
return
residual
class
Step2MiniDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
ModelConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
use_fused_moe
:
bool
=
False
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
config
=
config
.
hf_config
self
.
hidden_size
=
config
.
hidden_size
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
self
.
self_attn
=
Step2MiniAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
1
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
norm_eps
=
config
.
rms_norm_eps
,
max_position_embedding
=
config
.
max_position_embedding
,
head_dim
=
config
.
head_dim
,
share_q_dim
=
config
.
share_q_dim
,
rope_theta
=
config
.
rope_theta
,
rope_scaling
=
rope_scaling
,
prefix
=
f
"
{
prefix
}
.self_attn"
)
self
.
use_moe
=
False
layer_idx
=
int
(
prefix
.
split
(
"layers."
)[
1
].
split
(
"."
)[
0
])
moe_layers_enum
=
getattr
(
config
,
"moe_layers_enum"
,
None
)
if
moe_layers_enum
is
not
None
:
moe_layers_idx
=
[
int
(
i
)
for
i
in
moe_layers_enum
.
strip
().
split
(
','
)]
else
:
# Default to 1dense.
moe_layers_idx
=
[
i
for
i
in
range
(
1
,
config
.
num_hidden_layers
)]
if
layer_idx
in
moe_layers_idx
:
if
not
use_fused_moe
:
self
.
moe
=
Step1MoEMLP
(
config
.
moe_num_experts
,
config
.
moe_top_k
,
config
.
moe_dynamic_exp_p
,
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
hidden_act
=
"silu"
,
quant_config
=
quant_config
,
norm_expert_weight
=
config
.
norm_expert_weight
,
prefix
=
f
"
{
prefix
}
.moe"
,
enable_cudagraph
=
False
)
# FIXME: TODO: enable cudagraph
else
:
self
.
moe
=
FusedMoEBlock
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.moe"
)
self
.
share_expert
=
Step2MiniMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
share_expert_dim
,
hidden_act
=
"silu"
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.share_expert"
)
self
.
use_moe
=
True
else
:
self
.
mlp
=
Step2MiniMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
"silu"
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
self
.
use_fused_moe
=
use_fused_moe
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
prefix
=
prefix
# CUDA Graph parameters - 简化版本,使用共享memory pool
self
.
should_capture_graph
=
get_dp_group
().
world_size
>
1
and
current_platform
.
is_cuda_alike
()
self
.
cuda_graphs_captured
=
False
self
.
graph_runners_fwd1
:
dict
[
int
,
Tuple
[
torch
.
cuda
.
CUDAGraph
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]]
=
{}
self
.
graph_runners_fwd2
:
dict
[
int
,
Tuple
[
torch
.
cuda
.
CUDAGraph
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]]
=
{}
self
.
graph_runners_fwd3
:
dict
[
int
,
Tuple
[
torch
.
cuda
.
CUDAGraph
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]]
=
{}
self
.
max_graph_tokens
=
64
self
.
graph_token_step
=
32
self
.
decoder_captured_sizes
=
list
(
range
(
self
.
graph_token_step
,
self
.
max_graph_tokens
+
1
,
self
.
graph_token_step
))
if
self
.
should_capture_graph
else
[]
@
torch
.
inference_mode
()
def
_capture_cuda_graph
(
self
,
device
:
torch
.
device
,
hs_dtype
:
torch
.
dtype
,
pos_dtype
:
torch
.
dtype
):
global
_graph_memory_pool
if
self
.
cuda_graphs_captured
or
not
self
.
should_capture_graph
:
return
# 使用全局共享的memory pool
stream
=
torch
.
cuda
.
Stream
()
stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
stream
):
for
total_tokens
in
reversed
(
self
.
decoder_captured_sizes
):
# --- Capture forward_1 ---
graph_fwd1
=
torch
.
cuda
.
CUDAGraph
()
# 创建输入buffers
static_positions
=
torch
.
ones
((
total_tokens
,),
dtype
=
pos_dtype
,
device
=
device
)
static_hidden_states
=
torch
.
randn
((
total_tokens
,
self
.
hidden_size
),
dtype
=
hs_dtype
,
device
=
device
)
# Warmup forward_1
_
,
_
,
_
=
self
.
_forward_1_impl
(
static_positions
,
static_hidden_states
)
# Capture forward_1 - 使用torch.cuda.graph()和共享memory pool
with
torch
.
cuda
.
graph
(
graph_fwd1
,
pool
=
_graph_memory_pool
,
stream
=
stream
):
static_q_fwd1
,
static_k_fwd1
,
static_v_fwd1
=
self
.
_forward_1_impl
(
static_positions
,
static_hidden_states
)
# 更新全局memory pool
if
_graph_memory_pool
is
None
:
_graph_memory_pool
=
graph_fwd1
.
pool
()
self
.
graph_runners_fwd1
[
total_tokens
]
=
(
graph_fwd1
,
static_positions
,
static_hidden_states
,
static_q_fwd1
,
static_k_fwd1
,
static_v_fwd1
)
# --- Capture forward_2 ---
graph_fwd2
=
torch
.
cuda
.
CUDAGraph
()
# 创建输入buffers
attn_output_size
=
self
.
self_attn
.
num_heads
*
self
.
self_attn
.
head_dim
static_attn_output
=
torch
.
randn
((
total_tokens
,
attn_output_size
),
dtype
=
hs_dtype
,
device
=
device
)
static_residual
=
torch
.
randn
((
total_tokens
,
self
.
hidden_size
),
dtype
=
hs_dtype
,
device
=
device
)
# Warmup forward_2
_
,
_
=
self
.
_forward_2_impl
(
static_attn_output
,
static_residual
)
# Capture forward_2 - 使用torch.cuda.graph()和共享memory pool
with
torch
.
cuda
.
graph
(
graph_fwd2
,
pool
=
_graph_memory_pool
,
stream
=
stream
):
static_hs_out_fwd2
,
static_residual_out_fwd2
=
self
.
_forward_2_impl
(
static_attn_output
,
static_residual
)
self
.
graph_runners_fwd2
[
total_tokens
]
=
(
graph_fwd2
,
static_attn_output
,
static_residual
,
static_hs_out_fwd2
,
static_residual_out_fwd2
)
# --- Capture forward_3 ---
graph_fwd3
=
torch
.
cuda
.
CUDAGraph
()
# 创建输入buffers (重用之前的)
static_hidden_states_fwd3
=
torch
.
randn
((
total_tokens
,
self
.
hidden_size
),
dtype
=
hs_dtype
,
device
=
device
)
static_residual_fwd3
=
torch
.
randn
((
total_tokens
,
self
.
hidden_size
),
dtype
=
hs_dtype
,
device
=
device
)
# Warmup forward_3
_
,
_
=
self
.
_forward_3_impl
(
static_hidden_states_fwd3
,
static_residual_fwd3
)
# Capture forward_3 - 使用torch.cuda.graph()和共享memory pool
with
torch
.
cuda
.
graph
(
graph_fwd3
,
pool
=
_graph_memory_pool
,
stream
=
stream
):
static_ffn_output_fwd3
,
static_router_logits_fwd3
=
self
.
_forward_3_impl
(
static_hidden_states_fwd3
,
static_residual_fwd3
)
self
.
graph_runners_fwd3
[
total_tokens
]
=
(
graph_fwd3
,
static_hidden_states_fwd3
,
static_residual_fwd3
,
static_ffn_output_fwd3
,
static_router_logits_fwd3
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream
)
self
.
cuda_graphs_captured
=
True
def
_ensure_cuda_graphs_captured
(
self
,
device
:
torch
.
device
,
hs_dtype
:
torch
.
dtype
,
pos_dtype
:
torch
.
dtype
):
if
not
self
.
cuda_graphs_captured
and
self
.
should_capture_graph
:
self
.
_capture_cuda_graph
(
device
,
hs_dtype
,
pos_dtype
)
# Separate implementation logic from graph handling
def
_forward_1_impl
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
):
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
# q, _ = self.self_attn.q_proj(hidden_states)
# kv, _ = self.self_attn.kv_proj(hidden_states)
# k, v = kv.split([self.self_attn.kv_size, self.self_attn.kv_size], dim=-1)
qkv
,
_
=
self
.
self_attn
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
self_attn
.
q_size
,
self
.
self_attn
.
kv_size
,
self
.
self_attn
.
kv_size
],
dim
=-
1
)
q
=
self
.
self_attn
.
inter_norm
(
q
.
contiguous
())
q
=
self
.
self_attn
.
wq
(
q
)[
0
]
q
,
k
=
self
.
self_attn
.
rotary_emb
(
positions
,
q
,
k
)
return
q
,
k
,
v
def
forward_1
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
):
if
self
.
should_capture_graph
:
self
.
_ensure_cuda_graphs_captured
(
hidden_states
.
device
,
hidden_states
.
dtype
,
positions
.
dtype
)
graph_key
=
(
hidden_states
.
shape
[
0
]
+
self
.
graph_token_step
-
1
)
//
self
.
graph_token_step
*
self
.
graph_token_step
graph_data
=
self
.
graph_runners_fwd1
.
get
(
graph_key
)
if
self
.
cuda_graphs_captured
else
None
use_graph
=
graph_data
is
not
None
and
hidden_states
.
shape
[
0
]
<=
self
.
max_graph_tokens
if
use_graph
:
graph
,
static_pos_view
,
static_hs_view
,
static_q
,
static_k
,
static_v
=
graph_data
actual_tokens
=
hidden_states
.
shape
[
0
]
static_pos_view
[:
actual_tokens
].
copy_
(
positions
)
static_hs_view
[:
actual_tokens
].
copy_
(
hidden_states
)
graph
.
replay
()
return
static_q
[:
actual_tokens
],
static_k
[:
actual_tokens
],
static_v
[:
actual_tokens
]
# Fallback to eager execution
return
self
.
_forward_1_impl
(
positions
,
hidden_states
)
# Separate implementation logic from graph handling
def
_forward_2_impl
(
self
,
attn_output
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
):
hidden_states
,
_
=
self
.
self_attn
.
o_proj
(
attn_output
)
hidden_states
+=
residual
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
return
hidden_states
,
residual
def
forward_2
(
self
,
attn_output
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
):
if
self
.
should_capture_graph
:
graph_key
=
(
attn_output
.
shape
[
0
]
+
self
.
graph_token_step
-
1
)
//
self
.
graph_token_step
*
self
.
graph_token_step
graph_data
=
self
.
graph_runners_fwd2
.
get
(
graph_key
)
if
self
.
cuda_graphs_captured
else
None
use_graph
=
graph_data
is
not
None
and
attn_output
.
shape
[
0
]
<=
self
.
max_graph_tokens
if
use_graph
:
graph
,
static_attn_output_view
,
static_residual_view
,
static_hs_out
,
static_residual_out
=
graph_data
actual_tokens
=
attn_output
.
shape
[
0
]
static_attn_output_view
[:
actual_tokens
].
copy_
(
attn_output
)
static_residual_view
[:
actual_tokens
].
copy_
(
residual
)
graph
.
replay
()
return
static_hs_out
[:
actual_tokens
],
static_residual_out
[:
actual_tokens
]
# Fallback to eager execution
return
self
.
_forward_2_impl
(
attn_output
,
residual
)
# Separate implementation logic from graph handling
def
_forward_3_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
):
if
self
.
use_moe
:
ffn_output
=
self
.
share_expert
(
hidden_states
)
router_logits
,
_
=
self
.
moe
.
gate
(
hidden_states
)
else
:
ffn_output
=
self
.
mlp
(
hidden_states
)
router_logits
=
None
return
ffn_output
+
residual
,
router_logits
# Base output before potential MoE addition
def
forward_3
(
self
,
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
):
if
self
.
should_capture_graph
:
graph_key
=
(
hidden_states
.
shape
[
0
]
+
self
.
graph_token_step
-
1
)
//
self
.
graph_token_step
*
self
.
graph_token_step
graph_data
=
self
.
graph_runners_fwd3
.
get
(
graph_key
)
if
self
.
cuda_graphs_captured
else
None
use_graph
=
graph_data
is
not
None
and
hidden_states
.
shape
[
0
]
<=
self
.
max_graph_tokens
if
use_graph
:
graph
,
static_hs_view
,
static_residual_view
,
static_ffn_output
,
static_router_logits
=
graph_data
actual_tokens
=
hidden_states
.
shape
[
0
]
static_hs_view
[:
actual_tokens
].
copy_
(
hidden_states
)
static_residual_view
[:
actual_tokens
].
copy_
(
residual
)
graph
.
replay
()
return
static_ffn_output
[:
actual_tokens
],
static_router_logits
[:
actual_tokens
]
if
static_router_logits
is
not
None
else
None
# Fallback to eager execution
return
self
.
_forward_3_impl
(
hidden_states
,
residual
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
should_capture_graph
:
residual
=
hidden_states
q
,
k
,
v
=
self
.
forward_1
(
positions
,
hidden_states
)
attn_output
=
self
.
self_attn
.
attn
(
q
,
k
,
v
)
hidden_states
,
residual
=
self
.
forward_2
(
attn_output
,
residual
)
ffn_output_plus_residual
,
router_logits
=
self
.
forward_3
(
hidden_states
,
residual
)
if
self
.
use_moe
:
moe_output
=
self
.
moe
.
experts
(
hidden_states
,
router_logits
)
hidden_states
=
ffn_output_plus_residual
+
moe_output
else
:
hidden_states
=
ffn_output_plus_residual
return
hidden_states
else
:
return
self
.
forward_old
(
positions
,
hidden_states
)
def
forward_old
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
)
hidden_states
+=
residual
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
if
self
.
use_moe
:
share_output
=
self
.
share_expert
(
hidden_states
)
moe_output
=
self
.
moe
(
hidden_states
)
ffn_output
=
share_output
+
moe_output
else
:
ffn_output
=
self
.
mlp
(
hidden_states
)
hidden_states
=
ffn_output
+
residual
return
hidden_states
class
Step2MiniModel
(
nn
.
Module
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
use_fused_moe
:
bool
=
False
)
->
None
:
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
vocab_size
=
config
.
vocab_size
self
.
config
=
config
self
.
use_fused_moe
=
use_fused_moe
if
get_pp_group
().
is_first_rank
or
(
config
.
tie_word_embeddings
and
get_pp_group
().
is_last_rank
):
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
vocab_size
,
config
.
hidden_size
,
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
Step2MiniDecoderLayer
(
config
=
vllm_config
.
model_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
use_fused_moe
=
self
.
use_fused_moe
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
,
)
if
get_pp_group
().
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
})
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
@
support_torch_compile
class
Step3FlashModelFusedMoE
(
Step2MiniModel
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
(
vllm_config
,
prefix
,
use_fused_moe
=
True
)
class
Step2MiniPretrainedModel
(
nn
.
Module
,
SupportsPP
):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
qkv_params_mapping
=
[
# (param_name, shard_name, relative_start_idx, relative_end_idx)
(
".qkv_proj"
,
".q_proj"
,
0
,
self
.
config
.
share_q_dim
/
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
)),
(
".qkv_proj"
,
".k_proj"
,
self
.
config
.
share_q_dim
/
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
),
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
)
/
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
)),
(
".qkv_proj"
,
".v_proj"
,
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
)
/
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
),
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
)
/
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
)),
]
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
()
params_need_to_load
=
set
()
if
self
.
model
.
use_fused_moe
:
if
self
.
vllm_config
.
quant_config
is
not
None
and
self
.
vllm_config
.
quant_config
.
get_name
()
==
"groupwise_quant"
:
expert_params_mapping
=
[
(
".moe.experts.w13_weight"
,
".moe.gate_proj.qweight"
,
"w1"
),
(
".moe.experts.w13_weight"
,
".moe.up_proj.qweight"
,
"w3"
),
(
".moe.experts.w2_weight"
,
".moe.down_proj.qweight"
,
"w2"
),
(
".moe.experts.w13_weight_scale"
,
".moe.gate_proj.scales"
,
"w1"
),
(
".moe.experts.w13_weight_scale"
,
".moe.up_proj.scales"
,
"w3"
),
(
".moe.experts.w2_weight_scale"
,
".moe.down_proj.scales"
,
"w2"
),
]
else
:
expert_params_mapping
=
[
(
".moe.experts.w13_weight"
,
".moe.gate_proj.weight"
,
"w1"
),
(
".moe.experts.w13_weight"
,
".moe.up_proj.weight"
,
"w3"
),
(
".moe.experts.w2_weight"
,
".moe.down_proj.weight"
,
"w2"
)
]
else
:
expert_params_mapping
=
[]
disable_moe_stacked_params
=
[
data
[
1
]
for
data
in
expert_params_mapping
]
for
name
,
loaded_weight
in
weights
:
# continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
if
any
(
disable_moe_stacked_param
in
name
for
disable_moe_stacked_param
in
disable_moe_stacked_params
):
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
loaded_params
.
add
(
name
)
break
else
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
# Skip loading extra bias for GPTQ models.
if
((
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
))
and
name
not
in
params_dict
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
for
expert_id
in
range
(
loaded_weight
.
shape
[
0
]):
loaded_weight_expert
=
loaded_weight
[
expert_id
]
weight_loader
(
param
,
loaded_weight_expert
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
loaded_params
.
add
(
name
)
break
else
:
for
(
param_name
,
weight_name
,
start_idx
,
end_idx
)
in
qkv_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
dim
=
param
.
shape
[
param
.
output_dim
]
begin_idx
=
int
(
start_idx
*
dim
)
end_idx
=
int
(
end_idx
*
dim
)
param_slice
=
param
.
narrow
(
param
.
output_dim
,
begin_idx
,
end_idx
-
begin_idx
)
param_slice
.
copy_
(
loaded_weight
)
loaded_params
.
add
(
name
)
break
else
:
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
for
name
in
params_dict
:
params_need_to_load
.
add
(
name
)
if
params_need_to_load
!=
loaded_params
:
param_name_example
=
list
(
params_need_to_load
-
loaded_params
)[
0
]
raise
RuntimeError
(
f
"Some parameters like
{
param_name_example
}
are not in the checkpoint and will falsely use random initialization"
)
class
Step2MiniForCausalLM
(
Step2MiniPretrainedModel
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
vllm_config
=
vllm_config
# FIXME: hack for step3 flash model
if
self
.
config
.
num_hidden_layers
==
42
:
self
.
model
=
Step2MiniModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
else
:
self
.
model
=
Step3FlashModelFusedMoE
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
need_fp32_logits
=
False
)
self
.
sampler
=
get_sampler
()
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
):
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
class
Step2MiniForSequenceClassification
(
Step2MiniPretrainedModel
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
config
=
vllm_config
.
model_config
.
hf_config
self
.
model
=
Step2MiniModel
(
vllm_config
,
prefix
=
prefix
)
if
get_pp_group
().
is_last_rank
:
self
.
score
=
ReplicatedLinear
(
self
.
config
.
hidden_size
,
self
.
config
.
num_labels
,
bias
=
False
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
ALL
,
normalize
=
False
,
softmax
=
False
)
else
:
self
.
_pooler
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
SamplerOutput
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
logits
,
_
=
self
.
score
(
hidden_states
)
ret
=
self
.
_pooler
(
logits
,
pooling_metadata
)
return
ret
def
sequence_flops
(
self
,
input_length
,
context_length
):
output_flops
=
1
*
self
.
config
.
hidden_size
*
self
.
config
.
num_labels
*
2.0
/
1e12
return
super
().
sequence_flops
(
input_length
,
context_length
)
+
output_flops
\ No newline at end of file
vllm/model_executor/models/step_encoder.py
0 → 100644
View file @
583034f1
# SPDX-License-Identifier: Apache-2.0
import
math
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
import
torchvision
#from optimus import flash_attn_func
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
torchvision.transforms.functional
import
InterpolationMode
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.layernorm
import
OptimusLayerNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.transformers_utils.configs
import
CLIPVisionConfig
def
get_abs_pos
(
abs_pos
,
tgt_size
):
dim
=
abs_pos
.
size
(
-
1
)
abs_pos_new
=
abs_pos
.
squeeze
(
0
)
cls_token
,
old_pos_embed
=
abs_pos_new
[:
1
],
abs_pos_new
[
1
:]
src_size
=
int
(
math
.
sqrt
(
abs_pos_new
.
shape
[
0
]
-
1
))
tgt_size
=
int
(
math
.
sqrt
(
tgt_size
))
dtype
=
abs_pos
.
dtype
if
src_size
!=
tgt_size
:
old_pos_embed
=
old_pos_embed
.
view
(
1
,
src_size
,
src_size
,
dim
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
old_pos_embed
=
old_pos_embed
.
to
(
torch
.
float32
)
new_pos_embed
=
F
.
interpolate
(
old_pos_embed
,
size
=
(
tgt_size
,
tgt_size
),
mode
=
'bicubic'
,
antialias
=
True
,
align_corners
=
False
,
).
to
(
dtype
)
new_pos_embed
=
new_pos_embed
.
permute
(
0
,
2
,
3
,
1
)
new_pos_embed
=
new_pos_embed
.
view
(
tgt_size
*
tgt_size
,
dim
)
vision_pos_embed
=
torch
.
cat
([
cls_token
,
new_pos_embed
],
dim
=
0
)
vision_pos_embed
=
vision_pos_embed
.
view
(
1
,
tgt_size
*
tgt_size
+
1
,
dim
)
return
vision_pos_embed
else
:
return
abs_pos
class
StepCLIPVisionEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPVisionConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
image_size
=
config
.
image_size
self
.
patch_size
=
config
.
patch_size
self
.
class_embedding
=
nn
.
Parameter
(
torch
.
randn
(
1
,
self
.
embed_dim
))
self
.
patch_embedding
=
nn
.
Conv2d
(
in_channels
=
config
.
num_channels
,
out_channels
=
self
.
embed_dim
,
kernel_size
=
self
.
patch_size
,
stride
=
self
.
patch_size
,
bias
=
True
,
)
self
.
num_patches
=
(
self
.
image_size
//
self
.
patch_size
)
**
2
self
.
pad_tp_size
=
4
# hard code for padding
# To load the pretrained weights, we still use P+1 as the seqlen
self
.
position_embedding
=
torch
.
nn
.
Embedding
(
self
.
num_patches
+
1
,
self
.
embed_dim
)
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
self
.
num_patches
+
1
).
expand
(
(
1
,
-
1
)),
persistent
=
False
)
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
=
pixel_values
.
shape
[
0
]
patch_embeds
=
self
.
patch_embedding
(
pixel_values
)
# shape = [*, width, grid, grid]
patch_embeds
=
patch_embeds
.
flatten
(
2
).
transpose
(
1
,
2
)
# pad
class_embeds
=
self
.
class_embedding
.
expand
(
batch_size
,
1
,
-
1
)
embeddings
=
torch
.
cat
([
class_embeds
,
patch_embeds
],
dim
=
1
)
embeddings
=
embeddings
+
get_abs_pos
(
self
.
position_embedding
(
self
.
position_ids
),
patch_embeds
.
size
(
1
))
embeddings
=
torch
.
cat
([
embeddings
[:,
0
,
:].
unsqueeze
(
1
).
repeat
(
1
,
self
.
pad_tp_size
-
1
,
1
),
embeddings
],
dim
=
1
)
return
embeddings
class
StepCLIPAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
need_dp
:
bool
=
False
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
total_num_heads
=
config
.
num_attention_heads
self
.
head_dim
=
self
.
embed_dim
//
self
.
total_num_heads
self
.
scale
=
self
.
head_dim
**-
0.5
if
not
need_dp
:
tp_size
=
get_tensor_model_parallel_world_size
()
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
embed_dim
,
self
.
head_dim
,
self
.
total_num_heads
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
)
self
.
out_proj
=
RowParallelLinear
(
self
.
embed_dim
,
self
.
embed_dim
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
)
else
:
self
.
num_heads
=
self
.
total_num_heads
self
.
qkv_proj
=
ReplicatedLinear
(
self
.
embed_dim
,
self
.
embed_dim
*
3
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
,
)
self
.
out_proj
=
ReplicatedLinear
(
self
.
embed_dim
,
self
.
embed_dim
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
,
)
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
).
contiguous
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
residual
=
None
,
layernorm
=
None
,
):
"""Input shape: Batch x Time x Channel"""
if
layernorm
is
not
None
:
hidden_states
=
layernorm
(
hidden_states
)
bsz
,
tgt_len
,
_
=
hidden_states
.
size
()
# get query proj
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
=
q
.
view
(
bsz
,
tgt_len
,
self
.
num_heads
,
self
.
head_dim
)
k
=
k
.
view
(
bsz
,
tgt_len
,
self
.
num_heads
,
self
.
head_dim
)
v
=
v
.
view
(
bsz
,
tgt_len
,
self
.
num_heads
,
self
.
head_dim
)
# if self.head_dim % 16 != 0 or (self.head_dim != 64
# and self.head_dim != 128):
if
True
:
q
=
q
.
transpose
(
1
,
2
)
k
=
k
.
transpose
(
1
,
2
)
v
=
v
.
transpose
(
1
,
2
)
attn_output
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
scale
=
self
.
scale
,
is_causal
=
False
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
reshape
(
bsz
,
tgt_len
,
self
.
num_heads
*
self
.
head_dim
)
# else:
# attn_output = flash_attn_func(q,
# k,
# v,
# softmax_scale=self.scale,
# causal=False)
# attn_output = attn_output.view(bsz, tgt_len,
# self.num_heads * self.head_dim)
attn_output
,
_
=
self
.
out_proj
(
attn_output
,
residual
=
residual
)
return
attn_output
class
StepCLIPMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
need_dp
:
bool
=
False
):
super
().
__init__
()
self
.
config
=
config
self
.
activation_fn
=
get_act_fn
(
config
.
hidden_act
)
if
not
need_dp
:
self
.
fc1
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
)
self
.
fc2
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
)
else
:
self
.
fc1
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
)
self
.
fc2
=
ReplicatedLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
residual
=
None
,
layernorm
=
None
)
->
torch
.
Tensor
:
if
layernorm
is
not
None
:
hidden_states
=
layernorm
(
hidden_states
)
hidden_states
,
_
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
hidden_states
)
hidden_states
,
_
=
self
.
fc2
(
hidden_states
,
residual
=
residual
)
return
hidden_states
class
StepCLIPEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
need_dp
:
bool
=
False
):
super
().
__init__
()
self
.
embed_dim
=
config
.
hidden_size
self
.
self_attn
=
StepCLIPAttention
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
need_dp
=
need_dp
)
self
.
layer_norm1
=
OptimusLayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
mlp
=
StepCLIPMLP
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
need_dp
=
need_dp
)
self
.
layer_norm2
=
OptimusLayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
FloatTensor
:
residual
=
self
.
layer_norm1
(
self
.
self_attn
(
hidden_states
=
hidden_states
,
residual
=
None
,
layernorm
=
None
))
h
=
hidden_states
+
residual
out
=
h
+
self
.
layer_norm2
(
self
.
mlp
(
h
))
return
out
class
StepCLIPEncoder
(
nn
.
Module
):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`CLIPEncoderLayer`].
Args:
config: CLIPConfig
"""
def
__init__
(
self
,
config
:
CLIPVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
need_dp
:
bool
=
False
):
super
().
__init__
()
self
.
config
=
config
self
.
layers
=
nn
.
ModuleList
([
StepCLIPEncoderLayer
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.layers.
{
i
}
"
,
need_dp
=
need_dp
)
for
i
in
range
(
config
.
num_hidden_layers
)
])
def
forward
(
self
,
inputs_embeds
,
):
hidden_states
=
inputs_embeds
for
encoder_layer
in
self
.
layers
:
hidden_states
=
encoder_layer
(
hidden_states
,
)
return
hidden_states
class
StepCLIPVisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
need_dp
:
bool
=
False
):
super
().
__init__
()
self
.
config
=
config
self
.
image_size
=
config
.
image_size
self
.
vision_model_preprocessor
=
torchvision
.
transforms
.
Resize
(
(
self
.
image_size
,
self
.
image_size
),
interpolation
=
InterpolationMode
.
BICUBIC
,
antialias
=
True
)
self
.
embeddings
=
StepCLIPVisionEmbeddings
(
config
)
self
.
transformer
=
StepCLIPEncoder
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.transformer"
,
need_dp
=
need_dp
)
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
):
hidden_states
=
self
.
embeddings
(
pixel_values
)
hidden_states
=
self
.
transformer
(
inputs_embeds
=
hidden_states
)
return
hidden_states
,
None
class
StepCLIPVisionModel
(
nn
.
Module
):
_PARAMS_KEYS_TO_SELECT
=
[
"vision_model"
]
def
__init__
(
self
,
config
:
CLIPVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
need_dp
:
bool
=
False
):
super
().
__init__
()
quant_config
=
None
# FIXME(ys): step encoder does not support quantization
self
.
vision_model
=
StepCLIPVisionTransformer
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.vision_model"
,
need_dp
=
need_dp
)
def
get_input_embeddings
(
self
)
->
nn
.
Module
:
return
self
.
vision_model
.
embeddings
.
patch_embedding
def
forward
(
self
,
pixel_values
:
Optional
[
torch
.
FloatTensor
]
=
None
,
):
return
self
.
vision_model
(
pixel_values
=
pixel_values
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
_params_to_ignore
=
[
"text_model"
,
"logit_scale"
,
"vision_model.embeddings.position_ids"
,
"visual_projection.weight"
,
"text_projection.weight"
]
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
()
for
name
,
loaded_weight
in
weights
:
if
any
(
param_name
in
name
for
param_name
in
_params_to_ignore
):
continue
if
not
(
any
(
param_name
in
name
for
param_name
in
self
.
_PARAMS_KEYS_TO_SELECT
)):
continue
if
name
.
startswith
(
"model.vision_tower.vision_tower"
):
name
=
name
.
replace
(
"model.vision_tower.vision_tower."
,
""
)
elif
name
.
startswith
(
"model.vision_tower"
):
name
=
name
.
replace
(
"model.vision_tower."
,
""
)
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
.
split
(
"."
):
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
loaded_params
.
add
(
name
)
break
else
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
params_need_to_load
=
set
(
params_dict
.
keys
())
if
params_need_to_load
!=
loaded_params
:
param_name_example
=
list
(
params_need_to_load
-
loaded_params
)[
0
]
raise
RuntimeError
(
f
"Some parameters like
{
param_name_example
}
are not in the checkpoint and will falsely use random initialization"
)
class
StepCLIPVisionModelWithPostprocess
(
StepCLIPVisionModel
):
_PARAMS_KEYS_TO_SELECT
=
[
"vision_model"
,
"vit_downsampler"
]
def
__init__
(
self
,
config
:
CLIPVisionConfig
,
need_dp
:
bool
=
True
):
super
().
__init__
(
config
,
need_dp
=
need_dp
)
self
.
config
=
config
self
.
vit_downsampler
=
nn
.
Conv2d
(
self
.
config
.
hidden_size
,
self
.
config
.
output_hidden_size
,
kernel_size
=
2
,
stride
=
2
)
self
.
vit_downsampler2
=
nn
.
Conv2d
(
self
.
config
.
output_hidden_size
,
self
.
config
.
output_hidden_size
*
2
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
x
=
super
().
forward
(
x
)[
0
][:,
4
:]
B
,
P
=
x
.
shape
[:
2
]
HW
=
int
(
math
.
sqrt
(
P
))
x
=
x
.
permute
(
0
,
2
,
1
).
view
(
B
,
self
.
config
.
hidden_size
,
HW
,
HW
)
x
=
self
.
vit_downsampler
(
x
)
x
=
self
.
vit_downsampler2
(
x
)
x
=
x
.
view
(
B
,
self
.
config
.
output_hidden_size
*
2
,
-
1
).
permute
(
0
,
2
,
1
)
return
x
\ No newline at end of file
vllm/transformers_utils/config.py
View file @
583034f1
...
@@ -41,6 +41,15 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
...
@@ -41,6 +41,15 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
OvisConfig
,
RWConfig
,
OvisConfig
,
RWConfig
,
Step3TextConfig
,
Step3VLConfig
,
Step3TextConfig
,
Step3VLConfig
,
SkyworkR1VChatConfig
,
SolarConfig
,
SkyworkR1VChatConfig
,
SolarConfig
,
MMGPTStep1Config
,
MMGPTStep1ConfigV2
,
MPTConfig
,
NemotronConfig
,
NVLM_D_Config
,
RWConfig
,
SkyworkR1VChatConfig
,
SolarConfig
,
Step1AudioConfig
,
Step1Config
,
Step1oConfig
,
Step2Config
,
Step2MiniConfig
,
Step3vConfig
,
StepAudioQwen2Config
,
Telechat2Config
,
UltravoxConfig
)
Telechat2Config
,
UltravoxConfig
)
# yapf: enable
# yapf: enable
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.transformers_utils.utils
import
check_gguf_file
...
@@ -75,6 +84,20 @@ _CONFIG_REGISTRY_OVERRIDE_HF: dict[str, type[PretrainedConfig]] = {
...
@@ -75,6 +84,20 @@ _CONFIG_REGISTRY_OVERRIDE_HF: dict[str, type[PretrainedConfig]] = {
"mllama"
:
MllamaConfig
"mllama"
:
MllamaConfig
}
}
_CUSTOM_CONFIG_STEP
=
{
"step1"
:
Step1Config
,
"step2"
:
Step2Config
,
"step2_mini"
:
Step2MiniConfig
,
"mmgpt_step1"
:
MMGPTStep1Config
,
"mmgpt_step1_v2"
:
MMGPTStep1ConfigV2
,
#"mmgpt_qwen2": MMGPTQwen2Config,
#"mmgpt_qwen2_v2": MMGPTQwen2ConfigV2,
"step1o"
:
Step1oConfig
,
"step1_audio"
:
Step1AudioConfig
,
"step_audio_qwen2"
:
StepAudioQwen2Config
,
"step3v"
:
Step3vConfig
,
}
_CONFIG_REGISTRY
:
dict
[
str
,
type
[
PretrainedConfig
]]
=
{
_CONFIG_REGISTRY
:
dict
[
str
,
type
[
PretrainedConfig
]]
=
{
"chatglm"
:
ChatGLMConfig
,
"chatglm"
:
ChatGLMConfig
,
"cohere2"
:
Cohere2Config
,
"cohere2"
:
Cohere2Config
,
...
@@ -100,7 +123,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
...
@@ -100,7 +123,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
"ultravox"
:
UltravoxConfig
,
"ultravox"
:
UltravoxConfig
,
"step3_vl"
:
Step3VLConfig
,
"step3_vl"
:
Step3VLConfig
,
"step3_text"
:
Step3TextConfig
,
"step3_text"
:
Step3TextConfig
,
**
_CONFIG_REGISTRY_OVERRIDE_HF
**
_CONFIG_REGISTRY_OVERRIDE_HF
,
**
_CUSTOM_CONFIG_STEP
}
}
_CONFIG_ATTRS_MAPPING
:
dict
[
str
,
str
]
=
{
_CONFIG_ATTRS_MAPPING
:
dict
[
str
,
str
]
=
{
...
...
vllm/transformers_utils/configs/__init__.py
View file @
583034f1
...
@@ -32,6 +32,18 @@ from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig,
...
@@ -32,6 +32,18 @@ from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig,
Step3VisionEncoderConfig
,
Step3VisionEncoderConfig
,
Step3VLConfig
)
Step3VLConfig
)
from
vllm.transformers_utils.configs.ultravox
import
UltravoxConfig
from
vllm.transformers_utils.configs.ultravox
import
UltravoxConfig
from
vllm.transformers_utils.configs.mmgpt
import
(
CLIPVisionConfig
,
MMGPTQwen2Config
,
MMGPTQwen2ConfigV2
,
MMGPTStep1Config
,
MMGPTStep1ConfigV2
,
SamViTConfig
,
Step1oConfig
,
Step3vConfig
)
from
vllm.transformers_utils.configs.step
import
(
Step1Config
,
Step2Config
,
Step2MiniConfig
)
from
vllm.transformers_utils.configs.step1f
import
(
Step1AudioConfig
,
Step1fAudioEncoderConfig
,
StepAudioQwen2Config
)
__all__
=
[
__all__
=
[
"ChatGLMConfig"
,
"ChatGLMConfig"
,
...
@@ -62,4 +74,21 @@ __all__ = [
...
@@ -62,4 +74,21 @@ __all__ = [
"Step3VLConfig"
,
"Step3VLConfig"
,
"Step3VisionEncoderConfig"
,
"Step3VisionEncoderConfig"
,
"Step3TextConfig"
,
"Step3TextConfig"
,
"Step1Config"
,
"Step2Config"
,
"Step2MiniConfig"
,
"CLIPVisionConfig"
,
"MMGPTBaiChuanConfig"
,
"MMGPTLlamaConfig"
,
"MMGPTLlamaConfigV2"
,
"MMGPTQwen2Config"
,
"MMGPTQwen2ConfigV2"
,
"MMGPTStep1Config"
,
"MMGPTStep1ConfigV2"
,
"Step3vConfig"
,
"SamViTConfig"
,
"Step1oConfig"
,
"Step1AudioConfig"
,
"Step1fAudioEncoderConfig"
,
"StepAudioQwen2Config"
,
]
]
vllm/transformers_utils/configs/mmgpt.py
0 → 100644
View file @
583034f1
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
List
,
Optional
,
Union
from
transformers
import
Qwen2Config
from
transformers.configuration_utils
import
PretrainedConfig
from
vllm.transformers_utils.configs.step
import
Step1Config
,
Step2MiniConfig
class
CLIPVisionConfig
(
PretrainedConfig
):
model_type
=
"clip_vision_model"
def
__init__
(
self
,
hidden_size
=
768
,
intermediate_size
=
3072
,
projection_dim
=
512
,
num_hidden_layers
=
12
,
num_attention_heads
=
12
,
num_channels
=
3
,
image_size
=
224
,
patch_size
=
32
,
hidden_act
=
"quick_gelu"
,
layer_norm_eps
=
1e-5
,
attention_dropout
=
0.0
,
initializer_range
=
0.02
,
initializer_factor
=
1.0
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
projection_dim
=
projection_dim
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
num_channels
=
num_channels
self
.
patch_size
=
patch_size
self
.
image_size
=
image_size
self
.
layer_norm_eps
=
layer_norm_eps
self
.
hidden_act
=
hidden_act
self
.
attention_dropout
=
attention_dropout
self
.
initializer_range
=
initializer_range
self
.
initializer_factor
=
initializer_factor
class
SamViTConfig
(
PretrainedConfig
):
model_type
=
"sam_vit_model"
def
__init__
(
self
,
depth
=
24
,
embed_dim
=
1024
,
image_size
=
1280
,
mlp_ratio
=
4
,
num_heads
=
16
,
patch_size
=
16
,
qkv_bias
=
True
,
use_abs_pos
=
True
,
use_rel_pos
=
True
,
global_attn_indexes
=
(
5
,
11
,
17
,
23
),
window_size
=
14
,
out_channels
=
256
,
layer_norm_eps
=
1e-6
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
depth
=
depth
self
.
embed_dim
=
embed_dim
self
.
image_size
=
image_size
self
.
mlp_ratio
=
mlp_ratio
self
.
num_heads
=
num_heads
self
.
patch_size
=
patch_size
self
.
qkv_bias
=
qkv_bias
self
.
use_abs_pos
=
use_abs_pos
self
.
use_rel_pos
=
use_rel_pos
self
.
global_attn_indexes
=
global_attn_indexes
self
.
window_size
=
window_size
self
.
out_channels
=
out_channels
self
.
layer_norm_eps
=
layer_norm_eps
class
MMGPTStep1Config
(
Step1Config
):
# for step1.5
model_type
=
"mmgpt_step1"
def
__init__
(
self
,
hidden_size
:
int
=
5120
,
intermediate_size
:
int
=
13312
,
num_attention_heads
:
int
=
40
,
num_attention_groups
:
int
=
8
,
num_hidden_layers
:
int
=
48
,
max_seq_len
:
int
=
4096
,
vocab_size
:
int
=
65536
,
rms_norm_eps
:
float
=
1e-5
,
use_im_start_end
=
True
,
vision_select_layer
=-
2
,
image_token_len
=
None
,
projector_stride
=
1
,
vision_tower_config
=
None
,
image_token_id
=
13
,
image_seq_length
=
400
,
bos_token_id
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
eos_token_id
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
**
kwargs
,
)
->
None
:
super
().
__init__
(
bos_token_id
=
1
if
bos_token_id
is
None
else
bos_token_id
,
eos_token_id
=
[
2
,
3
]
if
eos_token_id
is
None
else
eos_token_id
,
**
kwargs
)
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_groups
=
num_attention_groups
self
.
num_hidden_layers
=
num_hidden_layers
self
.
max_seq_len
=
max_seq_len
self
.
vocab_size
=
vocab_size
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_im_start_end
=
use_im_start_end
self
.
vision_select_layer
=
vision_select_layer
self
.
image_token_len
=
image_token_len
self
.
projector_stride
=
projector_stride
self
.
image_token_id
=
image_token_id
self
.
image_seq_length
=
image_seq_length
self
.
vision_tower_config
=
CLIPVisionConfig
(
**
vision_tower_config
)
if
vision_tower_config
is
not
None
else
None
self
.
text_config
=
Step1Config
(
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
num_attention_heads
=
num_attention_heads
,
num_attention_groups
=
num_attention_groups
,
num_hidden_layers
=
num_hidden_layers
,
max_seq_len
=
max_seq_len
,
vocab_size
=
vocab_size
,
rms_norm_eps
=
rms_norm_eps
,
architectures
=
[
"Step1ForCausalLM"
],
torch_dtype
=
getattr
(
self
,
"torch_dtype"
,
"bfloat16"
),
)
class
MMGPTStep1ConfigV2
(
Step1Config
):
# for step1.5c/step1u, models with both vit and sam encoders
model_type
=
"mmgpt_step1_v2"
def
__init__
(
self
,
hidden_size
:
int
=
5120
,
intermediate_size
:
int
=
13312
,
num_attention_heads
:
int
=
40
,
num_attention_groups
:
int
=
8
,
num_hidden_layers
:
int
=
48
,
max_seq_len
:
int
=
4096
,
vocab_size
:
int
=
65536
,
rms_norm_eps
:
float
=
1e-5
,
use_im_start_end
=
True
,
vision_select_layer
=-
1
,
image_token_len
=
None
,
understand_projector_stride
=
1
,
vit_scale
=
1.0
,
projector_bias
=
True
,
vision_tower_config
=
None
,
sam_model_config
=
None
,
image_token_id
=
13
,
bos_token_id
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
eos_token_id
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
**
kwargs
,
)
->
None
:
super
().
__init__
(
bos_token_id
=
1
if
bos_token_id
is
None
else
bos_token_id
,
eos_token_id
=
[
2
,
3
]
if
eos_token_id
is
None
else
eos_token_id
,
**
kwargs
)
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_groups
=
num_attention_groups
self
.
num_hidden_layers
=
num_hidden_layers
self
.
max_seq_len
=
max_seq_len
self
.
vocab_size
=
vocab_size
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_im_start_end
=
use_im_start_end
self
.
vision_select_layer
=
vision_select_layer
self
.
image_token_len
=
image_token_len
self
.
image_token_id
=
image_token_id
self
.
understand_projector_stride
=
understand_projector_stride
self
.
vit_scale
=
vit_scale
self
.
projector_bias
=
projector_bias
self
.
vision_tower_config
=
CLIPVisionConfig
(
**
vision_tower_config
)
if
vision_tower_config
is
not
None
else
None
self
.
sam_model_config
=
SamViTConfig
(
**
sam_model_config
)
if
sam_model_config
is
not
None
else
None
self
.
text_config
=
Step1Config
(
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
num_attention_heads
=
num_attention_heads
,
num_attention_groups
=
num_attention_groups
,
num_hidden_layers
=
num_hidden_layers
,
max_seq_len
=
max_seq_len
,
vocab_size
=
vocab_size
,
rms_norm_eps
=
rms_norm_eps
,
architectures
=
[
"Step1ForCausalLM"
],
torch_dtype
=
getattr
(
self
,
"torch_dtype"
,
"bfloat16"
),
)
class
Step1oConfig
(
Step1Config
):
# for step1o
model_type
=
"step1o"
def
__init__
(
self
,
hidden_size
:
int
=
5120
,
intermediate_size
:
int
=
13312
,
num_attention_heads
:
int
=
40
,
num_attention_groups
:
int
=
8
,
num_hidden_layers
:
int
=
48
,
max_seq_len
:
int
=
4096
,
vocab_size
:
int
=
65536
,
rms_norm_eps
:
float
=
1e-5
,
use_im_start_end
=
True
,
vision_select_layer
=-
1
,
image_token_len
=
None
,
image_token_id
=
13
,
understand_projector_stride
=
1
,
vit_scale
=
1.0
,
projector_bias
=
True
,
patch_token_len
=
None
,
vision_tower_config
=
None
,
bos_token_id
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
eos_token_id
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
**
kwargs
,
)
->
None
:
super
().
__init__
(
bos_token_id
=
1
if
bos_token_id
is
None
else
bos_token_id
,
eos_token_id
=
[
2
,
3
]
if
eos_token_id
is
None
else
eos_token_id
,
**
kwargs
)
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_groups
=
num_attention_groups
self
.
num_hidden_layers
=
num_hidden_layers
self
.
max_seq_len
=
max_seq_len
self
.
vocab_size
=
vocab_size
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_im_start_end
=
use_im_start_end
self
.
vision_select_layer
=
vision_select_layer
self
.
image_token_len
=
image_token_len
self
.
image_token_id
=
image_token_id
self
.
understand_projector_stride
=
understand_projector_stride
self
.
vit_scale
=
vit_scale
self
.
projector_bias
=
projector_bias
self
.
patch_token_len
=
patch_token_len
if
patch_token_len
is
not
None
else
self
.
image_token_len
self
.
vision_tower_config
=
CLIPVisionConfig
(
**
vision_tower_config
)
if
vision_tower_config
is
not
None
else
None
self
.
text_config
=
Step1Config
(
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
num_attention_heads
=
num_attention_heads
,
num_attention_groups
=
num_attention_groups
,
num_hidden_layers
=
num_hidden_layers
,
max_seq_len
=
max_seq_len
,
vocab_size
=
vocab_size
,
rms_norm_eps
=
rms_norm_eps
,
architectures
=
[
"Step1ForCausalLM"
],
torch_dtype
=
getattr
(
self
,
"torch_dtype"
,
"bfloat16"
),
)
class
MMGPTQwen2Config
(
PretrainedConfig
):
# for step1.5t
model_type
=
"mmgpt_qwen2"
def
__init__
(
self
,
vocab_size
=
64012
,
hidden_size
=
4096
,
intermediate_size
=
11008
,
num_hidden_layers
=
48
,
num_attention_heads
=
32
,
num_attention_groups
=
4
,
num_key_value_heads
=
4
,
hidden_act
=
"silu"
,
max_position_embeddings
=
8192
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
rope_theta
=
1000000.0
,
rope_scaling
=
None
,
use_im_start_end
=
True
,
vision_select_layer
=-
1
,
image_token_len
=
None
,
image_token_id
=
151656
,
understand_projector_stride
=
1
,
vit_scale
=
1.0
,
projector_bias
=
True
,
pad_token_id
=-
1
,
vision_tower_config
=
None
,
sam_model_config
=
None
,
eos_token_id
=
None
,
**
kwargs
,
)
->
None
:
if
eos_token_id
is
not
None
:
if
isinstance
(
eos_token_id
,
list
):
eos_token_id
=
list
(
set
([
151643
,
151646
]
+
eos_token_id
))
else
:
eos_token_id
=
[
151643
,
151646
,
eos_token_id
]
else
:
eos_token_id
=
[
151643
,
151646
]
super
().
__init__
(
eos_token_id
=
eos_token_id
,
**
kwargs
)
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_groups
=
num_attention_groups
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
max_position_embeddings
=
max_position_embeddings
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
use_im_start_end
=
use_im_start_end
self
.
vision_select_layer
=
vision_select_layer
self
.
image_token_len
=
image_token_len
self
.
image_token_id
=
image_token_id
self
.
understand_projector_stride
=
understand_projector_stride
self
.
vit_scale
=
vit_scale
self
.
projector_bias
=
projector_bias
self
.
pad_token_id
=
pad_token_id
self
.
vision_tower_config
=
CLIPVisionConfig
(
**
vision_tower_config
)
if
vision_tower_config
is
not
None
else
None
self
.
sam_model_config
=
SamViTConfig
(
**
sam_model_config
)
if
sam_model_config
is
not
None
else
None
self
.
text_config
=
Qwen2Config
(
vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
num_hidden_layers
=
num_hidden_layers
,
num_attention_heads
=
num_attention_heads
,
num_key_value_heads
=
num_key_value_heads
,
hidden_act
=
hidden_act
,
max_position_embeddings
=
max_position_embeddings
,
initializer_range
=
initializer_range
,
rms_norm_eps
=
rms_norm_eps
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
architectures
=
[
"Qwen2ForCausalLM"
],
torch_dtype
=
getattr
(
self
,
"torch_dtype"
,
"bfloat16"
),
)
class
MMGPTQwen2ConfigV2
(
MMGPTQwen2Config
):
model_type
=
"mmgpt_qwen2_v2"
def
__init__
(
self
,
vocab_size
=
64012
,
hidden_size
=
4096
,
intermediate_size
=
11008
,
num_hidden_layers
=
48
,
num_attention_heads
=
32
,
num_attention_groups
=
4
,
num_key_value_heads
=
4
,
hidden_act
=
"silu"
,
max_position_embeddings
=
8192
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
rope_theta
=
1000000.0
,
rope_scaling
=
None
,
use_im_start_end
=
True
,
vision_select_layer
=-
1
,
image_token_len
=
None
,
image_token_id
=
151675
,
understand_projector_stride
=
1
,
vit_scale
=
1.0
,
projector_bias
=
True
,
pad_token_id
=-
1
,
vision_tower_config
=
None
,
sam_model_config
=
None
,
eos_token_id
=
None
,
**
kwargs
,
)
->
None
:
if
eos_token_id
is
not
None
:
if
isinstance
(
eos_token_id
,
list
):
eos_token_id
=
list
(
set
([
151643
,
151645
,
151665
]
+
eos_token_id
))
else
:
eos_token_id
=
[
151643
,
151645
,
151665
,
eos_token_id
]
else
:
eos_token_id
=
[
151643
,
151645
,
151665
]
super
().
__init__
(
eos_token_id
=
eos_token_id
,
**
kwargs
)
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_groups
=
num_attention_groups
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
max_position_embeddings
=
max_position_embeddings
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
use_im_start_end
=
use_im_start_end
self
.
vision_select_layer
=
vision_select_layer
self
.
image_token_len
=
image_token_len
self
.
image_token_id
=
image_token_id
self
.
understand_projector_stride
=
understand_projector_stride
self
.
vit_scale
=
vit_scale
self
.
projector_bias
=
projector_bias
self
.
pad_token_id
=
pad_token_id
self
.
vision_tower_config
=
CLIPVisionConfig
(
**
vision_tower_config
)
if
vision_tower_config
is
not
None
else
None
self
.
sam_model_config
=
SamViTConfig
(
**
sam_model_config
)
if
sam_model_config
is
not
None
else
None
self
.
text_config
=
Qwen2Config
(
vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
num_hidden_layers
=
num_hidden_layers
,
num_attention_heads
=
num_attention_heads
,
num_key_value_heads
=
num_key_value_heads
,
hidden_act
=
hidden_act
,
max_position_embeddings
=
max_position_embeddings
,
initializer_range
=
initializer_range
,
rms_norm_eps
=
rms_norm_eps
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
architectures
=
[
"Qwen2ForCausalLM"
],
torch_dtype
=
getattr
(
self
,
"torch_dtype"
,
"bfloat16"
),
)
class
Step3vConfig
(
Step1Config
):
model_type
=
"step3v"
def
__init__
(
self
,
hidden_size
:
int
=
5120
,
intermediate_size
:
int
=
13312
,
num_attention_heads
:
int
=
40
,
num_attention_groups
:
int
=
8
,
num_hidden_layers
:
int
=
48
,
max_seq_len
:
int
=
4096
,
vocab_size
:
int
=
65536
,
rms_norm_eps
:
float
=
1e-5
,
moe_every_n_layer
:
int
=
2
,
# 2 means 50% layers use MoE, interleaved with normal non-MoE layers.
use_moe
:
bool
=
False
,
moe_intermediate_size
:
int
=
10240
,
moe_num_experts
:
int
=
16
,
moe_top_k
:
int
=
4
,
max_pos_interp_ratio
:
float
=
1
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
moe_layer_offset
:
int
=
0
,
moe_dynamic_exp_p
:
float
=
1.0
,
rope_theta
:
float
=
500000
,
rope_scaling
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
head_dim
:
Optional
[
int
]
=
None
,
max_position_embedding
:
int
=
16384
,
share_expert_dim
:
Optional
[
int
]
=
None
,
allgather_dtype
:
Optional
[
str
]
=
None
,
share_q_dim
:
Optional
[
int
]
=
None
,
norm_expert_weight
:
bool
=
True
,
moe_layers_enum
:
Optional
[
str
]
=
None
,
use_im_start_end
:
bool
=
True
,
vision_select_layer
:
int
=
-
1
,
image_token_len
:
Optional
[
int
]
=
None
,
image_token_id
:
int
=
128001
,
understand_projector_stride
:
int
=
1
,
vit_scale
:
float
=
1.0
,
projector_bias
:
bool
=
True
,
patch_token_len
:
Optional
[
int
]
=
None
,
vision_tower_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
bos_token_id
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
eos_token_id
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
**
kwargs
,
)
->
None
:
super
().
__init__
(
bos_token_id
=
0
if
bos_token_id
is
None
else
bos_token_id
,
eos_token_id
=
[
1
,
128805
]
if
eos_token_id
is
None
else
eos_token_id
,
**
kwargs
)
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_groups
=
num_attention_groups
self
.
num_hidden_layers
=
num_hidden_layers
self
.
max_seq_len
=
max_seq_len
self
.
vocab_size
=
vocab_size
self
.
rms_norm_eps
=
rms_norm_eps
self
.
moe_every_n_layer
=
moe_every_n_layer
self
.
use_moe
=
use_moe
self
.
moe_intermediate_size
=
moe_intermediate_size
self
.
moe_num_experts
=
moe_num_experts
self
.
moe_top_k
=
moe_top_k
self
.
max_pos_interp_ratio
=
max_pos_interp_ratio
self
.
alibi_slopes
=
alibi_slopes
self
.
moe_layer_offset
=
moe_layer_offset
self
.
moe_dynamic_exp_p
=
moe_dynamic_exp_p
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
head_dim
=
head_dim
self
.
max_position_embedding
=
max_position_embedding
self
.
share_expert_dim
=
share_expert_dim
self
.
allgather_dtype
=
allgather_dtype
self
.
share_q_dim
=
share_q_dim
self
.
norm_expert_weight
=
norm_expert_weight
self
.
use_im_start_end
=
use_im_start_end
self
.
vision_select_layer
=
vision_select_layer
self
.
image_token_len
=
image_token_len
self
.
image_token_id
=
image_token_id
self
.
understand_projector_stride
=
understand_projector_stride
self
.
vit_scale
=
vit_scale
self
.
projector_bias
=
projector_bias
self
.
patch_token_len
=
patch_token_len
if
patch_token_len
is
not
None
else
self
.
image_token_len
self
.
vision_tower_config
=
CLIPVisionConfig
(
**
vision_tower_config
)
if
vision_tower_config
is
not
None
else
None
self
.
text_config
=
Step2MiniConfig
(
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
num_attention_heads
=
num_attention_heads
,
num_attention_groups
=
num_attention_groups
,
num_hidden_layers
=
num_hidden_layers
,
max_seq_len
=
max_seq_len
,
vocab_size
=
vocab_size
,
rms_norm_eps
=
rms_norm_eps
,
moe_every_n_layer
=
moe_every_n_layer
,
use_moe
=
use_moe
,
moe_intermediate_size
=
moe_intermediate_size
,
moe_num_experts
=
moe_num_experts
,
moe_top_k
=
moe_top_k
,
max_pos_interp_ratio
=
max_pos_interp_ratio
,
alibi_slopes
=
alibi_slopes
,
moe_layer_offset
=
moe_layer_offset
,
moe_dynamic_exp_p
=
moe_dynamic_exp_p
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
head_dim
=
head_dim
,
max_position_embedding
=
max_position_embedding
,
share_expert_dim
=
share_expert_dim
,
allgather_dtype
=
allgather_dtype
,
share_q_dim
=
share_q_dim
,
norm_expert_weight
=
norm_expert_weight
,
moe_layers_enum
=
moe_layers_enum
,
architectures
=
[
"Step2MiniForCausalLM"
],
torch_dtype
=
getattr
(
self
,
"torch_dtype"
,
"bfloat16"
),
)
\ No newline at end of file
vllm/transformers_utils/configs/step.py
0 → 100644
View file @
583034f1
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
transformers
import
PretrainedConfig
class
StepConfig
(
PretrainedConfig
):
model_type
=
"step"
def
__init__
(
self
,
hidden_size
:
int
=
5120
,
intermediate_size
:
int
=
13312
,
num_attention_heads
:
int
=
40
,
num_attention_groups
:
int
=
8
,
num_hidden_layers
:
int
=
48
,
max_seq_len
:
int
=
4096
,
vocab_size
:
int
=
65536
,
rms_norm_eps
:
float
=
1e-5
,
moe_every_n_layer
:
int
=
2
,
# 2 means 50% layers use MoE, interleaved with normal non-MoE layers.
use_moe
:
bool
=
False
,
moe_intermediate_size
:
int
=
10240
,
moe_num_experts
:
int
=
16
,
moe_top_k
:
int
=
4
,
max_pos_interp_ratio
:
float
=
1
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
moe_layer_offset
:
int
=
0
,
moe_dynamic_exp_p
:
float
=
1.0
,
rope_theta
:
float
=
500000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
head_dim
:
Optional
[
int
]
=
None
,
max_position_embedding
:
int
=
16384
,
share_expert_dim
:
Optional
[
int
]
=
None
,
allgather_dtype
:
Optional
[
str
]
=
None
,
share_q_dim
:
Optional
[
int
]
=
None
,
norm_expert_weight
:
bool
=
True
,
bos_token_id
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
eos_token_id
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
**
kwargs
,
)
->
None
:
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_groups
=
num_attention_groups
self
.
num_hidden_layers
=
num_hidden_layers
self
.
max_seq_len
=
max_seq_len
self
.
vocab_size
=
vocab_size
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_moe
=
use_moe
self
.
moe_intermediate_size
=
moe_intermediate_size
self
.
moe_every_n_layer
=
moe_every_n_layer
self
.
moe_num_experts
=
moe_num_experts
self
.
moe_top_k
=
moe_top_k
self
.
max_pos_interp_ratio
=
max_pos_interp_ratio
self
.
alibi_slopes
=
alibi_slopes
self
.
moe_layer_offset
=
moe_layer_offset
self
.
moe_dynamic_exp_p
=
moe_dynamic_exp_p
#for step2 mini
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
head_dim
=
head_dim
self
.
max_position_embedding
=
max_position_embedding
if
share_expert_dim
is
None
:
self
.
share_expert_dim
=
self
.
moe_intermediate_size
*
self
.
moe_top_k
else
:
self
.
share_expert_dim
=
share_expert_dim
self
.
share_q_dim
=
share_q_dim
self
.
norm_expert_weight
=
norm_expert_weight
self
.
allgather_dtype
=
allgather_dtype
self
.
_verify_slopes
()
super
().
__init__
(
bos_token_id
=
1
if
bos_token_id
is
None
else
bos_token_id
,
eos_token_id
=
[
2
,
3
]
if
eos_token_id
is
None
else
eos_token_id
,
**
kwargs
)
def
_verify_slopes
(
self
):
if
self
.
alibi_slopes
is
None
:
return
if
len
(
self
.
alibi_slopes
)
!=
self
.
num_attention_heads
:
raise
ValueError
(
f
"Number of alibi_slopes (
{
len
(
self
.
alibi_slopes
)
}
) does not match num_attention_heads (
{
self
.
num_attention_heads
}
)"
)
class
Step1Config
(
StepConfig
):
model_type
=
"step1"
class
Step2Config
(
StepConfig
):
model_type
=
"step2"
def
__init__
(
self
,
use_offline_input_scales
:
bool
=
True
,
**
kwargs
):
self
.
use_offline_input_scales
=
use_offline_input_scales
super
().
__init__
(
**
kwargs
)
class
Step2MiniConfig
(
StepConfig
):
model_type
=
"step2_mini"
\ No newline at end of file
vllm/transformers_utils/configs/step1f.py
0 → 100644
View file @
583034f1
# SPDX-License-Identifier: Apache-2.0
from
transformers
import
Qwen2Config
from
transformers.configuration_utils
import
PretrainedConfig
from
vllm.transformers_utils.configs.step
import
Step1Config
class
Step1fAudioEncoderConfig
(
PretrainedConfig
):
model_type
=
"stepasr_encoder"
def
__init__
(
self
,
n_mels
:
int
=
128
,
n_audio_ctx
:
int
=
1500
,
n_audio_state
:
int
=
1280
,
n_audio_head
:
int
=
20
,
n_audio_layer
:
int
=
32
,
n_codebook_size
:
int
=
4096
,
llm_dim
:
int
=
3072
,
kernel_size
:
int
=
3
,
adapter_stride
:
int
=
2
,
adapter_state
:
int
=
2048
,
**
kwargs
,
)
->
None
:
super
().
__init__
(
**
kwargs
)
self
.
n_mels
=
n_mels
self
.
n_audio_ctx
=
n_audio_ctx
self
.
n_audio_state
=
n_audio_state
self
.
n_audio_head
=
n_audio_head
self
.
n_audio_layer
=
n_audio_layer
self
.
n_codebook_size
=
n_codebook_size
self
.
llm_dim
=
llm_dim
self
.
kernel_size
=
kernel_size
self
.
adapter_stride
=
adapter_stride
self
.
adapter_state
=
adapter_state
class
Step1AudioConfig
(
PretrainedConfig
):
# for step1.5t
model_type
=
"step1_audio"
def
__init__
(
self
,
hidden_size
:
int
=
5120
,
intermediate_size
:
int
=
13312
,
num_attention_heads
:
int
=
40
,
num_attention_groups
:
int
=
8
,
num_hidden_layers
:
int
=
48
,
max_seq_len
:
int
=
4096
,
vocab_size
:
int
=
65536
,
rms_norm_eps
:
float
=
1e-5
,
audio_token_id
:
int
=
29
,
eos_token_id
=
None
,
audio_encoder_config
=
None
,
**
kwargs
,
)
->
None
:
if
eos_token_id
is
not
None
:
if
isinstance
(
eos_token_id
,
list
):
eos_token_id
=
list
(
set
([
2
,
3
]
+
eos_token_id
))
else
:
eos_token_id
=
[
2
,
3
,
eos_token_id
]
else
:
eos_token_id
=
[
2
,
3
]
super
().
__init__
(
eos_token_id
=
eos_token_id
,
**
kwargs
)
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_groups
=
num_attention_groups
self
.
max_seq_len
=
max_seq_len
self
.
rms_norm_eps
=
rms_norm_eps
self
.
audio_token_id
=
audio_token_id
self
.
audio_encoder_config
=
Step1fAudioEncoderConfig
(
**
audio_encoder_config
)
if
audio_encoder_config
is
not
None
else
None
self
.
text_config
=
Step1Config
(
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
num_attention_heads
=
num_attention_heads
,
num_attention_groups
=
num_attention_groups
,
num_hidden_layers
=
num_hidden_layers
,
max_seq_len
=
max_seq_len
,
vocab_size
=
vocab_size
,
rms_norm_eps
=
rms_norm_eps
,
architectures
=
[
"Step1ForCausalLM"
],
torch_dtype
=
getattr
(
self
,
"torch_dtype"
,
"bfloat16"
),
)
class
StepAudioQwen2Config
(
PretrainedConfig
):
model_type
=
"step_audio_qwen2"
def
__init__
(
self
,
vocab_size
=
64012
,
hidden_size
=
4096
,
intermediate_size
=
11008
,
num_hidden_layers
=
48
,
num_attention_heads
=
32
,
num_attention_groups
=
4
,
num_key_value_heads
=
4
,
hidden_act
=
"silu"
,
max_position_embeddings
=
8192
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
rope_theta
=
1000000.0
,
rope_scaling
=
None
,
audio_token_id
=
151690
,
eos_token_id
=
None
,
audio_encoder_config
=
None
,
**
kwargs
):
if
eos_token_id
is
not
None
:
if
isinstance
(
eos_token_id
,
list
):
eos_token_id
=
list
(
set
([
151643
,
151645
,
151665
]
+
eos_token_id
))
else
:
eos_token_id
=
[
151643
,
151645
,
151665
,
eos_token_id
]
else
:
eos_token_id
=
[
151643
,
151645
,
151665
]
super
().
__init__
(
eos_token_id
=
eos_token_id
,
**
kwargs
)
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_groups
=
num_attention_groups
self
.
num_key_value_heads
=
num_key_value_heads
assert
self
.
num_attention_groups
==
self
.
num_key_value_heads
,
"num_attention_groups must be equal to num_key_value_heads"
self
.
hidden_act
=
hidden_act
self
.
max_position_embeddings
=
max_position_embeddings
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
audio_encoder_config
=
Step1fAudioEncoderConfig
(
**
audio_encoder_config
)
if
audio_encoder_config
is
not
None
else
None
self
.
audio_token_id
=
audio_token_id
self
.
text_config
=
Qwen2Config
(
vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
num_hidden_layers
=
num_hidden_layers
,
num_attention_heads
=
num_attention_heads
,
num_key_value_heads
=
num_key_value_heads
,
hidden_act
=
hidden_act
,
max_position_embeddings
=
max_position_embeddings
,
initializer_range
=
initializer_range
,
rms_norm_eps
=
rms_norm_eps
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
architectures
=
[
"Qwen2ForCausalLM"
],
torch_dtype
=
getattr
(
self
,
"torch_dtype"
,
"bfloat16"
),
)
\ No newline at end of file
vllm/transformers_utils/detokenizer_utils.py
View file @
583034f1
...
@@ -4,6 +4,8 @@
...
@@ -4,6 +4,8 @@
from
typing
import
Optional
from
typing
import
Optional
from
.tokenizer
import
AnyTokenizer
from
.tokenizer
import
AnyTokenizer
# from vllm.transformers_utils.tokenizers.sentencepiece_tokenizer import (
# SentencePieceTokenizer)
def
_replace_none_with_empty
(
tokens
:
list
[
Optional
[
str
]]):
def
_replace_none_with_empty
(
tokens
:
list
[
Optional
[
str
]]):
...
@@ -171,6 +173,13 @@ def detokenize_incrementally(
...
@@ -171,6 +173,13 @@ def detokenize_incrementally(
# The prefix text is necessary only to defeat cleanup algorithms in
# The prefix text is necessary only to defeat cleanup algorithms in
# the decode which decide to add a space or not depending on the
# the decode which decide to add a space or not depending on the
# surrounding ids.
# surrounding ids.
# FIXME(ys): for step1 sentencepiece tokenizer, we need to handle the special tokens in convert_tokens_to_string
# if isinstance(tokenizer, SentencePieceTokenizer):
# prefix_text = tokenizer.convert_tokens_to_string(
# output_tokens[prefix_offset:read_offset], skip_special_tokens=skip_special_tokens)
# new_text = tokenizer.convert_tokens_to_string(
# output_tokens[prefix_offset:], skip_special_tokens=skip_special_tokens)
if
tokenizer
.
is_fast
or
not
tokenizer
.
get_added_vocab
():
if
tokenizer
.
is_fast
or
not
tokenizer
.
get_added_vocab
():
prefix_text
=
tokenizer
.
convert_tokens_to_string
(
prefix_text
=
tokenizer
.
convert_tokens_to_string
(
output_tokens
[
prefix_offset
:
read_offset
])
output_tokens
[
prefix_offset
:
read_offset
])
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment