Commit 583034f1 authored by zhuwenwen's avatar zhuwenwen
Browse files

[models] support step3v

parent 0adf9cda
......@@ -3418,6 +3418,8 @@ def _get_and_verify_max_len(
possible_keys = [
# OPT
"max_position_embeddings",
# step3
"max_position_embedding",
# GPT-2
"n_positions",
# MPT
......@@ -3490,8 +3492,14 @@ def _get_and_verify_max_len(
# No need to consider "type" key because of patch_rope_scaling when
# loading HF config
rope_type = rope_scaling["rope_type"]
if rope_type == "ntk_bypart":
derived_max_model_len = min(
derived_max_model_len,
rope_scaling["real_length"] * rope_scaling["scaling_factor"]
) if "real_length" in rope_scaling and "scaling_factor" in rope_scaling else derived_max_model_len
if rope_type not in ("su", "longrope", "llama3"):
elif rope_type not in ("su", "longrope", "llama3"):
if disable_sliding_window:
# TODO(robertgshaw): Find a model that supports rope_scaling
# with sliding window to see if this case should be allowed.
......@@ -3548,6 +3556,8 @@ def _get_and_verify_max_len(
logger.warning(
"%s Make sure the value is correct and within the "
"model context size.", msg)
if getattr(hf_config, "max_position_embedding", None) is not None: # step3/3v
hf_config.max_position_embedding = max_model_len
else:
raise ValueError(
f"{msg} To allow overriding this maximum, set "
......
......@@ -36,4 +36,7 @@ __all__ = [
"xLAMToolParser",
"MinimaxToolParser",
"Glm4MoeModelToolParser",
"Step1p5vMini2ToolParser",
"Step1p5vMini2MsToolParser",
"Step3ToolParser",
]
......@@ -3,6 +3,7 @@
"""Custom activation functions."""
import math
from typing import Optional
import optimus
import torch
import torch.nn as nn
......@@ -53,6 +54,14 @@ class FatreluAndMul(CustomOp):
return out
class OptimusSiluAndMul(nn.Module):
def forward(self,
x: torch.Tensor,
output: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.ops.Optimus.SiluDot_forward(x, out=output)
@CustomOp.register("silu_and_mul")
class SiluAndMul(CustomOp):
"""An activation function for SwiGLU.
......
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Custom normalization layers."""
from typing import Optional, Union, Tuple
import optimus # noqa F401
import torch
import torch.nn as nn
......@@ -298,6 +299,49 @@ class RMSNorm(CustomOp):
return s
class OptimusRMSNorm(nn.Module):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
output: Optional[torch.Tensor] = None,
fp16_out: bool = False) -> torch.Tensor:
if residual is not None:
assert output is None
from vllm import _custom_ops as ops
assert not fp16_out
ops.fused_add_rms_norm(
x,
residual,
self.weight.data,
self.variance_epsilon,
)
return x, residual
else:
if fp16_out:
if output is None:
output = torch.empty_like(x).half()
else:
output = output.half()
# return torch.ops.Optimus.rms_norm(x,
# self.weight,
# self.variance_epsilon,
# out=output)
return torch.nn.functional.rms_norm(x,
self.weight,
self.variance_epsilon,
out=output)
@CustomOp.register("gemma_rms_norm")
class GemmaRMSNorm(CustomOp):
"""RMS normalization for Gemma.
......@@ -363,3 +407,35 @@ class GemmaRMSNorm(CustomOp):
self.forward_static)
self._is_compiled = True
return self.forward_native(x, residual)
class OptimusLayerNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
output: Optional[torch.Tensor] = None) -> torch.Tensor:
assert residual is None
# return torch.ops.Optimus.layer_norm(x,
# self.weight,
# self.bias,
# eps=self.variance_epsilon,
# out=output)
# return torch.nn.functional.layer_norm(x,
# self.weight,
# self.bias,
# eps=self.variance_epsilon,
# out=output)
return torch.nn.functional.layer_norm(
x,
self.weight.shape, # normalized_shape 应为 weight 的形状
self.weight,
self.bias,
eps=self.variance_epsilon
)
......@@ -3,7 +3,7 @@
import itertools
from abc import abstractmethod
from typing import Any, Literal, Optional, Union
from typing import Any, Literal, Optional, Union, List
import vllm.envs as envs
import torch
import torch.nn as nn
......@@ -269,6 +269,40 @@ class UnquantizedLinearMethod(LinearMethodBase):
return dispatch_unquantized_gemm()(x, layer.weight, bias)
class UnquantizedMoELinearMethod(LinearMethodBase):
"""MoE Linear method without quantization.
"""
def __init__(self):
self.quant_config = None
def create_weights(self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
num_experts: Optional[int] = None,
**extra_weight_attrs):
weight = Parameter(torch.empty(num_experts,
sum(output_partition_sizes),
input_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 2, "output_dim": 1})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Apply the weights to the input tensor."""
raise NotImplementedError
class LinearBase(torch.nn.Module):
"""Base linear layer.
......@@ -783,6 +817,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
if len(param.data_container) == 2:
self.qweight = param.materialize_nested()
return
param_data = param.data
......@@ -986,6 +1022,175 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset=shard_offset,
shard_size=shard_size)
class MergedColumnParallelMoELinear(MergedColumnParallelLinear):
def __init__(self,
num_experts: int,
input_size: int,
output_sizes: List[int],
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
torch.nn.Module.__init__(self)
output_size = sum(output_sizes)
self.num_experts = num_experts
self.output_sizes = output_sizes
self.input_size = input_size
self.output_size = sum(output_sizes)
tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
self.output_size_per_partition = divide(self.output_size, tp_size)
self.output_partition_sizes = [
divide(output_size, tp_size) for output_size in self.output_sizes
]
self.gather_output = False
if output_sizes is None:
output_sizes = [output_size]
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
if quant_config is None:
self.quant_method = UnquantizedMoELinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)
# FIXME(ys): hack for moe
if isinstance(self.quant_method, UnquantizedLinearMethod):
self.quant_method = UnquantizedMoELinearMethod()
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size,
self.output_partition_sizes,
self.input_size,
self.output_size,
self.params_dtype,
self.num_experts,
weight_loader=self.weight_loader)
self.register_parameter("bias", None)
def forward(self,
input_,
output: Optional[torch.Tensor] = None,
expert_idx: int = -1):
if isinstance(self.quant_method, UnquantizedMoELinearMethod):
# use optimus moe_ffn outside
return
bias = None
assert self.quant_method is not None
output = self.quant_method.apply(self,
input_,
bias,
expert_idx=expert_idx,
output=output)
return output
class QKVReplicatedLinear(ReplicatedLinear):
def __init__(self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: Optional[int] = None,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
return_bias: bool = True):
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.head_size = head_size
self.num_heads = total_num_heads
self.num_kv_heads = total_num_kv_heads if total_num_kv_heads else total_num_heads
self.input_size = self.hidden_size
self.output_size = (self.num_heads +
2 * self.num_kv_heads) * self.head_size
self.skip_bias_add = skip_bias_add
self.return_bias = return_bias
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
if quant_config is None:
self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size, [self.output_size],
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader)
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=self.params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader
})
else:
self.register_parameter("bias", None)
def weight_loader(self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
param_data = param.data
output_dim = getattr(param, "output_dim", None)
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
if loaded_shard_id is None:
# Loaded weight is already packed.
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
assert loaded_shard_id in ["q", "k", "v"]
if output_dim is not None:
if loaded_shard_id == "q":
shard_offset = 0
shard_size = self.num_heads * self.head_size
elif loaded_shard_id == "k":
shard_offset = self.num_heads * self.head_size
shard_size = self.num_kv_heads * self.head_size
elif loaded_shard_id == "v":
shard_offset = (self.num_heads +
self.num_kv_heads) * self.head_size
shard_size = self.num_kv_heads * self.head_size
# If quantized, we need to adjust the offset and size to account
# for the packing.
packed_dim = getattr(param, "packed_dim", None)
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
if not envs.VLLM_USE_NN or is_quantization:
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
else:
param_data = param_data.narrow(int(not(output_dim)), shard_offset,
shard_size)
else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
logger.warning(
"Loading a weight without `output_dim` attribute in "
"QKVReplicatedLinear, assume the weight is the same "
"for all partitions.")
if envs.VLLM_USE_NN and not is_quantization:
loaded_weight = loaded_weight.t()
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class QKVParallelLinear(ColumnParallelLinear):
"""Linear layers for the attention's QKV transformation.
......@@ -1185,6 +1390,8 @@ class QKVParallelLinear(ColumnParallelLinear):
param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
if len(param.data_container) == 3:
self.qweight = param.materialize_nested()
return
param_data = param.data
......@@ -1495,7 +1702,7 @@ class RowParallelLinear(LinearBase):
def forward(
self, input_,
use_fused_silu_mul_quant: Optional[bool] = False
use_fused_silu_mul_quant: Optional[bool] = False,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
......@@ -1757,4 +1964,63 @@ class QKVCrossParallelLinear(LinearBase):
s += f", bias={self.bias is not None}"
s += f", tp_size={get_tensor_model_parallel_world_size()}"
s += ", gather_output=False"
return s
\ No newline at end of file
return s
class RowParallelMoELinear(RowParallelLinear):
def __init__(self,
num_experts: int,
input_size: int,
output_size: int,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
torch.nn.Module.__init__(self)
self.num_experts = num_experts
self.input_size = input_size
self.output_size = output_size
self.reduce_results = False
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
if quant_config is None:
self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedMoELinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)
# FIXME(ys): hack for moe
if isinstance(self.quant_method, UnquantizedLinearMethod):
self.quant_method = UnquantizedMoELinearMethod()
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size_per_partition,
[self.output_size],
self.input_size,
self.output_size,
self.params_dtype,
self.num_experts,
weight_loader=self.weight_loader)
self.register_parameter("bias", None)
def forward( # type: ignore[override]
self,
input_,
residual=None,
expert_idx: int = -1,
output: Optional[torch.Tensor] = None):
if isinstance(self.quant_method, UnquantizedMoELinearMethod):
# use optimus moe_ffn outside
return
bias = None
assert self.quant_method is not None
output = self.quant_method.apply(self,
input_,
bias,
expert_idx=expert_idx,
output=output)
return output
\ No newline at end of file
......@@ -36,7 +36,8 @@ class LogitsProcessor(nn.Module):
org_vocab_size: Optional[int] = None,
scale: float = 1.0,
logits_as_input: bool = False,
soft_cap: Optional[float] = None) -> None:
soft_cap: Optional[float] = None,
need_fp32_logits: bool = False) -> None:
"""
Args:
scale: A scaling factor to apply to the logits.
......@@ -52,6 +53,7 @@ class LogitsProcessor(nn.Module):
self.soft_cap = soft_cap
# Whether to use gather or all-gather to gather the logits.
self.use_all_gather = current_platform.use_all_gather()
self.need_fp32_logits = need_fp32_logits
def forward(
self,
......@@ -106,9 +108,13 @@ class LogitsProcessor(nn.Module):
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
logits = lm_head.quant_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
if self.need_fp32_logits:
logits = torch.ops.OptimusMoe.matmul_fp32(hidden_states,
lm_head.weight.t())
else:
logits = lm_head.quant_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
# Gather logits for TP
logits = self._gather_logits(logits)
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.optimus_moe import ( # noqa: F401
optimus_moe_int8)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import direct_register_custom_op
class GroupwiseQuantConfig(QuantizationConfig):
"""Config class for Groupwise Quantization.
"""
def __init__(
self,
weight_bits: int,
group_size: int,
symmetric: bool = False,
bf16_blocks: Optional[list] = None,
int8_blocks: Optional[list] = None,
weight_dtype: Optional[
str] = None, # Literal["int8", "fp8_e4m3", "fp6", "int4"] = "int8",
# FIXME: hack for mixed precision quantization
extra_quant_configs: Optional[dict] = None,
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
self.symmetric = symmetric
self.bf16_blocks = bf16_blocks if bf16_blocks else []
self.int8_blocks = int8_blocks if int8_blocks else []
self.extra_quant_configs = extra_quant_configs
if self.weight_bits == 4:
self.pack_factor = 32 // self.weight_bits
else:
self.pack_factor = 1
if not weight_dtype:
if weight_bits == 8:
self.weight_dtype = "int8" # fp8e4m3 must explicitly set
elif weight_bits == 6:
self.weight_dtype = "fp6"
elif weight_bits == 4:
self.weight_dtype = "int4"
else:
raise ValueError(f"Unsupported weight bits: {weight_bits}")
else:
self.weight_dtype = weight_dtype
def __repr__(self) -> str:
return (f"GroupwiseQuantConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size})")
@classmethod
def get_name(cls) -> str:
return "groupwise_quant"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
# The Groupwise Quant kernel only supports Ampere or newer GPUs.
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
return [
"quant_config.json",
]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GroupwiseQuantConfig":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
bf16_blocks = config.get("bf16_blocks", [])
int8_blocks = config.get("int8_blocks", [])
weight_dtype = config.get("weight_type")
extra_quant_configs = config.get("extra_quant_configs", {})
return cls(weight_bits,
group_size,
bf16_blocks=bf16_blocks,
int8_blocks=int8_blocks,
weight_dtype=weight_dtype,
extra_quant_configs=extra_quant_configs)
def get_quant_method(self,
layer: torch.nn.Module,
prefix: str = "") -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
block_index = int(prefix.split("layers.")[1].split(".")[0])
layer_name = prefix.split(".")[-1]
if block_index in self.bf16_blocks:
return UnquantizedLinearMethod()
if self.extra_quant_configs:
for config in self.extra_quant_configs:
if block_index in config[
"block_indices"] and layer_name in config[
"target_modules"]:
return GroupwiseQuantLinearMethod(
GroupwiseQuantConfig(
weight_bits=config["weight_bit"],
group_size=config["group_size"]
or self.group_size,
weight_dtype=config["weight_type"],
extra_quant_configs=self.extra_quant_configs))
# no specific config for this layer means no quantization
return UnquantizedLinearMethod()
else:
# Compatible with old config
return GroupwiseQuantLinearMethod(self)
elif isinstance(layer, FusedMoE):
# For MoE layers, only support 8-bit quantization
if self.weight_bits == 8:
return GroupwiseInt8MoeMethod(self)
else:
raise ValueError(f"Unsupported weight bits for MoE: {self.weight_bits}. Only 8-bit is supported.")
return None
class GroupwiseQuantLinearMethod(LinearMethodBase):
"""Linear method for GroupwiseQuant.
Args:
quant_config: The groupwise_quant quantization config.
"""
def __init__(self, quant_config: GroupwiseQuantConfig) -> None:
self.quant_config = quant_config
self.sm = torch.cuda.get_device_capability()
def create_weights(self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
num_experts: Optional[int] = None,
**extra_weight_attrs):
assert input_size_per_partition % self.quant_config.group_size == 0
output_size_per_partition = sum(output_partition_sizes)
assert output_size_per_partition % self.quant_config.pack_factor == 0
layer_keys = dir(layer)
has_num_experts = any("num_experts" in name for name in layer_keys)
if not has_num_experts:
layer.register_parameter("num_experts", None)
if self.quant_config.weight_bits == 4:
assert input_size_per_partition % self.quant_config.group_size == 0
assert output_size_per_partition % self.quant_config.pack_factor == 0
if num_experts:
weight_shape = [
num_experts, input_size_per_partition,
output_size_per_partition // self.quant_config.pack_factor
]
scale_shape = [
num_experts,
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition,
]
else:
weight_shape = [
input_size_per_partition,
output_size_per_partition // self.quant_config.pack_factor
]
scale_shape = [
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition,
]
qweight = Parameter(
torch.empty(
*weight_shape,
dtype=torch.int32,
),
requires_grad=False,
)
scales = Parameter(
torch.empty(
*scale_shape,
dtype=params_dtype,
),
requires_grad=False,
)
zeros = Parameter(
torch.empty(
*scale_shape,
dtype=params_dtype,
),
requires_grad=False,
)
if num_experts:
set_weight_attrs(
qweight, {
"input_dim": 1,
"output_dim": 2,
"packed_dim": 2,
"pack_factor": self.quant_config.pack_factor,
})
set_weight_attrs(scales, {
"input_dim": 1,
"output_dim": 2,
})
set_weight_attrs(zeros, {
"input_dim": 1,
"output_dim": 2,
})
else:
set_weight_attrs(
qweight, {
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
})
set_weight_attrs(scales, {
"input_dim": 0,
"output_dim": 1,
})
set_weight_attrs(zeros, {
"input_dim": 0,
"output_dim": 1,
})
layer.register_parameter("qweight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
layer.register_parameter("zeros", zeros)
set_weight_attrs(zeros, extra_weight_attrs)
elif self.quant_config.weight_bits == 8:
assert input_size_per_partition % self.quant_config.group_size == 0
if num_experts:
weight_shape = [
num_experts, input_size_per_partition,
output_size_per_partition
]
scale_shape = [
num_experts,
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition,
]
else:
weight_shape = [
input_size_per_partition, output_size_per_partition
]
scale_shape = [
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition,
]
qweight = Parameter(
torch.empty(
*weight_shape,
device="cuda",
dtype=torch.int8,
),
requires_grad=False,
)
scales = Parameter(
torch.empty(
*scale_shape,
device="cuda",
dtype=params_dtype,
),
requires_grad=False,
)
if num_experts:
set_weight_attrs(qweight, {
"input_dim": 1,
"output_dim": 2,
})
set_weight_attrs(scales, {
"input_dim": 1,
"output_dim": 2,
})
else:
set_weight_attrs(qweight, {
"input_dim": 0,
"output_dim": 1,
})
set_weight_attrs(scales, {
"input_dim": 0,
"output_dim": 1,
})
layer.register_parameter("qweight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
elif self.quant_config.weight_bits == 6:
assert input_size_per_partition % self.quant_config.group_size == 0
if num_experts:
weight_shape = [
num_experts,
output_size_per_partition,
input_size_per_partition,
]
scale_shape = [
num_experts,
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition,
]
else:
weight_shape = [
output_size_per_partition, input_size_per_partition
]
scale_shape = [
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition,
]
qweight = Parameter(
torch.zeros(
*weight_shape,
device=
"cpu", # hack for fp6 weight is stored in float16, to avoid cuda oom
dtype=torch.float16,
),
requires_grad=False,
)
scales = Parameter(
torch.empty(
*scale_shape,
device="cuda",
dtype=torch.float16,
),
requires_grad=False,
)
if num_experts:
set_weight_attrs(qweight, {
"input_dim": 2,
"output_dim": 1,
})
set_weight_attrs(scales, {
"input_dim": 1,
"output_dim": 2,
})
else:
set_weight_attrs(qweight, {
"input_dim": 1,
"output_dim": 0,
})
set_weight_attrs(scales, {
"input_dim": 0,
"output_dim": 1,
})
layer.register_parameter("qweight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
else:
raise NotImplementedError
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if not hasattr(layer, "qweight"):
return
if self.quant_config.weight_bits == 4:
num_experts = layer.num_experts
qweight = layer.qweight
zeros = layer.zeros
scales = layer.scales
if num_experts:
qscales_list = []
for i in range(num_experts):
qweight_processed, qscales = torch.ops.Optimus.GemmInt4GroupQuantWeight(
qweight[i], zeros[i], scales[i] + zeros[i],
self.quant_config.group_size)
qweight[i].copy_(qweight_processed)
qscales_list.append(qscales)
qscales = Parameter(torch.stack(qscales_list),
requires_grad=False)
layer.register_parameter("qscales", qscales)
else:
qweight_processed, qscales = torch.ops.Optimus.GemmInt4GroupQuantWeight(
qweight, zeros, scales + zeros,
self.quant_config.group_size)
qweight.copy_(qweight_processed)
qscales = Parameter(qscales, requires_grad=False)
layer.register_parameter("qscales", qscales)
layer._parameters.pop("zeros")
layer._parameters.pop("scales")
elif self.quant_config.weight_bits == 8:
num_experts = layer.num_experts
qweight = layer.qweight
if num_experts:
for i in range(num_experts):
qweight[i].copy_(
torch.ops.Optimus.FpAIntBPreprocessWeightGPU(
qweight[i].t().contiguous(), torch.int8))
else:
qweight.copy_(
torch.ops.Optimus.FpAIntBPreprocessWeightGPU(
qweight.t().contiguous(), torch.int8))
elif self.quant_config.weight_bits == 6:
num_experts = layer.num_experts
qweight = layer.qweight
layer._parameters.pop("qweight")
assert qweight.shape[-1] % 8 == 0
if num_experts:
qweight_processed = torch.empty(qweight.shape[0],
qweight.shape[1],
qweight.shape[2] * 6 // 8,
dtype=torch.uint8,
device="cuda")
else:
qweight_processed = torch.empty(qweight.shape[0],
qweight.shape[1] * 6 // 8,
dtype=torch.uint8,
device="cuda")
if num_experts:
for i in range(num_experts):
qweight_processed[
i] = torch.ops.Optimus.fp6_preprocess_weight(
qweight[i].cpu()).cuda()
qweight_processed = Parameter(qweight_processed,
requires_grad=False)
layer.register_parameter("qweight", qweight_processed)
else:
qweight_processed = Parameter(
torch.ops.Optimus.fp6_preprocess_weight(
qweight.cpu()).cuda(),
requires_grad=False)
layer.register_parameter("qweight", qweight_processed)
else:
raise NotImplementedError
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
output: Optional[torch.Tensor] = None,
expert_idx: Optional[int] = None) -> torch.Tensor:
if self.quant_config.weight_bits == 4:
qweight = layer.qweight
qscales = layer.qscales
num_experts = layer.num_experts
if num_experts:
assert expert_idx is not None, "expert_idx is None"
qweight = qweight[expert_idx]
qscales = qscales[expert_idx]
out = torch.ops.vllm.optimus_gemm_int4_group(x,
qweight,
qscales,
bias,
None, # Placeholder for a fifth argument that is None
out=output)
if residual is not None:
out += residual
return out
elif self.quant_config.weight_bits == 8:
qweight = layer.qweight
scales = layer.scales
num_experts = layer.num_experts
if num_experts:
assert expert_idx is not None, "expert_idx is None"
qweight = qweight[expert_idx]
scales = scales[expert_idx]
if residual is not None:
assert output is None or output is residual
out = torch.ops.vllm.optimus_fp_aintb_gemm(x,
qweight,
torch.int8, # Placeholder for dtype argument
scales,
residual,
out=residual)
if bias is not None:
out += bias
else:
out = torch.ops.vllm.optimus_fp_aintb_gemm(x,
qweight,
torch.int8, # Placeholder for dtype argument
scales,
bias,
out=output)
return out
elif self.quant_config.weight_bits == 6:
qweight = layer.qweight
scales = layer.scales
num_experts = layer.num_experts
if num_experts:
assert expert_idx is not None, "expert_idx is None"
qweight = qweight[expert_idx]
scales = scales[expert_idx]
if x.dtype != torch.bfloat16:
if output is None:
output = torch.empty(x.shape[0],
qweight.shape[0],
device=x.device,
dtype=torch.bfloat16)
else:
output = output.to(torch.bfloat16)
out = torch.ops.vllm.optimus_fp6_linear(x,
qweight,
scales,
4, # Placeholder for fp6_format_code
out=output)
if bias is not None:
out += bias
if residual is not None:
out += residual
return out
else:
raise NotImplementedError
class GroupwiseInt8MoeMethod(FusedMoEMethodBase):
"""MoE method for Groupwise INT8 quantization.
Args:
quant_config: The groupwise quantization config.
"""
def __init__(self, quant_config: GroupwiseQuantConfig):
self.quant_config = quant_config
assert self.quant_config.weight_bits == 8, "Only 8-bit quantization is supported for GroupwiseInt8MoeMethod"
def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
extra_weight_attrs.update({
"is_transposed":
True,
"quant_method":
FusedMoeWeightScaleSupported.GROUP.value,
})
# Create INT8 weights
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
2 * intermediate_size_per_partition,
dtype=torch.int8),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, {
"input_dim": 1,
"output_dim": 2,
})
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
intermediate_size_per_partition,
hidden_size,
dtype=torch.int8),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, {
"input_dim": 1,
"output_dim": 2,
})
set_weight_attrs(w2_weight, extra_weight_attrs)
# Create scales for groupwise quantization
w13_weight_scale = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size // self.quant_config.group_size,
2 * intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
set_weight_attrs(w13_weight_scale, {
"input_dim": 1,
"output_dim": 2,
})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
w2_weight_scale = torch.nn.Parameter(torch.empty(
num_experts,
intermediate_size_per_partition // self.quant_config.group_size,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w2_weight_scale, {
"input_dim": 1,
"output_dim": 2,
})
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Process weights similar to GroupwiseQuantLinearMethod for 8-bit case
num_experts = layer.w13_weight.shape[0]
for expert in range(num_experts):
# Preprocess w13 weight (gate and up combined)
layer.w13_weight[expert].copy_(
torch.ops.Optimus.FpAIntBPreprocessWeightGPU(
layer.w13_weight[expert].t().contiguous(), torch.int8))
# Preprocess w2 weight (down)
layer.w2_weight[expert].copy_(
torch.ops.Optimus.FpAIntBPreprocessWeightGPU(
layer.w2_weight[expert].t().contiguous(), torch.int8))
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
return torch.ops.vllm.optimus_moe_int8(
hidden_states=x,
router_logits=router_logits,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
top_k=top_k,
global_num_experts=global_num_experts,
norm_expert_weight=renormalize,
activation=activation,
)
# Wrapper and Fake Functions for Optimus::GemmInt4Group
def optimus_gemm_int4_group(x: torch.Tensor, qweight: torch.Tensor,
qscales: torch.Tensor,
bias: Optional[torch.Tensor],
out: Optional[torch.Tensor]) -> torch.Tensor:
return torch.ops.Optimus.GemmInt4Group(x, qweight, qscales, bias,
None, out=out)
def optimus_gemm_int4_group_fake(x: torch.Tensor, qweight: torch.Tensor,
qscales: torch.Tensor,
bias: Optional[torch.Tensor],
out: Optional[torch.Tensor]) -> torch.Tensor:
output_shape = list(x.shape[:-1]) + [qscales.shape[-1]]
if out is not None:
return torch.empty(output_shape, dtype=x.dtype, device=x.device)
return torch.empty(output_shape, dtype=x.dtype, device=x.device)
# Wrapper and Fake Functions for Optimus::FpAIntBGemm
def optimus_fp_aintb_gemm(x: torch.Tensor, qweight: torch.Tensor,
dtype_arg: torch.dtype, scales: torch.Tensor,
bias_or_residual: Optional[torch.Tensor],
out: Optional[torch.Tensor]) -> torch.Tensor:
return torch.ops.Optimus.FpAIntBGemm(x, qweight, dtype_arg, scales,
bias_or_residual, "identity",
out=out)
def optimus_fp_aintb_gemm_fake(x: torch.Tensor, qweight: torch.Tensor,
dtype_arg: torch.dtype, scales: torch.Tensor,
bias_or_residual: Optional[torch.Tensor],
out: Optional[torch.Tensor]) -> torch.Tensor:
output_shape = list(x.shape[:-1]) + [qweight.shape[-1]]
if out is not None:
return torch.empty(output_shape, dtype=x.dtype, device=x.device)
return torch.empty(output_shape, dtype=x.dtype, device=x.device)
# Wrapper and Fake Functions for Optimus::fp6_linear
def optimus_fp6_linear(x: torch.Tensor, qweight: torch.Tensor,
scales: torch.Tensor, fp6_format_code: int,
out: Optional[torch.Tensor]) -> torch.Tensor:
return torch.ops.Optimus.fp6_linear(x, qweight, scales, fp6_format_code,
out=out)
def optimus_fp6_linear_fake(x: torch.Tensor, qweight: torch.Tensor,
scales: torch.Tensor, fp6_format_code: int,
out: Optional[torch.Tensor]) -> torch.Tensor:
output_channels = scales.shape[-1]
output_shape = list(x.shape[:-1]) + [output_channels]
output_dtype = x.dtype
if x.dtype != torch.bfloat16:
output_dtype = torch.bfloat16
if out is not None:
return torch.empty(output_shape, dtype=output_dtype, device=x.device)
return torch.empty(output_shape, dtype=output_dtype, device=x.device)
direct_register_custom_op(
op_name="optimus_gemm_int4_group",
op_func=optimus_gemm_int4_group,
mutates_args=["out"],
fake_impl=optimus_gemm_int4_group_fake,
)
direct_register_custom_op(
op_name="optimus_fp_aintb_gemm",
op_func=optimus_fp_aintb_gemm,
mutates_args=["out", "bias_or_residual"],
fake_impl=optimus_fp_aintb_gemm_fake,
)
direct_register_custom_op(
op_name="optimus_fp6_linear",
op_func=optimus_fp6_linear,
mutates_args=["out"],
fake_impl=optimus_fp6_linear_fake,
)
\ No newline at end of file
import torch
@torch.jit.script
def cal_scale(amax, fp_max, scale):
margin = 0
exp = torch.floor(torch.log2(fp_max / amax)) - margin
sf = torch.round(torch.pow(2, torch.abs(exp)))
sf = torch.where(amax > 0.0, sf, scale)
sf = torch.where(torch.isfinite(amax), sf, scale)
scale = torch.where(exp < 0, 1 / sf, sf)
scale_inv = torch.reciprocal(scale)
return scale, scale_inv
instances = {}
def singleton(cls):
global instances
def get_instance(*args, **kwargs):
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]
return get_instance
def reset_singleton():
global instances
instances = {}
@singleton
class QuantFp8:
def __init__(self, device):
self.fp_max = torch.tensor([448.0], device=device)
self.device = device
self.scale = torch.tensor([1.0], device=self.device)
pass
@staticmethod
def quantize_v1(weight, bits):
if bits == 8:
amax = weight.abs().max()
fp_max = torch.tensor([448.0]).to(weight.device)
margin = 0
scale = torch.tensor([1.0]).to(weight.device)
exp = torch.floor(torch.log2(fp_max / amax)) - margin
sf = torch.round(torch.pow(2, torch.abs(exp)))
sf = torch.where(amax > 0.0, sf, scale)
sf = torch.where(torch.isfinite(amax), sf, scale)
scale = torch.where(exp < 0, 1 / sf, sf)
qweight = (weight.to(torch.float32) * scale).to(
torch.float8_e4m3fn)
scale = torch.reciprocal(scale)
# print(f"amax={amax},scalse={scale}")
else:
raise ValueError(f"Unsupported bit width: {bits}")
return qweight, scale
def quantize(self, weight, bits, weight_scale, use_offline_input_scales):
if bits == 8:
amax = torch.empty(1, dtype=torch.float32, device=self.device)
scale = torch.tensor([1.0], device=self.device)
torch.ops.OptimusFp8.abs_max_nan_to_inf(weight, amax)
if weight_scale is None or not use_offline_input_scales:
scale, scale_inv = cal_scale(amax, self.fp_max, scale)
else:
scale, scale_inv = weight_scale, torch.reciprocal(weight_scale)
qweight = torch.ops.OptimusFp8.quantize(weight, scale, None,
torch.float8_e4m3fn)
# print(f"scale={scale},self.amax={self.amax}")
return qweight, scale_inv
else:
raise ValueError(f"Unsupported bit width: {bits}")
def get_quant_scale(self, tensor):
amax = torch.empty(1, dtype=torch.float32, device=tensor.device)
torch.ops.OptimusFp8.abs_max_nan_to_inf(tensor, amax)
scale, _ = cal_scale(amax, self.fp_max, self.scale)
return scale
def quantize(weight, bits, weight_scale=None, use_offline_input_scales=True):
quant = QuantFp8(weight.device)
return quant.quantize(weight, bits, weight_scale, use_offline_input_scales)
def dequant(weight, weight_scales):
return torch.ops.OptimusFp8.dequantize(weight, weight_scales,
torch.bfloat16)
def experts_dequant(weights, weight_scales):
ret = torch.empty(*weights.shape,
device=weights.device,
dtype=torch.bfloat16)
for i in range(weights.shape[0]):
ret[i] = dequant(weights[i], weight_scales[i])
return ret
def experts_quantize(weight, bits):
if bits == 8:
qweight_experts = torch.empty(*weight.shape,
dtype=torch.float8_e4m3fn,
device=weight.device)
scales = torch.empty(weight.shape[0],
dtype=torch.float32,
device=weight.device)
for idx in range(weight.shape[0]):
expert_weight = weight[idx]
qweight, scale = quantize(expert_weight, bits)
qweight_experts[idx] = qweight
scales[idx] = scale
return qweight_experts, scales
else:
raise ValueError(f"Unsupported bit width: {bits}")
def dynamic_fp8_pertensor_quantize(tensor):
# amax = torch.empty(1, dtype=torch.float32, device=tensor.device)
# scale = torch.tensor([1.0], device=tensor.device)
# fp_max = torch.tensor([448.0], device=tensor.device)
# torch.ops.OptimusFp8.abs_max_nan_to_inf(tensor, amax)
# scale, _ = cal_scale(amax, fp_max, scale)
# return scale
quant = QuantFp8(tensor.device)
return quant.get_quant_scale(tensor)
\ No newline at end of file
......@@ -797,3 +797,10 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
# If there were no matches, return the untouched param name
return name
def fp8_input_scales_loader(path: str):
with safe_open(path, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_slice(name)
yield name, param
# SPDX-License-Identifier: Apache-2.0
import os
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from itertools import product
from math import ceil, sqrt
from typing import Any, List, Literal, Optional, Tuple, TypedDict, Union
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import Pooler, PoolerOutput, PoolingType
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.step_encoder import StepCLIPVisionModel
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.step_image_preprocessor import StepPreprocessor
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
SentencePieceTokenizer)
from .interfaces import SupportsMultiModal, SupportsPP
from .interfaces_base import VllmModelForPooling
from .utils import (flatten_bn, init_vllm_registered_model,
is_pp_missing_parameter, maybe_prefix,
merge_multimodal_embeddings)
DEFAULT_HIGH_RESOLUTION = os.getenv("VLLM_DEFAULT_HIGH_RESOLUTION", "false").lower() in ["true", "1"]
VISION_MODEL_USE_DP = os.getenv("VLLM_VISION_MODEL_USE_DP", "false").lower() in ["true", "1"]
print(f"DEFAULT_HIGH_RESOLUTION: {DEFAULT_HIGH_RESOLUTION}")
print(f"VISION_MODEL_USE_DP: {VISION_MODEL_USE_DP}")
class MMStep1oImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: torch.Tensor # (batch_size * num_images, num_channels, height, width)
patch_pixel_values: Optional[
torch.
Tensor] # (batch_size * num_patches, num_channels, patch_size, patch_size)
num_patches: List[int] # (batch_size * num_patches)
class MMStep1oImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
image_embeds: torch.Tensor # (batch_size * num_images * image_feature_size, hidden_size)
MMStep1oImageInputs = Union[MMStep1oImagePixelInputs,
MMStep1oImageEmbeddingInputs]
ImageWithPatches = Tuple[Image.Image, list[Image.Image], list[int] | None]
class ImagePatcher:
def determine_window_size(self, long: int, short: int) -> int:
if long <= 728:
return short if long / short > 1.5 else 0
return min(short, 504) if long / short > 4 else 504
def slide_window(
self,
width: int,
height: int,
sizes: list[tuple[int, int]],
steps: list[tuple[int, int]],
img_rate_thr: float = 0.6,
) -> Tuple[List[Tuple[int, int, int, int]], Tuple[int, int]]:
assert 1 >= img_rate_thr >= 0, "The `in_rate_thr` should lie in 0~1"
windows = []
# Sliding windows.
for size, step in zip(sizes, steps):
size_w, size_h = size
step_w, step_h = step
x_num = 1 if width <= size_w else ceil((width - size_w) / step_w +
1)
x_start = [step_w * i for i in range(x_num)]
if len(x_start) > 1 and x_start[-1] + size_w > width:
x_start[-1] = width - size_w
y_num = 1 if height <= size_h else ceil((height - size_h) /
step_h + 1)
y_start = [step_h * i for i in range(y_num)]
if len(y_start) > 1 and y_start[-1] + size_h > height:
y_start[-1] = height - size_h
start = np.array(list(product(y_start, x_start)), dtype=int)
start[:, [0, 1]] = start[:, [1, 0]]
windows.append(np.concatenate([start, start + size], axis=1))
windows = np.concatenate(windows, axis=0)
return [(int(box[0]), int(box[1]), int(box[2] - box[0]),
int(box[3] - box[1])) for box in windows], (x_num, y_num)
def square_pad(self, img: Image.Image) -> Image.Image:
w, h = img.size
if w == h:
return img
size = max(w, h)
padded = Image.new(img.mode, (size, size), 0)
padded.paste(img, (0, 0))
return padded
def get_image_size_for_padding(self, img_width: int,
img_height: int) -> Tuple[int, int]:
ratio = img_width / img_height
if min(img_height, img_width) < 32 and (ratio > 4 or ratio < 1 / 4):
new_size = max(img_height, img_width)
return new_size, new_size
return img_width, img_height
def get_image_size_for_preprocess(self, img_width: int,
img_height: int) -> Tuple[int, int]:
if max(img_height, img_width) > 3024:
scale_factor = 3024 / max(img_height, img_width)
img_width = int(img_width * scale_factor)
img_height = int(img_height * scale_factor)
return img_width, img_height
else:
return img_width, img_height
def get_image_size_for_crop(self, img_width: int, img_height: int,
window_size: int):
w_ratio = img_width / window_size
h_ratio = img_height / window_size
if w_ratio < 1:
width_new = img_width
else:
xiaoshu_w = w_ratio - img_width // window_size
w_ratio = int(w_ratio) + 1 if xiaoshu_w > 0.2 else int(w_ratio)
width_new = window_size * w_ratio
if h_ratio < 1:
height_new = img_height
else:
xiaoshu_h = h_ratio - img_height // window_size
h_ratio = int(h_ratio) + 1 if xiaoshu_h > 0.2 else int(h_ratio)
height_new = window_size * h_ratio
return int(width_new), int(height_new)
def patch_crop(self, img: Image.Image, i: int, j: int, th: int, tw: int):
target = img.crop((j, i, j + tw, i + th))
return target
def get_num_patches(self, img_width: int,
img_height: int) -> Tuple[int, int]:
img_width, img_height = self.get_image_size_for_padding(
img_width, img_height)
img_width, img_height = self.get_image_size_for_preprocess(
img_width, img_height)
window_size = self.determine_window_size(max(img_height, img_width),
min(img_height, img_width))
if window_size == 0:
return 0, 0
else:
img_width, img_height = self.get_image_size_for_crop(
img_width, img_height, window_size)
center_list, (x_num, y_num) = self.slide_window(
img_width, img_height, [(window_size, window_size)],
[(window_size, window_size)])
full_rows = (len(center_list) - 1) // x_num + 1
if len(center_list) > 0 and len(center_list) % x_num == 0:
full_rows -= 1
return len(center_list), full_rows
def __call__(
self, img: Image.Image
) -> Tuple[Image.Image, List[Image.Image], List[bool] | None]:
img_width, img_height = img.size
new_img_width, new_img_height = self.get_image_size_for_padding(
img_width, img_height)
if new_img_width != img_width or new_img_height != img_height:
img = self.square_pad(img)
img_width, img_height = img.size
new_img_width, new_img_height = self.get_image_size_for_preprocess(
img_width, img_height)
img = img.resize((new_img_width, new_img_height),
Image.Resampling.BILINEAR)
window_size = self.determine_window_size(
max(new_img_height, new_img_width),
min(new_img_height, new_img_width))
if window_size == 0:
return img, [], None
else:
new_img_width, new_img_height = self.get_image_size_for_crop(
new_img_width, new_img_height, window_size)
if (new_img_width, new_img_height) != (img_width, img_height):
img_for_crop = img.resize((new_img_width, new_img_height),
Image.Resampling.BILINEAR)
else:
img_for_crop = img
patches = []
newlines = []
center_list, (x_num, y_num) = self.slide_window(
new_img_width, new_img_height, [(window_size, window_size)],
[(window_size, window_size)])
for patch_id, center_lf_point in enumerate(center_list):
x, y, patch_w, patch_h = center_lf_point
big_patch = self.patch_crop(img_for_crop, y, x, patch_h,
patch_w)
patches.append(big_patch)
if (patch_id + 1) % x_num == 0:
newlines.append(patch_id)
if newlines and newlines[-1] == len(patches) - 1:
newlines.pop()
return img, patches, [i in newlines for i in range(len(patches))
] if len(patches) > 0 else None
class Step1oProcessor:
def __init__(
self,
config: PretrainedConfig,
tokenizer: AnyTokenizer,
) -> None:
super().__init__()
self.config = config
self.tokenizer = tokenizer
self.image_size = 728
self.patch_size = 504
self.image_preprocessor = StepPreprocessor(self.image_size, "bilinear",
self.patch_size)
self.num_image_feature_size = 169
self.num_patch_feature_size = 81
self.image_token = "<im_patch>"
self.image_feature_placeholder = self.image_token * self.num_image_feature_size
self.patch_feature_placeholder = self.image_token * self.num_patch_feature_size
self.patcher = ImagePatcher()
@property
def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[self.image_token]
def get_num_image_tokens(self, img_width: int, img_height: int, detail: str = "auto") -> int:
if detail == "high":
use_high_resolution = True
elif detail == "low":
use_high_resolution = False
else:
use_high_resolution = DEFAULT_HIGH_RESOLUTION
if use_high_resolution:
num_patches, num_newlines = self.patcher.get_num_patches(
img_width, img_height)
else:
num_patches = 0
num_newlines = 0
return num_patches * (
self.num_patch_feature_size +
2) + self.num_image_feature_size + 2 + num_newlines
def _split_images(self,
images: list[Image.Image]) -> list[ImageWithPatches]:
result = []
for img in images:
detail = img.info.get("detail", None)
if detail == "high":
use_high_resolution = True
elif detail == "low":
use_high_resolution = False
else:
use_high_resolution = DEFAULT_HIGH_RESOLUTION
if use_high_resolution:
result.append(self.patcher(img))
else:
result.append((img, [], None))
return result
def _convert_images_to_pixel_values(
self,
images: list[Image.Image],
is_patch: bool = False,
) -> list[torch.Tensor]:
return [
self.image_preprocessor.preprocess(
img, is_patch=is_patch)["pixel_values"] for img in images
]
def _get_patch_repl(
self,
num_patches: int,
patch_newline_mask: list[bool] | None,
) -> Tuple[str, list[int]]:
text = ""
token_ids = []
for i in range(num_patches):
assert len(patch_newline_mask) == num_patches
text += f"<patch_start>{self.patch_feature_placeholder}<patch_end>"
token_ids.extend(
[self.tokenizer.convert_tokens_to_ids("<patch_start>")] +
[self.image_token_id] * self.num_patch_feature_size +
[self.tokenizer.convert_tokens_to_ids("<patch_end>")])
if patch_newline_mask and patch_newline_mask[i]:
text += "<patch_newline>"
token_ids.append(
self.tokenizer.convert_tokens_to_ids("<patch_newline>"))
return text, token_ids
def _get_image_repl(
self,
num_images: int,
) -> Tuple[str, list[int]]:
text = f"<im_start>{self.image_feature_placeholder}<im_end>"
token_ids = [
self.tokenizer.convert_tokens_to_ids("<im_start>")
] + [self.image_token_id] * self.num_image_feature_size + [
self.tokenizer.convert_tokens_to_ids("<im_end>")
]
return text * num_images, token_ids * num_images
def _get_image_repl_features(
self,
num_images: int,
num_patches: int,
patch_new_line_idx: Optional[list[bool]],
) -> Tuple[str, list[int]]:
if num_patches > 0:
patch_repl, patch_repl_ids = self._get_patch_repl(
num_patches, patch_new_line_idx)
else:
patch_repl = ""
patch_repl_ids = []
image_repl, image_repl_ids = self._get_image_repl(num_images)
return patch_repl + image_repl, patch_repl_ids + image_repl_ids
def replace_placeholder(self, text: str, placeholder: str,
repls: list[str]) -> str:
parts = text.split(placeholder)
if len(parts) - 1 != len(repls):
raise ValueError(
"The number of placeholders does not match the number of replacements."
)
result = [parts[0]]
for i, repl in enumerate(repls):
result.append(repl)
result.append(parts[i + 1])
return "".join(result)
def __call__(
self,
text: Optional[Union[str, list[str]]] = None,
images: Optional[Union[Image.Image, list[Image.Image]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> BatchFeature:
if text is None:
text = []
if not isinstance(text, list):
text = [text]
if images is None:
images = []
if not isinstance(images, list):
images = [images]
if len(images) == 0:
image_inputs = {}
if isinstance(self.tokenizer, SentencePieceTokenizer):
assert len(text) == 1
text_inputs = {
"input_ids":
torch.tensor([
self.tokenizer.encode(text[0],
add_special_tokens=True)
],
dtype=torch.long)
} # step-tokenizer does not support text input for special tokens
else:
text_inputs = self.tokenizer(text)
else:
splitted_images_data = self._split_images(images)
pixel_values_lst = []
patch_pixel_values_lst = []
patch_newline_mask_lst = []
image_repl_str_lst = []
image_repl_ids_lst = []
num_patches = []
for raw_img, img_patches, patch_newline_mask in splitted_images_data:
pixel_values_lst.extend(
self._convert_images_to_pixel_values([raw_img]))
if len(img_patches) > 0:
patch_pixel_values_lst.extend(
self._convert_images_to_pixel_values(img_patches,
is_patch=True))
num_patches.append(len(img_patches))
image_repl_str, image_repl_ids = self._get_image_repl_features(
1, len(img_patches), patch_newline_mask)
image_repl_str_lst.append(image_repl_str)
image_repl_ids_lst.extend(image_repl_ids)
if patch_newline_mask is not None:
patch_newline_mask_lst.extend(patch_newline_mask)
image_inputs = {
"pixel_values": torch.cat(pixel_values_lst),
"num_patches": num_patches,
}
if patch_pixel_values_lst:
image_inputs["patch_pixel_values"] = torch.cat(
patch_pixel_values_lst)
if patch_newline_mask_lst:
image_inputs["patch_newline_mask"] = torch.tensor(
patch_newline_mask_lst, dtype=torch.bool)
if isinstance(self.tokenizer, SentencePieceTokenizer):
text_inputs = {
"input_ids":
torch.tensor(image_repl_ids_lst,
dtype=torch.long).unsqueeze(0)
} # step-tokenizer does not support text input for special tokens
else:
text = [
self.replace_placeholder(t, self.image_token,
image_repl_str_lst) for t in text
]
text_inputs = self.tokenizer(text)
return BatchFeature(
{
**text_inputs,
**image_inputs,
},
tensor_type=return_tensors,
)
class Step1oProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self) -> Step1oProcessor:
return Step1oProcessor(
self.get_hf_config(),
self.get_tokenizer(),
)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_max_image_tokens(self) -> int:
hf_processor = self.get_hf_processor()
return hf_processor.get_num_image_tokens(
self.get_image_size_with_most_features().width,
self.get_image_size_with_most_features().height)
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_image_size_with_most_features(self) -> ImageSize:
return ImageSize(728, 728)
def get_num_mm_tokens(self, mm_data: MultiModalDataDict) -> int:
if len(mm_data) != 1 or "image" not in mm_data:
raise ValueError("mm_data could only contain one key 'image' for steo1o")
image_data = mm_data["image"]
if not isinstance(image_data, (list, tuple)):
image_data = [image_data]
return sum(self.get_hf_processor().get_num_image_tokens(
img.width, img.height, detail=img.info.get("detail", None)) for img in image_data)
class Step1oDummyInputsBuilder(BaseDummyInputsBuilder[Step1oProcessingInfo]):
# def get_dummy_processor_inputs(
# self,
# seq_len: int,
# mm_counts: Mapping[str, int],
# ) -> ProcessorInputs:
# target_width, target_height = \
# self.info.get_image_size_with_most_features()
# num_images = mm_counts.get("image", 0)
# mm_data = {
# "image":
# self._get_dummy_images(width=target_width,
# height=target_height,
# num_images=num_images)
# }
# return ProcessorInputs(
# prompt_text="<im_patch>" * num_images,
# mm_data=mm_data,
# )
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
return "<im_patch>" * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
target_width, target_height = \
self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
class Step1oMultiModalProcessor(BaseMultiModalProcessor[Step1oProcessingInfo]):
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_placeholder_token_id = hf_processor.image_token_id
batch_num_patches = out_mm_kwargs["num_patches"].tolist()
def get_replacement_step1o(item_idx: int):
img_out = out_mm_kwargs.get_item("image", item_idx)
num_patches = batch_num_patches[item_idx]
if num_patches > 0:
patch_newline_mask = img_out["patch_newline_mask"].data.tolist(
)
image_repl_ids = hf_processor._get_image_repl_features(
1, num_patches, patch_newline_mask)[1]
else:
image_repl_ids = hf_processor._get_image_repl_features(
1, 0, None)[1]
return PromptUpdateDetails.select_token_id(
seq=image_repl_ids,
embed_token_id=image_placeholder_token_id,
)
return [
PromptReplacement(
modality="image",
target=[image_placeholder_token_id],
replacement=get_replacement_step1o,
)
]
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
num_patches = hf_inputs.get("num_patches", torch.empty(0))
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
patch_pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", num_patches),
num_patches=MultiModalFieldConfig.batched("image"),
patch_newline_mask=MultiModalFieldConfig.flat_from_sizes(
"image", num_patches),
)
@MULTIMODAL_REGISTRY.register_processor(Step1oMultiModalProcessor,
info=Step1oProcessingInfo,
dummy_inputs=Step1oDummyInputsBuilder)
class MMGPTStep1oForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.vision_model = StepCLIPVisionModel(config.vision_tower_config,
None,
prefix=maybe_prefix(
prefix, "vision_model"),
need_dp=VISION_MODEL_USE_DP)
self.vit_downsampler = nn.Conv2d(
config.vision_tower_config.hidden_size,
config.vision_tower_config.output_hidden_size,
kernel_size=2,
stride=config.understand_projector_stride)
self.vit_downsampler2 = nn.Conv2d(
config.vision_tower_config.output_hidden_size,
config.vision_tower_config.output_hidden_size * 2,
kernel_size=3,
stride=2,
padding=1,
)
self.vit_large_projector = nn.Linear(
config.vision_tower_config.output_hidden_size * 2,
config.hidden_size,
bias=config.projector_bias,
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"))
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[MMStep1oImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
patch_pixel_values = kwargs.pop("patch_pixel_values", None)
num_patches = kwargs.pop("num_patches", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
pixel_values = flatten_bn(pixel_values, concat=True)
if pixel_values.dim() >= 3:
pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:])
if patch_pixel_values is not None:
patch_pixel_values = flatten_bn(patch_pixel_values,
concat=True)
patch_pixel_values = patch_pixel_values.view(
-1, *patch_pixel_values.shape[-3:])
# Handle empty patch_pixel_values by setting to None
if patch_pixel_values.shape[0] == 0:
patch_pixel_values = None
num_patches = flatten_bn(num_patches, concat=True).tolist()
return MMStep1oImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values.to(self.dtype).to(self.device),
patch_pixel_values=patch_pixel_values.to(self.dtype).to(
self.device) if patch_pixel_values is not None else None,
num_patches=num_patches,
)
if image_embeds is not None:
if image_embeds.dim() == 2 or image_embeds.dim() >= 3:
image_embeds = image_embeds.view(-1, image_embeds.shape[-1])
else:
raise ValueError(f"Unexpected shape for image_embeds: {image_embeds.shape}")
return MMStep1oImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds.to(self.dtype).to(self.device),
)
return None
def _process_image_features(self,
image_features: torch.Tensor) -> torch.Tensor:
B, P = image_features.shape[:2]
HW = int(sqrt(P))
image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW)
image_features = self.vit_downsampler(image_features)
image_features = self.vit_downsampler2(image_features)
n_dim = image_features.size(1)
image_features = image_features.view(B, n_dim, -1).permute(0, 2, 1)
image_features = self.vit_large_projector(image_features)
return image_features
def _get_vision_model_output(self, input_tensor: torch.Tensor) -> torch.Tensor:
if VISION_MODEL_USE_DP and get_tensor_model_parallel_world_size() > 1:
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
batch_size = input_tensor.shape[0]
chunk_size = (batch_size + tp_size - 1) // tp_size
start_idx = tp_rank * chunk_size
end_idx = min(start_idx + chunk_size, batch_size)
local_input_tensor = torch.empty(chunk_size, *input_tensor.shape[1:], dtype=input_tensor.dtype, device=input_tensor.device)
if end_idx > start_idx:
local_input_tensor[:end_idx - start_idx].copy_(input_tensor[start_idx:end_idx])
local_features = self.vision_model(local_input_tensor)[0][:, 4:]
total_features = tensor_model_parallel_all_gather(local_features.contiguous(), dim=0)
return total_features[:batch_size]
else:
return self.vision_model(input_tensor)[0][:, 4:]
def _process_image_input(
self,
image_input: MMStep1oImageInputs) -> tuple[torch.Tensor, ...]:
if image_input["type"] == "image_embeds":
image_features = image_input["image_embeds"]
else:
image_features = self._get_vision_model_output(image_input["pixel_values"])
patch_image_features = self._get_vision_model_output(
image_input["patch_pixel_values"]) if image_input["patch_pixel_values"] is not None else None
num_patches = image_input["num_patches"]
image_features = self._process_image_features(image_features)
patch_image_features = self._process_image_features(
patch_image_features) if patch_image_features is not None else None
merged_image_features = []
cur_patch_idx = 0
for i, num_patch in enumerate(num_patches):
cur_feature = []
if num_patch > 0:
patch_slice = patch_image_features[
cur_patch_idx:cur_patch_idx + num_patch]
cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1]))
cur_feature.append(image_features[i].view(
-1, image_features.shape[-1]))
cur_patch_idx += num_patch
merged_image_features.append(
torch.cat(cur_feature) if len(cur_feature) > 1 else cur_feature[0])
return merged_image_features
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
vision_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
if vision_embeddings is None:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
else:
is_text = input_ids != self.config.image_token_id
text_ids = input_ids[is_text]
text_embeds = self.language_model.model.get_input_embeddings(
text_ids)
inputs_embeds = torch.empty(input_ids.shape[0], text_embeds.shape[-1], dtype=text_embeds.dtype, device=text_embeds.device)
inputs_embeds[is_text] = text_embeds
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_id)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
if intermediate_tensors is not None:
inputs_embeds = None
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
hidden_states = self.language_model(input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def maybe_remap_params(self, name):
if name.startswith("model."):
name = name.replace("model.", "language_model.model.")
if name.startswith("lm_head"):
name = name.replace("lm_head", "language_model.lm_head")
if name.startswith("vision_model."):
name = name.replace("vision_model.", "vision_model.vision_model.")
return name
def load_weights_1o(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params = set()
for name, loaded_weight in weights:
name = self.maybe_remap_params(name)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
params_need_to_load = []
for name in params_dict:
params_need_to_load.append(name)
params_need_to_load = set(params_need_to_load)
if params_need_to_load != loaded_params:
param_name_example = list(params_need_to_load - loaded_params)[0]
raise RuntimeError(
f"Some parameters like {param_name_example} are not in the checkpoint and will falsely use random initialization"
)
def load_weights_3v(self, weights: Iterable[Tuple[str, torch.Tensor]]):
from vllm.model_executor.layers.fused_moe import FusedMoE
qkv_params_mapping = [
# (param_name, shard_name, relative_start_idx, relative_end_idx)
(".qkv_proj", ".q_proj", 0, self.config.share_q_dim / (self.config.share_q_dim + self.config.head_dim * 2)),
(".qkv_proj", ".k_proj", self.config.share_q_dim / (self.config.share_q_dim + self.config.head_dim * 2), (self.config.share_q_dim + self.config.head_dim) / (self.config.share_q_dim + self.config.head_dim * 2)),
(".qkv_proj", ".v_proj", (self.config.share_q_dim + self.config.head_dim) / (self.config.share_q_dim + self.config.head_dim * 2), (self.config.share_q_dim + self.config.head_dim * 2) / (self.config.share_q_dim + self.config.head_dim * 2)),
]
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params = set()
params_need_to_load = set()
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.language_model.model.config.moe_num_experts)
if self.language_model.model.use_fused_moe:
quant_config = self.language_model.model.vllm_config.quant_config
if quant_config is not None and quant_config.get_name() == "groupwise_quant":
expert_params_mapping = [
(".moe.experts.w13_weight", ".moe.gate_proj.qweight", "w1"),
(".moe.experts.w13_weight", ".moe.up_proj.qweight", "w3"),
(".moe.experts.w2_weight", ".moe.down_proj.qweight", "w2"),
(".moe.experts.w13_weight_scale", ".moe.gate_proj.scales", "w1"),
(".moe.experts.w13_weight_scale", ".moe.up_proj.scales", "w3"),
(".moe.experts.w2_weight_scale", ".moe.down_proj.scales","w2"),
]
else:
expert_params_mapping = [
(".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"),
(".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"),
(".moe.experts.w2_weight", ".moe.down_proj.weight", "w2")
]
else:
expert_params_mapping = []
disable_moe_stacked_params = [data[1] for data in expert_params_mapping]
for name, loaded_weight in weights:
name = self.maybe_remap_params(name)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
if any(disable_moe_stacked_param in name for disable_moe_stacked_param in disable_moe_stacked_params):
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name,shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
for expert_id in range(loaded_weight.shape[0]):
loaded_weight_expert = loaded_weight[expert_id]
weight_loader(param,
loaded_weight_expert,
name,
shard_id=shard_id,
expert_id=expert_id)
loaded_params.add(name)
break
else:
for (param_name, weight_name, start_idx, end_idx) in qkv_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
dim = param.shape[param.output_dim]
begin_idx = int(start_idx * dim)
end_idx = int(end_idx * dim)
param_slice = param.narrow(param.output_dim,begin_idx,end_idx-begin_idx)
param_slice.copy_(loaded_weight)
loaded_params.add(name)
break
else:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
params_need_to_load = []
for name in params_dict:
params_need_to_load.append(name)
params_need_to_load = set(params_need_to_load)
if params_need_to_load != loaded_params:
param_name_example = list(params_need_to_load - loaded_params)[0]
raise RuntimeError(
f"Some parameters like {param_name_example} are not in the checkpoint and will falsely use random initialization"
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if self.config.model_type in ["step1o", "mmgpt_qwen2_v2"]:
self.load_weights_1o(weights)
elif self.config.model_type == "step3v":
self.load_weights_3v(weights)
else:
raise ValueError(f"Unsupported model type: {self.multimodal_config.model_type}")
class MMGPTStep1oRewardModel(MMGPTStep1oForCausalLM, VllmModelForPooling):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None, "Pooler config must be provided for classification models"
# Remove attributes specific to CausalLM if they exist directly on self
# (They are typically part of language_model)
for attr in ("sampler", "lm_head"):
if hasattr(self.language_model, attr):
delattr(self.language_model, attr)
# Initialize the classification score head
self.score = RowParallelLinear(config.text_config.hidden_size,
config.num_labels, # Assumes num_labels is in the main config
quant_config=quant_config,
input_is_parallel=False,
bias=False,
prefix=maybe_prefix(prefix, "score"))
# Initialize the pooler
# Use LAST pooling, no normalization, apply softmax (typical for classification)
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.ALL,
normalize=False,
softmax=False,
)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
return self._pooler(hidden_states, pooling_metadata)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> torch.Tensor:
# Get hidden states from the base model (without the LM head)
hidden_states = super().forward(input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds,
**kwargs)
# Apply the classification head
logits, _ = self.score(hidden_states)
return logits
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Filter out lm_head weights before passing to the base loader
weights_iterator = ((name, data) for name, data in weights
if "language_model.lm_head." not in name)
# Use the base class's load_weights logic, which now includes
# handling for the 'score' layer via maybe_remap_params
super().load_weights(weights_iterator)
\ No newline at end of file
......@@ -134,6 +134,11 @@ _TEXT_GENERATION_MODELS = {
# [Encoder-decoder]
"BartModel": ("bart", "BartForConditionalGeneration"),
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
# step model
"Step1ForCausalLM": ("step1", "Step1ForCausalLM"),
"Step2ForCausalLM": ("step1", "Step1ForCausalLM"),
"Step1MoEForCausalLM": ("step1", "Step1ForCausalLM"),
"Step2MiniForCausalLM": ("step2_mini", "Step2MiniForCausalLM"),
}
_EMBEDDING_MODELS = {
......@@ -174,6 +179,19 @@ _EMBEDDING_MODELS = {
# input and output. I am adding it here because it piggy-backs on embedding
# models for the time being.
"PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
# step model
"Step1ForSequenceClassification": ("step1",
"Step1ForSequenceClassification"),
"Step2ForClassification": ("step1", "Step1ForSequenceClassification"),
"Step2ForSequenceClassification": ("step2",
"Step2ForSequenceClassification"),
"Step2MiniForClassification": ("step2_mini",
"Step2MiniForSequenceClassification"),
"MMGPTQwen2RewardModel": ("mm_step1o", "MMGPTStep1oRewardModel"),
# Technically PrithviGeoSpatialMAE is a model that works on images, both in
# input and output. I am adding it here because it piggy-backs on embedding
# models for the time being.
"PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
}
_CROSS_ENCODER_MODELS = {
......@@ -251,6 +269,15 @@ _SPECULATIVE_DECODING_MODELS = {
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
"MedusaModel": ("medusa", "Medusa"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
# step model
"MMGPTStep1ForCausalLMV2": ("mm_step1p5c_1u", "MMGPTStep1ForCausalLMV2"),
"MMGPTStep1ForCausalLMV3": ("mm_step1p5c_1u", "MMGPTStep1ForCausalLMV3"),
"MMGPTStep1ForCausalLMV4": ("mm_step1o", "MMGPTStep1oForCausalLM"),
"MMGPTQwen2ForCausalLM": ("mm_step1p5c_1u", "MMGPTStep1ForCausalLMV3"),
"MMGPTQwen2ForCausalLMV2": ("mm_step1o", "MMGPTStep1oForCausalLM"),
"MMGPTStep3vForCausalLM": ("mm_step1o", "MMGPTStep1oForCausalLM"),
"Step1AudioForCausalLM": ("mm_step_audio", "MMGPTStep1fForCausalLM"),
"StepAudioForCausalLMV2": ("mm_step_audio", "MMGPTStep1fForCausalLM"),
}
_TRANSFORMERS_MODELS = {
......
# SPDX-License-Identifier: Apache-2.0
import math
import os
from typing import Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
# from optimus import moe_expert_histogram as optimus_moe_expert_histogram
# from optimus import moe_gather as optimus_moe_gather
# from optimus import moe_scatter as optimus_moe_scatter
from torch import nn
from vllm.attention import Attention
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import OptimusSiluAndMul, SiluAndMul
from vllm.model_executor.layers.layernorm import OptimusRMSNorm, RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
MergedColumnParallelMoELinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
RowParallelMoELinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.quant_utils import (
dynamic_fp8_pertensor_quantize)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, fp8_input_scales_loader)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
DISABLE_SEQUENCE_PARALLEL = True # FIXME: os.getenv("DISABLE_SEQUENCE_PARALLEL", "0") == "1"
SEQUENCE_PARALLEL_THRESHOLD = 512 if os.getenv("SEQUENCE_PARALLEL_THRESHOLD", "0") == "0" else int(os.getenv("SEQUENCE_PARALLEL_THRESHOLD"))
GEMM_COMM_OVERLAP_RATIO = 0.5
MLP_BATCH_SIZE = 8192
def _get_alibi_slopes(n_heads):
n = 2**math.floor(math.log2(n_heads)) # nearest 2**n to n_heads
m0 = 2.0**(-8.0 / n)
slopes = np.power(m0, np.arange(1, n + 1))
if n < n_heads:
m1 = 2.0**(-4.0 / n)
mm = np.power(m1, np.arange(1, 1 + 2 * (n_heads - n), 2))
slopes = np.concatenate([slopes, mm])
return slopes
def _get_ntk_alibi_slopes(max_pos_interp_ratio, slopes):
if max_pos_interp_ratio == 1.0:
return slopes
smax, smin = slopes.max(), slopes.min()
D0 = np.log2(smax) - np.log2(smin)
W1 = (np.log2(smax) - np.log2(slopes)) / D0
ratios = np.power(max_pos_interp_ratio, W1)
return slopes / (ratios**0.5)
class Step1MoEMLP(nn.Module):
def __init__(self,
num_experts: int,
top_k: int,
top_p: float,
hidden_size: int,
intermediate_size: int,
hidden_act="",
quant_config: Optional[QuantizationConfig] = None,
norm_expert_weight=True,
prefix: str = "",
enable_cudagraph: bool = False):
super().__init__()
self.gate = ReplicatedLinear(input_size=hidden_size,
output_size=num_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate")
self.top_k = top_k
self.top_p = top_p
self.num_experts = num_experts
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
)
assert intermediate_size % tensor_model_parallel_world_size == 0
self.gate_up_proj = MergedColumnParallelMoELinear(
num_experts,
hidden_size, [intermediate_size] * 2,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
if (intermediate_size / tensor_model_parallel_world_size) % 64 == 0:
self.act_fn = OptimusSiluAndMul()
else:
self.act_fn = SiluAndMul()
self.down_proj = RowParallelMoELinear(num_experts,
intermediate_size,
hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
self.tp_rank = get_tensor_model_parallel_rank()
self.quant_config = quant_config
self.norm_expert_weight = norm_expert_weight
self.enable_cudagraph = enable_cudagraph
self.need_fp32_gate = False
def get_expert_output(self, inputs: torch.Tensor,
expert_token_cnt: torch.Tensor, token_nums: int):
if self.quant_config and getattr(self.gate_up_proj.quant_method,
"quant_config", None) and getattr(
self.down_proj.quant_method,
"quant_config", None):
if inputs.size(
0
) <= 1024 and self.gate_up_proj.quant_method.quant_config.weight_bits == 8 and self.down_proj.quant_method.quant_config.weight_bits == 8:
if self.enable_cudagraph:
tmp = torch.ops.Optimus.MoeFpAIntBGemm(
inputs, self.gate_up_proj.qweight,
self.gate_up_proj.qweight.dtype,
self.gate_up_proj.scales, expert_token_cnt, token_nums,
None)
tmp = self.act_fn(tmp)
tmp = torch.ops.Optimus.MoeFpAIntBGemm(
tmp, self.down_proj.qweight,
self.down_proj.qweight.dtype, self.down_proj.scales,
expert_token_cnt, token_nums, None)
return tmp
else:
quant_output_ = torch.ops.OptimusMoe.moe_ffn_quant(
inputs,
self.gate_up_proj.qweight.dtype,
self.gate_up_proj.qweight,
self.gate_up_proj.scales,
self.down_proj.qweight,
self.down_proj.scales,
expert_token_cnt,
token_nums,
out=inputs)
return quant_output_
else:
expert_token_cnt = expert_token_cnt.to("cpu").tolist()
start = 0
end = 0
if getattr(
self.gate_up_proj.quant_method, "quant_config", None
) and self.gate_up_proj.quant_method.quant_config.weight_bits == 6:
output = torch.empty_like(inputs,
dtype=torch.bfloat16,
device=inputs.device)
else:
output = inputs
for i in range(len(expert_token_cnt)):
cur_token_cnt = expert_token_cnt[i]
if (cur_token_cnt <= 0):
continue
end += cur_token_cnt
tmp = self.gate_up_proj(inputs[start:end], expert_idx=i)
tmp = self.act_fn(tmp)
tmp = self.down_proj(tmp,
expert_idx=i,
output=output[start:end])
start += cur_token_cnt
return output
else:
moe_output = torch.ops.OptimusMoe.moe_ffn(inputs,
self.gate_up_proj.weight,
self.down_proj.weight,
expert_token_cnt,
token_nums)
return moe_output
def forward(
self,
x,
residual=None,
layernorm=None,
disable_allreduce=False,
user_output=None,
):
if layernorm is not None:
x = layernorm(
x,
fp16_out=getattr(self.gate_up_proj.quant_method,
"quant_config", None) and
self.gate_up_proj.quant_method.quant_config.weight_bits == 6
if self.gate_up_proj.quant_method else False)
x_shape = x.shape
if self.need_fp32_gate:
if getattr(
self.gate_up_proj.quant_method, "quant_config", None
) and self.gate_up_proj.quant_method.quant_config.weight_bits == 6:
logits = torch.ops.OptimusMoe.matmul_fp32(x.to(torch.bfloat16),
self.gate.weight.t())
else:
logits = torch.ops.OptimusMoe.matmul_fp32(x, self.gate.weight.t())
else:
logits = self.gate(x)[0]
# if self.top_p < 1.0:
# top_k_index, expert_weight, scatter_index = torch.ops.OptimusMoe.topk_topp_gating(
# logits, self.top_k, self.top_p, self.norm_expert_weight)
# expert_token_cnt = optimus_moe_expert_histogram(
# top_k_index, self.num_experts)
# scatter_index = torch.ops.OptimusMoe.index_compute(
# top_k_index, expert_token_cnt, out=scatter_index)
# mid_output = optimus_moe_scatter(x, scatter_index)
# expert_output = self.get_expert_output(mid_output,
# expert_token_cnt,
# x_shape[0])
# output = optimus_moe_gather(expert_output, scatter_index,
# expert_weight)
# else:
# expert_weight, expert_token_cnt, scatter_index = torch.ops.OptimusMoe.gating_histogram_index(
# logits, self.top_k, 1.0, self.norm_expert_weight)
# mid_output = optimus_moe_scatter(x, scatter_index)
# expert_output = self.get_expert_output(mid_output,
# expert_token_cnt,
# x_shape[0])
# output = optimus_moe_gather(expert_output, scatter_index,
# expert_weight)
if self.top_p < 1.0:
top_k_index, expert_weight, scatter_index = torch.ops.OptimusMoe.topk_topp_gating(
logits, self.top_k, self.top_p, self.norm_expert_weight)
expert_token_cnt = torch.ops.OptimusMoe.expert_histogram(
top_k_index, self.num_experts)
scatter_index = torch.ops.OptimusMoe.index_compute(
top_k_index, expert_token_cnt, out=scatter_index)
mid_output = torch.ops.OptimusMoe.scatter(x, scatter_index)
expert_output = self.get_expert_output(mid_output,
expert_token_cnt,
x_shape[0])
output = torch.ops.OptimusMoe.gather(expert_output, scatter_index,
expert_weight)
else:
expert_weight, expert_token_cnt, scatter_index = torch.ops.OptimusMoe.gating_histogram_index(
logits, self.top_k, 1.0, self.norm_expert_weight)
mid_output = torch.ops.OptimusMoe.scatter(x, scatter_index)
expert_output = self.get_expert_output(mid_output,
expert_token_cnt,
x_shape[0])
output = torch.ops.OptimusMoe.gather(expert_output, scatter_index,
expert_weight)
if self.tp_rank == 0 and residual is not None:
output += residual
if not disable_allreduce:
output = tensor_model_parallel_all_reduce(output)
if user_output is not None:
user_output.copy_(output)
return output
class Step1MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
use_optimus_silu: bool = True,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
if use_optimus_silu:
self.act_fn = OptimusSiluAndMul()
else:
self.act_fn = SiluAndMul()
def forward(self,
x,
residual=None,
layernorm=None,
disable_allreduce=False,
user_output=None):
if layernorm is not None:
x = layernorm(
x,
fp16_out=self.gate_up_proj.quant_method.quant_config.
weight_bits == 6 if getattr(self.gate_up_proj.quant_method,
"quant_config", None) else False)
x, _ = self.gate_up_proj(x)
x = self.act_fn(x)
residual, _ = self.down_proj(x,
residual,
output=user_output,
disable_allreduce=disable_allreduce)
return residual
class Step1Attention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
slopes: Optional[List[float]] = None,
max_pos_interp_ratio: float = 1.0,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
# Create the alibi slopes and slice them.
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
if slopes is None:
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
alibi_slopes = _get_ntk_alibi_slopes(max_pos_interp_ratio,
alibi_slopes)
alibi_slopes = alibi_slopes[head_start:head_end]
else:
assert len(slopes) == self.total_num_heads
alibi_slopes = _get_ntk_alibi_slopes(max_pos_interp_ratio,
slopes).tolist()
alibi_slopes = slopes[head_start:head_end]
scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads,
self.head_dim,
scaling,
self.num_kv_heads,
alibi_slopes,
alibi_sqrt=True,
cache_config=cache_config,
prefix=f"{prefix}.attn")
def forward(self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] = None,
layernorm: Optional[nn.Module] = None,
disable_allreduce=False,
user_output=None) -> torch.Tensor:
del positions # Unused.
hidden_states = layernorm(
hidden_states,
fp16_out=self.qkv_proj.quant_method.quant_config.weight_bits == 6
if getattr(self.qkv_proj.quant_method, "quant_config",
None) else False) if layernorm else hidden_states
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q,
k,
v)
residual, _ = self.o_proj(attn_output,
residual,
disable_allreduce=disable_allreduce,
output=user_output)
return residual
class Step1DecoderLayer(nn.Module):
def __init__(self,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.enable_cudagraph = not model_config.enforce_eager
config = model_config.hf_config
self.hidden_size = config.hidden_size
self.self_attn = Step1Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_attention_groups,
slopes=config.alibi_slopes,
max_pos_interp_ratio=config.max_pos_interp_ratio,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
layer_idx = int(prefix.split("layers.")[1].split(".")[0])
self.use_moe = config.use_moe and (layer_idx + config.moe_layer_offset
) % config.moe_every_n_layer == 0
if self.use_moe:
self.moe = Step1MoEMLP(
config.moe_num_experts,
config.moe_top_k,
config.moe_dynamic_exp_p,
hidden_size=self.hidden_size,
intermediate_size=config.moe_intermediate_size,
hidden_act="silu",
quant_config=quant_config,
prefix=f"{prefix}.moe",
enable_cudagraph=self.enable_cudagraph,
)
else:
self.mlp = Step1MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act="silu",
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
ln_cls = OptimusRMSNorm if config.hidden_size % 64 == 0 else RMSNorm
self.input_layernorm = ln_cls(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = ln_cls(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
# Self Attention
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
residual=hidden_states,
layernorm=self.input_layernorm,
)
# Fully Connected
def ffn_switch():
return self.moe if self.use_moe else self.mlp
hidden_states = ffn_switch()(hidden_states,
hidden_states,
self.post_attention_layernorm)
return hidden_states
# @support_torch_compile
class Step1Model(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
assert lora_config is None
self.config = config
self.allgather_dtype = None # FIXME(ys): disable fp8 allgather
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
if get_pp_group().is_first_rank or (config.tie_word_embeddings
and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
)
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Step1DecoderLayer(model_config=vllm_config.
model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
ln_cls = OptimusRMSNorm if config.hidden_size % 64 == 0 else RMSNorm
if get_pp_group().is_last_rank:
self.norm = ln_cls(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
self.sequence_parallel_threshold = None if DISABLE_SEQUENCE_PARALLEL else SEQUENCE_PARALLEL_THRESHOLD
self.overlap_ratio = GEMM_COMM_OVERLAP_RATIO
self.mlp_batch_size = MLP_BATCH_SIZE
self.tp_size = get_tensor_model_parallel_world_size()
self.use_moe = config.use_moe
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().world_size > 1:
return self.forward_pp(input_ids, positions, intermediate_tensors,
inputs_embeds)
else:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
if self.use_moe:
return self.forward_hidden_states_moe(hidden_states, positions)
else:
return self.forward_hidden_states(hidden_states, positions)
def forward_pp(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
})
hidden_states = self.norm(hidden_states)
return hidden_states
def forward_hidden_states_moe(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
S = hidden_states.shape[0]
if (self.tp_size > 1 and self.sequence_parallel_threshold is not None
and self.sequence_parallel_threshold < S):
if self.tp_size > 8:
return self.forward_overlap_v2(hidden_states, positions)
else:
# TODO(xwx): overlap mlp layer of MoE model
return self.forward_split_ffn(hidden_states, positions)
else:
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
)
hidden_states = self.norm(hidden_states)
return hidden_states
def forward_overlap_v2(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
del positions
S = hidden_states.shape[0]
tp_size = get_tensor_model_parallel_world_size()
rank = get_tensor_model_parallel_rank()
if S % tp_size != 0:
# pad to multiple of tp_size with 0
pad_len = tp_size - S % tp_size
hidden_states = torch.cat([
hidden_states,
torch.zeros(pad_len,
hidden_states.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
])
S = hidden_states.shape[0]
else:
pad_len = 0
assert S % tp_size == 0
hidden_states = hidden_states.view(S, -1)
dim_0 = int((S * self.overlap_ratio + tp_size - 1) // tp_size *
tp_size) # round up to multiple of tp_size
buffer = torch.empty(S * int(self.config.intermediate_size / tp_size),
dtype=hidden_states.dtype,
device=hidden_states.device)
mlp_buffer = buffer.view(S, -1)
if tp_size > 8:
assert tp_size % 8 == 0, f"tp_size should be an integer multiple of 8,but cur tp_size={tp_size}"
kv_repeat = tp_size // 8
else:
kv_repeat = 1
qkv_buffer = buffer[:S * int(
(self.config.num_attention_heads +
self.config.num_attention_groups * kv_repeat * 2) // tp_size *
(self.config.hidden_size //
self.config.num_attention_heads))].view(S, -1)
chunk_size = S // tp_size
residual = torch.empty(chunk_size,
self.config.hidden_size,
dtype=hidden_states.dtype,
device=hidden_states.device)
chunk_size_0 = dim_0 // tp_size
chunk_size_1 = chunk_size - chunk_size_0
residual_intersect_0 = residual[:chunk_size_0]
residual_intersect_1 = residual[chunk_size_0:]
hidden_states_intersect_0 = hidden_states[rank *
chunk_size_0:(rank + 1) *
chunk_size_0]
hidden_states_intersect_1 = hidden_states[dim_0 +
rank * chunk_size_1:dim_0 +
(rank + 1) * chunk_size_1]
s1 = torch.cuda.Stream(device=residual.device)
for i in range(len(self.layers)):
layer = self.layers[i]
ffn = layer.moe if layer.use_moe else layer.mlp
# Attention Forward
residual_intersect_0.copy_(hidden_states_intersect_0)
layer.input_layernorm(hidden_states[:dim_0],
output=hidden_states[:dim_0])
layer.self_attn.qkv_proj(hidden_states[:dim_0],
output=qkv_buffer[:dim_0])
with torch.cuda.stream(s1):
residual_intersect_1.copy_(hidden_states_intersect_1)
layer.input_layernorm(hidden_states[dim_0:],
output=hidden_states[dim_0:])
layer.self_attn.qkv_proj(hidden_states[dim_0:],
output=qkv_buffer[dim_0:])
torch.cuda.current_stream().wait_stream(s1)
q, k, v = qkv_buffer.view(S, -1).split([
layer.self_attn.q_size, layer.self_attn.kv_size,
layer.self_attn.kv_size
],
dim=-1)
if pad_len > 0:
attn_output = layer.self_attn.attn(q[:-pad_len], k[:-pad_len], v[:-pad_len])
attn_output = torch.cat([attn_output, torch.zeros(pad_len, attn_output.shape[1], dtype=attn_output.dtype, device=attn_output.device)], dim=0)
else:
attn_output = layer.self_attn.attn(q, k, v)
hidden_states = hidden_states.view(S, -1)
layer.self_attn.o_proj(attn_output[:dim_0],
output=hidden_states[:dim_0],
disable_allreduce=True)
hidden_states_intersect_0.add_(residual_intersect_0)
torch.distributed.all_reduce(
hidden_states[:dim_0],
group=get_tensor_model_parallel_group().device_group)
with torch.cuda.stream(s1):
layer.self_attn.o_proj(attn_output[dim_0:],
output=hidden_states[dim_0:],
disable_allreduce=True)
hidden_states_intersect_1.add_(residual_intersect_1)
torch.distributed.all_reduce(
hidden_states[dim_0:],
group=get_tensor_model_parallel_group().device_group)
del attn_output
residual_intersect_0.copy_(hidden_states_intersect_0)
layer.post_attention_layernorm(hidden_states[:dim_0],
output=hidden_states[:dim_0])
num_batch_size = (dim_0 + self.mlp_batch_size -
1) // self.mlp_batch_size
for idx in range(num_batch_size):
start = idx * self.mlp_batch_size
end = min((idx + 1) * self.mlp_batch_size, dim_0)
ffn(hidden_states[start:end],
disable_allreduce=True,
user_output=hidden_states[start:end])
hidden_states_intersect_0.add_(residual_intersect_0)
torch.distributed.all_reduce(
hidden_states[:dim_0],
group=get_tensor_model_parallel_group().device_group)
with torch.cuda.stream(s1):
residual_intersect_1.copy_(hidden_states_intersect_1)
layer.post_attention_layernorm(hidden_states[dim_0:],
output=hidden_states[dim_0:])
num_batch_size = (S - dim_0 + self.mlp_batch_size -
1) // self.mlp_batch_size
for idx in range(num_batch_size):
start = dim_0 + idx * self.mlp_batch_size
end = dim_0 + min(
(idx + 1) * self.mlp_batch_size, S - dim_0)
ffn(hidden_states[start:end],
disable_allreduce=True,
user_output=hidden_states[start:end])
hidden_states_intersect_1.add_(residual_intersect_1)
torch.distributed.all_reduce(
hidden_states[dim_0:],
group=get_tensor_model_parallel_group().device_group)
torch.cuda.current_stream().wait_stream(s1)
del buffer, mlp_buffer, qkv_buffer, residual
self.norm(hidden_states, output=hidden_states)
return hidden_states[:S - pad_len]
def forward_split_ffn(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
seq_len = hidden_states.shape[0]
tp_size = get_tensor_model_parallel_world_size()
rank = get_tensor_model_parallel_rank()
chunk_size = self.config.hidden_size // tp_size
residual = torch.empty(seq_len,
chunk_size,
dtype=hidden_states.dtype,
device=hidden_states.device)
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states_intersect_0 = hidden_states.narrow(
1, rank * chunk_size, chunk_size)
residual.copy_(hidden_states_intersect_0)
layer.input_layernorm(hidden_states, output=hidden_states)
layer.self_attn(positions,
hidden_states,
residual=None,
layernorm=None,
disable_allreduce=True,
user_output=hidden_states)
hidden_states_intersect_0.add_(residual)
torch.distributed.all_reduce(
hidden_states,
group=get_tensor_model_parallel_group().device_group)
residual.copy_(hidden_states_intersect_0)
layer.post_attention_layernorm(hidden_states, output=hidden_states)
num_batch_size = (seq_len + self.mlp_batch_size -
1) // self.mlp_batch_size
hidden_states = hidden_states.view(seq_len, -1)
for idx in range(num_batch_size):
start = idx * self.mlp_batch_size
end = min((idx + 1) * self.mlp_batch_size, seq_len)
if layer.use_moe:
hidden_states[start:end] = layer.moe(
hidden_states[start:end],
disable_allreduce=True)
else:
layer.mlp(hidden_states[start:end],
disable_allreduce=True,
user_output=hidden_states[start:end])
hidden_states_intersect_0.add_(residual)
torch.distributed.all_reduce(
hidden_states,
group=get_tensor_model_parallel_group().device_group)
del residual
self.norm(hidden_states, output=hidden_states)
return hidden_states
def forward_hidden_states(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
S = hidden_states.shape[0]
if self.tp_size > 1 and self.sequence_parallel_threshold is not None and self.sequence_parallel_threshold < S:
tp_size = get_tensor_model_parallel_world_size()
rank = get_tensor_model_parallel_rank()
if S % tp_size != 0:
# pad to multiple of tp_size with 0
pad_len = tp_size - S % tp_size
hidden_states = torch.cat([
hidden_states,
torch.zeros(pad_len,
hidden_states.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
])
S = hidden_states.shape[0]
else:
pad_len = 0
assert S % tp_size == 0
chunk_size = S // tp_size
residual = torch.empty(chunk_size,
self.config.hidden_size,
dtype=hidden_states.dtype,
device=hidden_states.device)
dim_0 = int((S * self.overlap_ratio + tp_size - 1) // tp_size *
tp_size) # round up to multiple of tp_size
dim_1 = S - dim_0
mlp_dim = int(self.config.intermediate_size / tp_size)
qkv_dim = int(
(self.config.num_attention_heads +
self.config.num_attention_groups * 2) // tp_size *
(self.config.hidden_size // self.config.num_attention_heads))
if self.allgather_dtype is not None:
fp8_dim = int(self.config.hidden_size / 2)
max_buffer_dim = max(mlp_dim, qkv_dim, fp8_dim)
else:
max_buffer_dim = max(mlp_dim, qkv_dim)
buffer = torch.empty(S * max_buffer_dim,
dtype=hidden_states.dtype,
device=hidden_states.device)
buffer_0 = buffer[:dim_0 * max_buffer_dim]
buffer_1 = buffer[dim_0 * max_buffer_dim:]
mlp_buffer_0 = buffer_0[:dim_0 * mlp_dim].view(dim_0, -1)
mlp_buffer_1 = buffer_1[:dim_1 * mlp_dim].view(dim_1, -1)
qkv_buffer = buffer[dim_0 * max_buffer_dim -
dim_0 * qkv_dim:dim_0 * max_buffer_dim +
dim_1 * qkv_dim].view(S, -1)
chunk_size_0 = dim_0 // tp_size
chunk_size_1 = chunk_size - chunk_size_0
residual_intersect_0 = residual[:chunk_size_0]
hidden_states_0 = hidden_states[:dim_0]
hidden_states_1 = hidden_states[dim_0:]
hidden_states_intersect_0 = hidden_states[rank *
chunk_size_0:(rank + 1) *
chunk_size_0]
residual_intersect_1 = residual[chunk_size_0:]
hidden_states_intersect_1 = hidden_states[dim_0 + rank *
chunk_size_1:dim_0 +
(rank + 1) *
chunk_size_1]
if self.allgather_dtype is not None:
hidden_states_fp8_0 = buffer_0[:dim_0 * fp8_dim]
hidden_states_fp8_0 = torch.ops.OptimusFp8.as_type(
hidden_states_fp8_0,
torch.uint8).reshape(dim_0, self.config.hidden_size)
hidden_states_fp8_1 = buffer_1[:dim_1 * fp8_dim]
hidden_states_fp8_1 = torch.ops.OptimusFp8.as_type(
hidden_states_fp8_1,
torch.uint8).reshape(dim_1, self.config.hidden_size)
hidden_states_fp8_intersect_0 = hidden_states_fp8_0[
rank * chunk_size_0:(rank + 1) * chunk_size_0]
hidden_states_fp8_intersect_1 = hidden_states_fp8_1[
rank * chunk_size_1:(rank + 1) * chunk_size_1]
s1 = torch.cuda.Stream(device=residual.device)
for i in range(len(self.layers)):
layer = self.layers[i]
# Attention Forward
if i == 0:
residual_intersect_0.copy_(hidden_states_intersect_0)
layer.input_layernorm(hidden_states[:dim_0],
output=hidden_states[:dim_0])
layer.self_attn.qkv_proj(hidden_states[:dim_0],
output=qkv_buffer[:dim_0])
with torch.cuda.stream(s1):
if i == 0:
residual_intersect_1.copy_(hidden_states_intersect_1)
layer.input_layernorm(hidden_states[dim_0:],
output=hidden_states[dim_0:])
else:
if self.allgather_dtype is not None:
if self.allgather_dtype == "static_fp8e4m3":
qkv_input_scale_1 = torch.full(
[1],
layer.self_attn.qkv_proj.input_scales,
device="cuda",
dtype=torch.float32)
torch.ops.OptimusFp8.rms_norm_quantize_infer(
residual_intersect_1,
layer.input_layernorm.weight,
qkv_input_scale_1,
out=hidden_states_fp8_intersect_1)
elif self.allgather_dtype == "dynamic_fp8e4m3":
layer.input_layernorm(
residual_intersect_1,
output=hidden_states_intersect_1)
qkv_input_scale_1 = dynamic_fp8_pertensor_quantize(
hidden_states_intersect_1)
torch.distributed.all_reduce(
qkv_input_scale_1,
torch.distributed.ReduceOp.MIN,
group=get_tensor_model_parallel_group(
).device_group)
hidden_states_fp8_intersect_1 = torch.ops.OptimusFp8.as_type(
hidden_states_fp8_intersect_1,
torch.float8_e4m3fn)
torch.ops.OptimusFp8.quantize(
hidden_states_intersect_1,
qkv_input_scale_1,
out=hidden_states_fp8_intersect_1)
hidden_states_fp8_intersect_1 = torch.ops.OptimusFp8.as_type(
hidden_states_fp8_intersect_1, torch.uint8)
else:
raise ValueError(
f"Unsupported allgather_dtype: {self.allgather_dtype}"
)
torch.distributed.all_gather_into_tensor(
hidden_states_fp8_1,
hidden_states_fp8_intersect_1,
group=get_tensor_model_parallel_group(
).device_group)
hidden_states_fp8_1 = torch.ops.OptimusFp8.as_type(
hidden_states_fp8_1, torch.float8_e4m3fn)
torch.ops.OptimusFp8.dequantize(
hidden_states_fp8_1,
qkv_input_scale_1.reciprocal(),
torch.bfloat16,
out=hidden_states[dim_0:])
hidden_states_fp8_1 = torch.ops.OptimusFp8.as_type(
hidden_states_fp8_1, torch.uint8)
else:
layer.input_layernorm(
residual_intersect_1,
output=hidden_states_intersect_1)
torch.distributed.all_gather_into_tensor(
hidden_states[dim_0:],
hidden_states_intersect_1,
group=get_tensor_model_parallel_group(
).device_group)
layer.self_attn.qkv_proj(hidden_states[dim_0:],
output=qkv_buffer[dim_0:])
torch.cuda.current_stream().wait_stream(s1)
q, k, v = qkv_buffer.split([
layer.self_attn.q_size, layer.self_attn.kv_size,
layer.self_attn.kv_size
],
dim=-1)
if pad_len > 0:
attn_output = layer.self_attn.attn(q[:S-pad_len], k[:S-pad_len], v[:S-pad_len])
attn_output = torch.cat([attn_output, torch.zeros(pad_len, attn_output.shape[1], dtype=attn_output.dtype, device=attn_output.device)], dim=0)
else:
attn_output = layer.self_attn.attn(q, k, v)
layer.self_attn.o_proj(attn_output[:dim_0],
output=hidden_states[:dim_0],
disable_allreduce=True)
hidden_states_intersect_0.add_(residual_intersect_0)
torch.distributed.reduce_scatter_tensor(
residual_intersect_0,
hidden_states[:dim_0],
group=get_tensor_model_parallel_group().device_group)
with torch.cuda.stream(s1):
layer.self_attn.o_proj(attn_output[dim_0:],
output=hidden_states[dim_0:],
disable_allreduce=True)
hidden_states_intersect_1.add_(residual_intersect_1)
del attn_output
if self.allgather_dtype is not None:
if self.allgather_dtype == "static_fp8e4m3":
gate_up_input_scale_0 = torch.full(
[1],
layer.mlp.gate_up_proj.input_scales,
device="cuda",
dtype=torch.float32)
torch.ops.OptimusFp8.rms_norm_quantize_infer(
residual_intersect_0,
layer.post_attention_layernorm.weight,
gate_up_input_scale_0,
out=hidden_states_fp8_intersect_0)
elif self.allgather_dtype == "dynamic_fp8e4m3":
layer.post_attention_layernorm(
residual_intersect_0,
output=hidden_states_intersect_0)
gate_up_input_scale_0 = dynamic_fp8_pertensor_quantize(
hidden_states_intersect_0)
torch.distributed.all_reduce(
gate_up_input_scale_0,
torch.distributed.ReduceOp.MIN,
group=get_tensor_model_parallel_group(
).device_group)
hidden_states_fp8_intersect_0 = torch.ops.OptimusFp8.as_type(
hidden_states_fp8_intersect_0, torch.float8_e4m3fn)
torch.ops.OptimusFp8.quantize(
hidden_states_intersect_0,
gate_up_input_scale_0,
out=hidden_states_fp8_intersect_0)
hidden_states_fp8_intersect_0 = torch.ops.OptimusFp8.as_type(
hidden_states_fp8_intersect_0, torch.uint8)
else:
raise ValueError(
f"Unsupported allgather_dtype: {self.allgather_dtype}"
)
torch.distributed.all_gather_into_tensor(
hidden_states_fp8_0,
hidden_states_fp8_intersect_0,
group=get_tensor_model_parallel_group().device_group)
else:
layer.post_attention_layernorm(
residual_intersect_0, output=hidden_states_intersect_0)
torch.distributed.all_gather_into_tensor(
hidden_states[:dim_0],
hidden_states_intersect_0,
group=get_tensor_model_parallel_group().device_group)
with torch.cuda.stream(s1):
torch.distributed.reduce_scatter_tensor(
residual_intersect_1,
hidden_states[dim_0:],
group=get_tensor_model_parallel_group().device_group)
if self.allgather_dtype is not None:
hidden_states_fp8_0 = torch.ops.OptimusFp8.as_type(
hidden_states_fp8_0, torch.float8_e4m3fn)
torch.ops.OptimusFp8.dequantize(
hidden_states_fp8_0,
gate_up_input_scale_0.reciprocal(),
torch.bfloat16,
out=hidden_states[:dim_0])
hidden_states_fp8_0 = torch.ops.OptimusFp8.as_type(
hidden_states_fp8_0, torch.uint8)
num_batch_size = (dim_0 + self.mlp_batch_size -
1) // self.mlp_batch_size
for idx in range(num_batch_size):
start = idx * self.mlp_batch_size
end = min((idx + 1) * self.mlp_batch_size, dim_0)
w0_out_0, _ = layer.mlp.gate_up_proj(
hidden_states_0[start:end])
layer.mlp.act_fn(w0_out_0, output=mlp_buffer_0[start:end])
del w0_out_0
with torch.cuda.stream(s1):
if self.allgather_dtype is not None:
if self.allgather_dtype == "static_fp8e4m3":
gate_up_input_scale_1 = torch.full(
[1],
layer.mlp.gate_up_proj.input_scales,
device="cuda",
dtype=torch.float32)
torch.ops.OptimusFp8.rms_norm_quantize_infer(
residual_intersect_1,
layer.post_attention_layernorm.weight,
gate_up_input_scale_1,
out=hidden_states_fp8_intersect_1)
elif self.allgather_dtype == "dynamic_fp8e4m3":
layer.post_attention_layernorm(
residual_intersect_1,
output=hidden_states_intersect_1)
gate_up_input_scale_1 = dynamic_fp8_pertensor_quantize(
hidden_states_intersect_1)
torch.distributed.all_reduce(
gate_up_input_scale_1,
torch.distributed.ReduceOp.MIN,
group=get_tensor_model_parallel_group(
).device_group)
hidden_states_fp8_intersect_1 = torch.ops.OptimusFp8.as_type(
hidden_states_fp8_intersect_1,
torch.float8_e4m3fn)
torch.ops.OptimusFp8.quantize(
hidden_states_intersect_1,
gate_up_input_scale_1,
out=hidden_states_fp8_intersect_1)
hidden_states_fp8_intersect_1 = torch.ops.OptimusFp8.as_type(
hidden_states_fp8_intersect_1, torch.uint8)
else:
raise ValueError(
f"Unsupported allgather_dtype: {self.allgather_dtype}"
)
torch.distributed.all_gather_into_tensor(
hidden_states_fp8_1,
hidden_states_fp8_intersect_1,
group=get_tensor_model_parallel_group(
).device_group)
else:
layer.post_attention_layernorm(
residual_intersect_1,
output=hidden_states_intersect_1)
torch.distributed.all_gather_into_tensor(
hidden_states[dim_0:],
hidden_states_intersect_1,
group=get_tensor_model_parallel_group(
).device_group)
layer.mlp.down_proj(mlp_buffer_0,
output=hidden_states[:dim_0],
disable_allreduce=True)
hidden_states_intersect_0.add_(residual_intersect_0)
if i < len(self.layers) - 1:
torch.distributed.reduce_scatter_tensor(
residual_intersect_0,
hidden_states[:dim_0],
group=get_tensor_model_parallel_group().device_group)
else:
torch.distributed.all_reduce(
hidden_states[:dim_0],
group=get_tensor_model_parallel_group().device_group)
with torch.cuda.stream(s1):
if self.allgather_dtype is not None:
hidden_states_fp8_1 = torch.ops.OptimusFp8.as_type(
hidden_states_fp8_1, torch.float8_e4m3fn)
torch.ops.OptimusFp8.dequantize(
hidden_states_fp8_1,
gate_up_input_scale_1.reciprocal(),
torch.bfloat16,
out=hidden_states[dim_0:])
hidden_states_fp8_1 = torch.ops.OptimusFp8.as_type(
hidden_states_fp8_1, torch.uint8)
num_batch_size = (dim_1 + self.mlp_batch_size -
1) // self.mlp_batch_size
for idx in range(num_batch_size):
start = idx * self.mlp_batch_size
end = min((idx + 1) * self.mlp_batch_size, dim_1)
w0_out_1, _ = layer.mlp.gate_up_proj(
hidden_states_1[start:end])
layer.mlp.act_fn(w0_out_1,
output=mlp_buffer_1[start:end])
del w0_out_1
if i < len(self.layers) - 1:
next_layer = self.layers[i + 1]
if self.allgather_dtype is not None:
if self.allgather_dtype == "static_fp8e4m3":
qkv_input_scale_0 = torch.full(
[1],
next_layer.self_attn.qkv_proj.input_scales,
device="cuda",
dtype=torch.float32)
torch.ops.OptimusFp8.rms_norm_quantize_infer(
residual_intersect_0,
next_layer.input_layernorm.weight,
qkv_input_scale_0,
out=hidden_states_fp8_intersect_0)
elif self.allgather_dtype == "dynamic_fp8e4m3":
next_layer.input_layernorm(
residual_intersect_0,
output=hidden_states_intersect_0)
qkv_input_scale_0 = dynamic_fp8_pertensor_quantize(
hidden_states_intersect_0)
torch.distributed.all_reduce(
qkv_input_scale_0,
torch.distributed.ReduceOp.MIN,
group=get_tensor_model_parallel_group(
).device_group)
hidden_states_fp8_intersect_0 = torch.ops.OptimusFp8.as_type(
hidden_states_fp8_intersect_0,
torch.float8_e4m3fn)
torch.ops.OptimusFp8.quantize(
hidden_states_intersect_0,
qkv_input_scale_0,
out=hidden_states_fp8_intersect_0)
hidden_states_fp8_intersect_0 = torch.ops.OptimusFp8.as_type(
hidden_states_fp8_intersect_0, torch.uint8)
else:
raise ValueError(
f"Unsupported allgather_dtype: {self.allgather_dtype}"
)
torch.distributed.all_gather_into_tensor(
hidden_states_fp8_0,
hidden_states_fp8_intersect_0,
group=get_tensor_model_parallel_group(
).device_group)
else:
next_layer.input_layernorm(
residual_intersect_0,
output=hidden_states_intersect_0)
torch.distributed.all_gather_into_tensor(
hidden_states[:dim_0],
hidden_states_intersect_0,
group=get_tensor_model_parallel_group(
).device_group)
with torch.cuda.stream(s1):
layer.mlp.down_proj(mlp_buffer_1,
output=hidden_states[dim_0:],
disable_allreduce=True)
hidden_states_intersect_1.add_(residual_intersect_1)
if i < len(self.layers) - 1:
torch.distributed.reduce_scatter_tensor(
residual_intersect_1,
hidden_states[dim_0:],
group=get_tensor_model_parallel_group(
).device_group)
else:
torch.distributed.all_reduce(
hidden_states[dim_0:],
group=get_tensor_model_parallel_group(
).device_group)
if i < len(self.layers) - 1:
next_layer = self.layers[i + 1]
if self.allgather_dtype is not None:
hidden_states_fp8_0 = torch.ops.OptimusFp8.as_type(
hidden_states_fp8_0, torch.float8_e4m3fn)
torch.ops.OptimusFp8.dequantize(
hidden_states_fp8_0,
qkv_input_scale_0.reciprocal(),
torch.bfloat16,
out=hidden_states[:dim_0])
hidden_states_fp8_0 = torch.ops.OptimusFp8.as_type(
hidden_states_fp8_0, torch.uint8)
next_layer.self_attn.qkv_proj(hidden_states[:dim_0],
output=qkv_buffer[:dim_0])
torch.cuda.current_stream().wait_stream(s1)
del buffer, residual
self.norm(hidden_states, output=hidden_states)
return hidden_states[:S - pad_len]
else:
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class Step1PretrainedModel(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
break
else:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
params_need_to_load = []
for name in params_dict:
if not ("vision_model" in name or "latent_query_tokens" in name
or "sam_model" in name):
params_need_to_load.append(name)
params_need_to_load = set(params_need_to_load)
if params_need_to_load != loaded_params:
param_name_example = list(params_need_to_load - loaded_params)[0]
raise RuntimeError(
f"Some parameters like {param_name_example} are not in the checkpoint and will falsely use random initialization"
)
def load_fp8_input_scales(self, input_scales_path):
for name, loaded_weight in fp8_input_scales_loader(input_scales_path):
if name.startswith("refrence_model."):
name = name.replace("refrence_model.", "")
idx = int(name.split(".")[2])
layer = self.model.layers[idx]
if "qkv_proj" in name:
layer.self_attn.qkv_proj.input_scales = loaded_weight[:].item()
elif "gate_up_proj" in name:
layer.mlp.gate_up_proj.input_scales = loaded_weight[:].item()
class Step1ForCausalLM(Step1PretrainedModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
lora_config = vllm_config.lora_config
self.config = config
self.model = Step1Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
prefix=maybe_prefix(prefix, "lm_head"),
)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale,
need_fp32_logits=False)
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
class Step1ForSequenceClassification(Step1PretrainedModel):
"""\
Step1 Transformer with a sequence classification head.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.model = Step1Model(vllm_config, prefix)
config = vllm_config.model_config.hf_config
assert len(config.id2label.keys()) == config.num_labels
if get_pp_group().is_last_rank:
self.score = ReplicatedLinear(config.hidden_size,
config.num_labels,
bias=False)
pooler_config = vllm_config.model_config.pooler_config
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.ALL,
normalize=False,
softmax=False)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
logits, _ = self.score(hidden_states)
ret = self._pooler(logits, pooling_metadata)
return ret
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
"""Inference-only Jurassic model."""
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
from torch import nn
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (get_dp_group, get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.step1 import Step1MoEMLP
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
logger = init_logger(__name__)
# 全局共享的CUDA graph memory pool,类似model_runner.py中的实现
_graph_memory_pool: Optional[Tuple[int, int]] = None
class FusedMoEBlock(nn.Module):
def __init__(self,
config: ModelConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size > config.moe_num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.moe_num_experts}.")
assert config.moe_dynamic_exp_p == 1, "Only support dynamic exp p=1"
self.experts = FusedMoE(num_experts=config.moe_num_experts,
top_k=config.moe_top_k,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_expert_weight,
quant_config=quant_config,
prefix=f"{prefix}.experts")
self.gate = ReplicatedLinear(config.hidden_size,
config.moe_num_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(orig_shape)
class Step2MiniMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
self.prefix = prefix
self.hidden_size = hidden_size
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(hidden_states)
intermediate_act = self.act_fn(gate_up)
output, _ = self.down_proj(intermediate_act)
return output
class Step2MiniAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
norm_eps: float,
rope_theta: int,
share_q_dim: Optional[int] = None,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embedding: int = 8192,
head_dim: int = 256,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.q_size = share_q_dim if share_q_dim else self.head_dim
self.qkv_proj = ReplicatedLinear(
hidden_size,
self.q_size + self.kv_size * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.inter_norm = RMSNorm(self.q_size, eps=norm_eps)
self.wq = ColumnParallelLinear(
self.q_size,
self.head_dim * self.total_num_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wq",
)
self.rotary_emb = get_rope(self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embedding,
base=rope_theta,
rope_scaling=rope_scaling)
scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads,
self.head_dim,
scaling,
self.num_kv_heads,
cache_config=cache_config,
prefix=f"{prefix}.attn")
self.prefix = prefix
def forward(self,
positions: torch.Tensor,
hidden_states: torch.Tensor) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = self.inter_norm(q.contiguous())
q = self.wq(q)[0]
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
residual, _ = self.o_proj(attn_output)
return residual
class Step2MiniDecoderLayer(nn.Module):
def __init__(self,
config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
use_fused_moe: bool = False,
prefix: str = "") -> None:
super().__init__()
config = config.hf_config
self.hidden_size = config.hidden_size
rope_scaling = getattr(config, "rope_scaling", None)
self.self_attn = Step2MiniAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
norm_eps=config.rms_norm_eps,
max_position_embedding=config.max_position_embedding,
head_dim=config.head_dim,
share_q_dim=config.share_q_dim,
rope_theta=config.rope_theta,
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn")
self.use_moe = False
layer_idx = int(prefix.split("layers.")[1].split(".")[0])
moe_layers_enum = getattr(config, "moe_layers_enum", None)
if moe_layers_enum is not None:
moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(',')]
else:
# Default to 1dense.
moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
if layer_idx in moe_layers_idx:
if not use_fused_moe:
self.moe = Step1MoEMLP(
config.moe_num_experts,
config.moe_top_k,
config.moe_dynamic_exp_p,
hidden_size=self.hidden_size,
intermediate_size=config.moe_intermediate_size,
hidden_act="silu",
quant_config=quant_config,
norm_expert_weight=config.norm_expert_weight,
prefix=f"{prefix}.moe",
enable_cudagraph=False) # FIXME: TODO: enable cudagraph
else:
self.moe = FusedMoEBlock(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.moe")
self.share_expert = Step2MiniMLP(
hidden_size=self.hidden_size,
intermediate_size=config.share_expert_dim,
hidden_act="silu",
quant_config=quant_config,
prefix=f"{prefix}.share_expert")
self.use_moe = True
else:
self.mlp = Step2MiniMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act="silu",
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.use_fused_moe = use_fused_moe
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.prefix = prefix
# CUDA Graph parameters - 简化版本,使用共享memory pool
self.should_capture_graph = get_dp_group().world_size > 1 and current_platform.is_cuda_alike()
self.cuda_graphs_captured = False
self.graph_runners_fwd1: dict[int, Tuple[torch.cuda.CUDAGraph, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = {}
self.graph_runners_fwd2: dict[int, Tuple[torch.cuda.CUDAGraph, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = {}
self.graph_runners_fwd3: dict[int, Tuple[torch.cuda.CUDAGraph, torch.Tensor, torch.Tensor, torch.Tensor]] = {}
self.max_graph_tokens = 64
self.graph_token_step = 32
self.decoder_captured_sizes = list(range(self.graph_token_step,
self.max_graph_tokens + 1,
self.graph_token_step)) if self.should_capture_graph else []
@torch.inference_mode()
def _capture_cuda_graph(self, device: torch.device, hs_dtype: torch.dtype, pos_dtype: torch.dtype):
global _graph_memory_pool
if self.cuda_graphs_captured or not self.should_capture_graph:
return
# 使用全局共享的memory pool
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
for total_tokens in reversed(self.decoder_captured_sizes):
# --- Capture forward_1 ---
graph_fwd1 = torch.cuda.CUDAGraph()
# 创建输入buffers
static_positions = torch.ones((total_tokens,), dtype=pos_dtype, device=device)
static_hidden_states = torch.randn((total_tokens, self.hidden_size), dtype=hs_dtype, device=device)
# Warmup forward_1
_, _, _ = self._forward_1_impl(static_positions, static_hidden_states)
# Capture forward_1 - 使用torch.cuda.graph()和共享memory pool
with torch.cuda.graph(graph_fwd1, pool=_graph_memory_pool, stream=stream):
static_q_fwd1, static_k_fwd1, static_v_fwd1 = self._forward_1_impl(static_positions, static_hidden_states)
# 更新全局memory pool
if _graph_memory_pool is None:
_graph_memory_pool = graph_fwd1.pool()
self.graph_runners_fwd1[total_tokens] = (
graph_fwd1, static_positions, static_hidden_states,
static_q_fwd1, static_k_fwd1, static_v_fwd1
)
# --- Capture forward_2 ---
graph_fwd2 = torch.cuda.CUDAGraph()
# 创建输入buffers
attn_output_size = self.self_attn.num_heads * self.self_attn.head_dim
static_attn_output = torch.randn((total_tokens, attn_output_size), dtype=hs_dtype, device=device)
static_residual = torch.randn((total_tokens, self.hidden_size), dtype=hs_dtype, device=device)
# Warmup forward_2
_, _ = self._forward_2_impl(static_attn_output, static_residual)
# Capture forward_2 - 使用torch.cuda.graph()和共享memory pool
with torch.cuda.graph(graph_fwd2, pool=_graph_memory_pool, stream=stream):
static_hs_out_fwd2, static_residual_out_fwd2 = self._forward_2_impl(static_attn_output, static_residual)
self.graph_runners_fwd2[total_tokens] = (
graph_fwd2, static_attn_output, static_residual,
static_hs_out_fwd2, static_residual_out_fwd2
)
# --- Capture forward_3 ---
graph_fwd3 = torch.cuda.CUDAGraph()
# 创建输入buffers (重用之前的)
static_hidden_states_fwd3 = torch.randn((total_tokens, self.hidden_size), dtype=hs_dtype, device=device)
static_residual_fwd3 = torch.randn((total_tokens, self.hidden_size), dtype=hs_dtype, device=device)
# Warmup forward_3
_, _ = self._forward_3_impl(static_hidden_states_fwd3, static_residual_fwd3)
# Capture forward_3 - 使用torch.cuda.graph()和共享memory pool
with torch.cuda.graph(graph_fwd3, pool=_graph_memory_pool, stream=stream):
static_ffn_output_fwd3, static_router_logits_fwd3 = self._forward_3_impl(static_hidden_states_fwd3, static_residual_fwd3)
self.graph_runners_fwd3[total_tokens] = (
graph_fwd3, static_hidden_states_fwd3, static_residual_fwd3,
static_ffn_output_fwd3, static_router_logits_fwd3
)
torch.cuda.current_stream().wait_stream(stream)
self.cuda_graphs_captured = True
def _ensure_cuda_graphs_captured(self, device: torch.device, hs_dtype: torch.dtype, pos_dtype: torch.dtype):
if not self.cuda_graphs_captured and self.should_capture_graph:
self._capture_cuda_graph(device, hs_dtype, pos_dtype)
# Separate implementation logic from graph handling
def _forward_1_impl(self, positions: torch.Tensor, hidden_states: torch.Tensor):
hidden_states = self.input_layernorm(hidden_states)
# q, _ = self.self_attn.q_proj(hidden_states)
# kv, _ = self.self_attn.kv_proj(hidden_states)
# k, v = kv.split([self.self_attn.kv_size, self.self_attn.kv_size], dim=-1)
qkv, _ = self.self_attn.qkv_proj(hidden_states)
q, k, v = qkv.split([self.self_attn.q_size, self.self_attn.kv_size, self.self_attn.kv_size], dim=-1)
q = self.self_attn.inter_norm(q.contiguous())
q = self.self_attn.wq(q)[0]
q, k = self.self_attn.rotary_emb(positions, q, k)
return q, k, v
def forward_1(self, positions: torch.Tensor, hidden_states: torch.Tensor):
if self.should_capture_graph:
self._ensure_cuda_graphs_captured(hidden_states.device, hidden_states.dtype, positions.dtype)
graph_key = (hidden_states.shape[0] + self.graph_token_step - 1) // self.graph_token_step * self.graph_token_step
graph_data = self.graph_runners_fwd1.get(graph_key) if self.cuda_graphs_captured else None
use_graph = graph_data is not None and hidden_states.shape[0] <= self.max_graph_tokens
if use_graph:
graph, static_pos_view, static_hs_view, static_q, static_k, static_v = graph_data
actual_tokens = hidden_states.shape[0]
static_pos_view[:actual_tokens].copy_(positions)
static_hs_view[:actual_tokens].copy_(hidden_states)
graph.replay()
return static_q[:actual_tokens], static_k[:actual_tokens], static_v[:actual_tokens]
# Fallback to eager execution
return self._forward_1_impl(positions, hidden_states)
# Separate implementation logic from graph handling
def _forward_2_impl(self, attn_output: torch.Tensor, residual: torch.Tensor):
hidden_states, _ = self.self_attn.o_proj(attn_output)
hidden_states += residual
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
return hidden_states, residual
def forward_2(self, attn_output: torch.Tensor, residual: torch.Tensor):
if self.should_capture_graph:
graph_key = (attn_output.shape[0] + self.graph_token_step - 1) // self.graph_token_step * self.graph_token_step
graph_data = self.graph_runners_fwd2.get(graph_key) if self.cuda_graphs_captured else None
use_graph = graph_data is not None and attn_output.shape[0] <= self.max_graph_tokens
if use_graph:
graph, static_attn_output_view, static_residual_view, static_hs_out, static_residual_out = graph_data
actual_tokens = attn_output.shape[0]
static_attn_output_view[:actual_tokens].copy_(attn_output)
static_residual_view[:actual_tokens].copy_(residual)
graph.replay()
return static_hs_out[:actual_tokens], static_residual_out[:actual_tokens]
# Fallback to eager execution
return self._forward_2_impl(attn_output, residual)
# Separate implementation logic from graph handling
def _forward_3_impl(self, hidden_states: torch.Tensor, residual: torch.Tensor):
if self.use_moe:
ffn_output = self.share_expert(hidden_states)
router_logits, _ = self.moe.gate(hidden_states)
else:
ffn_output = self.mlp(hidden_states)
router_logits = None
return ffn_output + residual, router_logits # Base output before potential MoE addition
def forward_3(self, hidden_states: torch.Tensor, residual: torch.Tensor):
if self.should_capture_graph:
graph_key = (hidden_states.shape[0] + self.graph_token_step - 1) // self.graph_token_step * self.graph_token_step
graph_data = self.graph_runners_fwd3.get(graph_key) if self.cuda_graphs_captured else None
use_graph = graph_data is not None and hidden_states.shape[0] <= self.max_graph_tokens
if use_graph:
graph, static_hs_view, static_residual_view, static_ffn_output, static_router_logits = graph_data
actual_tokens = hidden_states.shape[0]
static_hs_view[:actual_tokens].copy_(hidden_states)
static_residual_view[:actual_tokens].copy_(residual)
graph.replay()
return static_ffn_output[:actual_tokens], static_router_logits[:actual_tokens] if static_router_logits is not None else None
# Fallback to eager execution
return self._forward_3_impl(hidden_states, residual)
def forward(self, positions: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
if self.should_capture_graph:
residual = hidden_states
q, k, v = self.forward_1(positions, hidden_states)
attn_output = self.self_attn.attn(q, k, v)
hidden_states, residual = self.forward_2(attn_output, residual)
ffn_output_plus_residual, router_logits = self.forward_3(hidden_states, residual)
if self.use_moe:
moe_output = self.moe.experts(hidden_states, router_logits)
hidden_states = ffn_output_plus_residual + moe_output
else:
hidden_states = ffn_output_plus_residual
return hidden_states
else:
return self.forward_old(positions, hidden_states)
def forward_old(self, positions: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
hidden_states += residual
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
if self.use_moe:
share_output = self.share_expert(hidden_states)
moe_output = self.moe(hidden_states)
ffn_output = share_output + moe_output
else:
ffn_output = self.mlp(hidden_states)
hidden_states = ffn_output + residual
return hidden_states
class Step2MiniModel(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = "", use_fused_moe: bool = False) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.vocab_size = config.vocab_size
self.config = config
self.use_fused_moe = use_fused_moe
if get_pp_group().is_first_rank or (config.tie_word_embeddings
and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
)
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Step2MiniDecoderLayer(config=vllm_config.
model_config,
cache_config=cache_config,
quant_config=quant_config,
use_fused_moe=self.use_fused_moe,
prefix=prefix),
prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
})
hidden_states = self.norm(hidden_states)
return hidden_states
@support_torch_compile
class Step3FlashModelFusedMoE(Step2MiniModel):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config, prefix, use_fused_moe=True)
class Step2MiniPretrainedModel(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
qkv_params_mapping = [
# (param_name, shard_name, relative_start_idx, relative_end_idx)
(".qkv_proj", ".q_proj", 0, self.config.share_q_dim / (self.config.share_q_dim + self.config.head_dim * 2)),
(".qkv_proj", ".k_proj", self.config.share_q_dim / (self.config.share_q_dim + self.config.head_dim * 2), (self.config.share_q_dim + self.config.head_dim) / (self.config.share_q_dim + self.config.head_dim * 2)),
(".qkv_proj", ".v_proj", (self.config.share_q_dim + self.config.head_dim) / (self.config.share_q_dim + self.config.head_dim * 2), (self.config.share_q_dim + self.config.head_dim * 2) / (self.config.share_q_dim + self.config.head_dim * 2)),
]
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params = set()
params_need_to_load = set()
if self.model.use_fused_moe:
if self.vllm_config.quant_config is not None and self.vllm_config.quant_config.get_name() == "groupwise_quant":
expert_params_mapping = [
(".moe.experts.w13_weight", ".moe.gate_proj.qweight", "w1"),
(".moe.experts.w13_weight", ".moe.up_proj.qweight", "w3"),
(".moe.experts.w2_weight", ".moe.down_proj.qweight", "w2"),
(".moe.experts.w13_weight_scale", ".moe.gate_proj.scales", "w1"),
(".moe.experts.w13_weight_scale", ".moe.up_proj.scales", "w3"),
(".moe.experts.w2_weight_scale", ".moe.down_proj.scales","w2"),
]
else:
expert_params_mapping = [
(".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"),
(".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"),
(".moe.experts.w2_weight", ".moe.down_proj.weight", "w2")
]
else:
expert_params_mapping = []
disable_moe_stacked_params = [data[1] for data in expert_params_mapping]
for name, loaded_weight in weights:
# continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
if any(disable_moe_stacked_param in name for disable_moe_stacked_param in disable_moe_stacked_params):
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name,shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
for expert_id in range(loaded_weight.shape[0]):
loaded_weight_expert = loaded_weight[expert_id]
weight_loader(param,
loaded_weight_expert,
name,
shard_id=shard_id,
expert_id=expert_id)
loaded_params.add(name)
break
else:
for (param_name, weight_name, start_idx, end_idx) in qkv_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
dim = param.shape[param.output_dim]
begin_idx = int(start_idx * dim)
end_idx = int(end_idx * dim)
param_slice = param.narrow(param.output_dim,begin_idx,end_idx-begin_idx)
param_slice.copy_(loaded_weight)
loaded_params.add(name)
break
else:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
for name in params_dict:
params_need_to_load.add(name)
if params_need_to_load != loaded_params:
param_name_example = list(params_need_to_load - loaded_params)[0]
raise RuntimeError(
f"Some parameters like {param_name_example} are not in the checkpoint and will falsely use random initialization"
)
class Step2MiniForCausalLM(Step2MiniPretrainedModel):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
config = vllm_config.model_config.hf_config
lora_config = vllm_config.lora_config
self.config = config
self.vllm_config = vllm_config
# FIXME: hack for step3 flash model
if self.config.num_hidden_layers == 42:
self.model = Step2MiniModel(vllm_config=vllm_config, prefix=prefix)
else:
self.model = Step3FlashModelFusedMoE(vllm_config=vllm_config, prefix=prefix)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
if not lora_config else lora_config.lora_vocab_padding_size,
)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
need_fp32_logits=False)
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None):
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
class Step2MiniForSequenceClassification(Step2MiniPretrainedModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
self.model = Step2MiniModel(vllm_config, prefix=prefix)
if get_pp_group().is_last_rank:
self.score = ReplicatedLinear(self.config.hidden_size,
self.config.num_labels,
bias=False)
pooler_config = vllm_config.model_config.pooler_config
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.ALL,
normalize=False,
softmax=False)
else:
self._pooler = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> SamplerOutput:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
logits, _ = self.score(hidden_states)
ret = self._pooler(logits, pooling_metadata)
return ret
def sequence_flops(self, input_length, context_length):
output_flops = 1 * self.config.hidden_size * self.config.num_labels * 2.0 / 1e12
return super().sequence_flops(input_length,
context_length) + output_flops
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
import math
from typing import Iterable, Optional, Tuple
import torch
import torchvision
#from optimus import flash_attn_func
from torch import nn
from torch.nn import functional as F
from torchvision.transforms.functional import InterpolationMode
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import OptimusLayerNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs import CLIPVisionConfig
def get_abs_pos(abs_pos, tgt_size):
dim = abs_pos.size(-1)
abs_pos_new = abs_pos.squeeze(0)
cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]
src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
old_pos_embed = old_pos_embed.view(1, src_size, src_size,
dim).permute(0, 3, 1,
2).contiguous()
old_pos_embed = old_pos_embed.to(torch.float32)
new_pos_embed = F.interpolate(
old_pos_embed,
size=(tgt_size, tgt_size),
mode='bicubic',
antialias=True,
align_corners=False,
).to(dtype)
new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1,
dim)
return vision_pos_embed
else:
return abs_pos
class StepCLIPVisionEmbeddings(nn.Module):
def __init__(self, config: CLIPVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(torch.randn(1, self.embed_dim))
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=True,
)
self.num_patches = (self.image_size // self.patch_size)**2
self.pad_tp_size = 4 # hard code for padding
# To load the pretrained weights, we still use P+1 as the seqlen
self.position_embedding = torch.nn.Embedding(self.num_patches + 1,
self.embed_dim)
self.register_buffer("position_ids",
torch.arange(self.num_patches + 1).expand(
(1, -1)),
persistent=False)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
patch_embeds = self.patch_embedding(
pixel_values) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
# pad
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + get_abs_pos(
self.position_embedding(self.position_ids), patch_embeds.size(1))
embeddings = torch.cat([
embeddings[:, 0, :].unsqueeze(1).repeat(1, self.pad_tp_size - 1,
1), embeddings
],
dim=1)
return embeddings
class StepCLIPAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
need_dp: bool = False):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.total_num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.total_num_heads
self.scale = self.head_dim**-0.5
if not need_dp:
tp_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.qkv_proj = QKVParallelLinear(self.embed_dim,
self.head_dim,
self.total_num_heads,
bias=True,
quant_config=quant_config,
prefix=prefix)
self.out_proj = RowParallelLinear(self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=prefix)
else:
self.num_heads = self.total_num_heads
self.qkv_proj = ReplicatedLinear(
self.embed_dim,
self.embed_dim * 3,
bias=True,
quant_config=quant_config,
prefix=prefix,
)
self.out_proj = ReplicatedLinear(
self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=prefix,
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
residual=None,
layernorm=None,
):
"""Input shape: Batch x Time x Channel"""
if layernorm is not None:
hidden_states = layernorm(hidden_states)
bsz, tgt_len, _ = hidden_states.size()
# get query proj
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q = q.view(bsz, tgt_len, self.num_heads, self.head_dim)
k = k.view(bsz, tgt_len, self.num_heads, self.head_dim)
v = v.view(bsz, tgt_len, self.num_heads, self.head_dim)
# if self.head_dim % 16 != 0 or (self.head_dim != 64
# and self.head_dim != 128):
if True:
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
attn_output = F.scaled_dot_product_attention(q,
k,
v,
scale=self.scale,
is_causal=False)
attn_output = attn_output.transpose(1, 2).reshape(
bsz, tgt_len, self.num_heads * self.head_dim)
# else:
# attn_output = flash_attn_func(q,
# k,
# v,
# softmax_scale=self.scale,
# causal=False)
# attn_output = attn_output.view(bsz, tgt_len,
# self.num_heads * self.head_dim)
attn_output, _ = self.out_proj(attn_output, residual=residual)
return attn_output
class StepCLIPMLP(nn.Module):
def __init__(self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
need_dp: bool = False):
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
if not need_dp:
self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=prefix)
self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=prefix)
else:
self.fc1 = ReplicatedLinear(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=prefix)
self.fc2 = ReplicatedLinear(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=prefix)
def forward(self,
hidden_states: torch.Tensor,
residual=None,
layernorm=None) -> torch.Tensor:
if layernorm is not None:
hidden_states = layernorm(hidden_states)
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states, residual=residual)
return hidden_states
class StepCLIPEncoderLayer(nn.Module):
def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
need_dp: bool = False):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = StepCLIPAttention(config,
quant_config,
prefix=f"{prefix}.self_attn",
need_dp=need_dp)
self.layer_norm1 = OptimusLayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.mlp = StepCLIPMLP(config,
quant_config,
prefix=f"{prefix}.mlp",
need_dp=need_dp)
self.layer_norm2 = OptimusLayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.FloatTensor:
residual = self.layer_norm1(
self.self_attn(hidden_states=hidden_states,
residual=None,
layernorm=None))
h = hidden_states + residual
out = h + self.layer_norm2(self.mlp(h))
return out
class StepCLIPEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`CLIPEncoderLayer`].
Args:
config: CLIPConfig
"""
def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
need_dp: bool = False):
super().__init__()
self.config = config
self.layers = nn.ModuleList([
StepCLIPEncoderLayer(config,
quant_config,
prefix=f"{prefix}.layers.{i}",
need_dp=need_dp)
for i in range(config.num_hidden_layers)
])
def forward(
self,
inputs_embeds,
):
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states, )
return hidden_states
class StepCLIPVisionTransformer(nn.Module):
def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
need_dp: bool = False):
super().__init__()
self.config = config
self.image_size = config.image_size
self.vision_model_preprocessor = torchvision.transforms.Resize(
(self.image_size, self.image_size),
interpolation=InterpolationMode.BICUBIC,
antialias=True)
self.embeddings = StepCLIPVisionEmbeddings(config)
self.transformer = StepCLIPEncoder(config,
quant_config,
prefix=f"{prefix}.transformer",
need_dp=need_dp)
def forward(
self,
pixel_values: torch.Tensor,
):
hidden_states = self.embeddings(pixel_values)
hidden_states = self.transformer(inputs_embeds=hidden_states)
return hidden_states, None
class StepCLIPVisionModel(nn.Module):
_PARAMS_KEYS_TO_SELECT = ["vision_model"]
def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
need_dp: bool = False):
super().__init__()
quant_config = None # FIXME(ys): step encoder does not support quantization
self.vision_model = StepCLIPVisionTransformer(
config,
quant_config,
prefix=f"{prefix}.vision_model",
need_dp=need_dp)
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
):
return self.vision_model(pixel_values=pixel_values)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
_params_to_ignore = [
"text_model", "logit_scale",
"vision_model.embeddings.position_ids", "visual_projection.weight",
"text_projection.weight"
]
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params = set()
for name, loaded_weight in weights:
if any(param_name in name for param_name in _params_to_ignore):
continue
if not (any(param_name in name
for param_name in self._PARAMS_KEYS_TO_SELECT)):
continue
if name.startswith("model.vision_tower.vision_tower"):
name = name.replace("model.vision_tower.vision_tower.", "")
elif name.startswith("model.vision_tower"):
name = name.replace("model.vision_tower.", "")
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name.split("."):
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
params_need_to_load = set(params_dict.keys())
if params_need_to_load != loaded_params:
param_name_example = list(params_need_to_load - loaded_params)[0]
raise RuntimeError(
f"Some parameters like {param_name_example} are not in the checkpoint and will falsely use random initialization"
)
class StepCLIPVisionModelWithPostprocess(StepCLIPVisionModel):
_PARAMS_KEYS_TO_SELECT = ["vision_model", "vit_downsampler"]
def __init__(self, config: CLIPVisionConfig, need_dp: bool = True):
super().__init__(config, need_dp=need_dp)
self.config = config
self.vit_downsampler = nn.Conv2d(self.config.hidden_size,
self.config.output_hidden_size,
kernel_size=2,
stride=2)
self.vit_downsampler2 = nn.Conv2d(
self.config.output_hidden_size,
self.config.output_hidden_size * 2,
kernel_size=3,
stride=2,
padding=1,
)
def forward(self, x: torch.Tensor):
x = super().forward(x)[0][:, 4:]
B, P = x.shape[:2]
HW = int(math.sqrt(P))
x = x.permute(0, 2, 1).view(B, self.config.hidden_size, HW, HW)
x = self.vit_downsampler(x)
x = self.vit_downsampler2(x)
x = x.view(B, self.config.output_hidden_size * 2, -1).permute(0, 2, 1)
return x
\ No newline at end of file
......@@ -41,6 +41,15 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
OvisConfig, RWConfig,
Step3TextConfig, Step3VLConfig,
SkyworkR1VChatConfig, SolarConfig,
MMGPTStep1Config,
MMGPTStep1ConfigV2, MPTConfig,
NemotronConfig, NVLM_D_Config,
RWConfig, SkyworkR1VChatConfig,
SolarConfig, Step1AudioConfig,
Step1Config, Step1oConfig,
Step2Config, Step2MiniConfig,
Step3vConfig,
StepAudioQwen2Config,
Telechat2Config, UltravoxConfig)
# yapf: enable
from vllm.transformers_utils.utils import check_gguf_file
......@@ -75,6 +84,20 @@ _CONFIG_REGISTRY_OVERRIDE_HF: dict[str, type[PretrainedConfig]] = {
"mllama": MllamaConfig
}
_CUSTOM_CONFIG_STEP = {
"step1": Step1Config,
"step2": Step2Config,
"step2_mini": Step2MiniConfig,
"mmgpt_step1": MMGPTStep1Config,
"mmgpt_step1_v2": MMGPTStep1ConfigV2,
#"mmgpt_qwen2": MMGPTQwen2Config,
#"mmgpt_qwen2_v2": MMGPTQwen2ConfigV2,
"step1o": Step1oConfig,
"step1_audio": Step1AudioConfig,
"step_audio_qwen2": StepAudioQwen2Config,
"step3v": Step3vConfig,
}
_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
"chatglm": ChatGLMConfig,
"cohere2": Cohere2Config,
......@@ -100,7 +123,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
"ultravox": UltravoxConfig,
"step3_vl": Step3VLConfig,
"step3_text": Step3TextConfig,
**_CONFIG_REGISTRY_OVERRIDE_HF
**_CONFIG_REGISTRY_OVERRIDE_HF,
**_CUSTOM_CONFIG_STEP
}
_CONFIG_ATTRS_MAPPING: dict[str, str] = {
......
......@@ -32,6 +32,18 @@ from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig,
Step3VisionEncoderConfig,
Step3VLConfig)
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.transformers_utils.configs.mmgpt import (CLIPVisionConfig,
MMGPTQwen2Config,
MMGPTQwen2ConfigV2,
MMGPTStep1Config,
MMGPTStep1ConfigV2,
SamViTConfig, Step1oConfig,
Step3vConfig)
from vllm.transformers_utils.configs.step import (Step1Config, Step2Config,
Step2MiniConfig)
from vllm.transformers_utils.configs.step1f import (Step1AudioConfig,
Step1fAudioEncoderConfig,
StepAudioQwen2Config)
__all__ = [
"ChatGLMConfig",
......@@ -62,4 +74,21 @@ __all__ = [
"Step3VLConfig",
"Step3VisionEncoderConfig",
"Step3TextConfig",
"Step1Config",
"Step2Config",
"Step2MiniConfig",
"CLIPVisionConfig",
"MMGPTBaiChuanConfig",
"MMGPTLlamaConfig",
"MMGPTLlamaConfigV2",
"MMGPTQwen2Config",
"MMGPTQwen2ConfigV2",
"MMGPTStep1Config",
"MMGPTStep1ConfigV2",
"Step3vConfig",
"SamViTConfig",
"Step1oConfig",
"Step1AudioConfig",
"Step1fAudioEncoderConfig",
"StepAudioQwen2Config",
]
# SPDX-License-Identifier: Apache-2.0
from typing import Any, List, Optional, Union
from transformers import Qwen2Config
from transformers.configuration_utils import PretrainedConfig
from vllm.transformers_utils.configs.step import Step1Config, Step2MiniConfig
class CLIPVisionConfig(PretrainedConfig):
model_type = "clip_vision_model"
def __init__(
self,
hidden_size=768,
intermediate_size=3072,
projection_dim=512,
num_hidden_layers=12,
num_attention_heads=12,
num_channels=3,
image_size=224,
patch_size=32,
hidden_act="quick_gelu",
layer_norm_eps=1e-5,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.projection_dim = projection_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.attention_dropout = attention_dropout
self.initializer_range = initializer_range
self.initializer_factor = initializer_factor
class SamViTConfig(PretrainedConfig):
model_type = "sam_vit_model"
def __init__(
self,
depth=24,
embed_dim=1024,
image_size=1280,
mlp_ratio=4,
num_heads=16,
patch_size=16,
qkv_bias=True,
use_abs_pos=True,
use_rel_pos=True,
global_attn_indexes=(5, 11, 17, 23),
window_size=14,
out_channels=256,
layer_norm_eps=1e-6,
**kwargs,
):
super().__init__(**kwargs)
self.depth = depth
self.embed_dim = embed_dim
self.image_size = image_size
self.mlp_ratio = mlp_ratio
self.num_heads = num_heads
self.patch_size = patch_size
self.qkv_bias = qkv_bias
self.use_abs_pos = use_abs_pos
self.use_rel_pos = use_rel_pos
self.global_attn_indexes = global_attn_indexes
self.window_size = window_size
self.out_channels = out_channels
self.layer_norm_eps = layer_norm_eps
class MMGPTStep1Config(Step1Config):
# for step1.5
model_type = "mmgpt_step1"
def __init__(
self,
hidden_size: int = 5120,
intermediate_size: int = 13312,
num_attention_heads: int = 40,
num_attention_groups: int = 8,
num_hidden_layers: int = 48,
max_seq_len: int = 4096,
vocab_size: int = 65536,
rms_norm_eps: float = 1e-5,
use_im_start_end=True,
vision_select_layer=-2,
image_token_len=None,
projector_stride=1,
vision_tower_config=None,
image_token_id=13,
image_seq_length=400,
bos_token_id: Optional[Union[List[int], int]] = None,
eos_token_id: Optional[Union[List[int], int]] = None,
**kwargs,
) -> None:
super().__init__(
bos_token_id=1 if bos_token_id is None else bos_token_id,
eos_token_id=[2, 3] if eos_token_id is None else eos_token_id,
**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_attention_groups = num_attention_groups
self.num_hidden_layers = num_hidden_layers
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.rms_norm_eps = rms_norm_eps
self.use_im_start_end = use_im_start_end
self.vision_select_layer = vision_select_layer
self.image_token_len = image_token_len
self.projector_stride = projector_stride
self.image_token_id = image_token_id
self.image_seq_length = image_seq_length
self.vision_tower_config = CLIPVisionConfig(
**vision_tower_config) if vision_tower_config is not None else None
self.text_config = Step1Config(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_attention_groups=num_attention_groups,
num_hidden_layers=num_hidden_layers,
max_seq_len=max_seq_len,
vocab_size=vocab_size,
rms_norm_eps=rms_norm_eps,
architectures=["Step1ForCausalLM"],
torch_dtype=getattr(self, "torch_dtype", "bfloat16"),
)
class MMGPTStep1ConfigV2(Step1Config):
# for step1.5c/step1u, models with both vit and sam encoders
model_type = "mmgpt_step1_v2"
def __init__(
self,
hidden_size: int = 5120,
intermediate_size: int = 13312,
num_attention_heads: int = 40,
num_attention_groups: int = 8,
num_hidden_layers: int = 48,
max_seq_len: int = 4096,
vocab_size: int = 65536,
rms_norm_eps: float = 1e-5,
use_im_start_end=True,
vision_select_layer=-1,
image_token_len=None,
understand_projector_stride=1,
vit_scale=1.0,
projector_bias=True,
vision_tower_config=None,
sam_model_config=None,
image_token_id=13,
bos_token_id: Optional[Union[List[int], int]] = None,
eos_token_id: Optional[Union[List[int], int]] = None,
**kwargs,
) -> None:
super().__init__(
bos_token_id=1 if bos_token_id is None else bos_token_id,
eos_token_id=[2, 3] if eos_token_id is None else eos_token_id,
**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_attention_groups = num_attention_groups
self.num_hidden_layers = num_hidden_layers
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.rms_norm_eps = rms_norm_eps
self.use_im_start_end = use_im_start_end
self.vision_select_layer = vision_select_layer
self.image_token_len = image_token_len
self.image_token_id = image_token_id
self.understand_projector_stride = understand_projector_stride
self.vit_scale = vit_scale
self.projector_bias = projector_bias
self.vision_tower_config = CLIPVisionConfig(
**vision_tower_config) if vision_tower_config is not None else None
self.sam_model_config = SamViTConfig(
**sam_model_config) if sam_model_config is not None else None
self.text_config = Step1Config(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_attention_groups=num_attention_groups,
num_hidden_layers=num_hidden_layers,
max_seq_len=max_seq_len,
vocab_size=vocab_size,
rms_norm_eps=rms_norm_eps,
architectures=["Step1ForCausalLM"],
torch_dtype=getattr(self, "torch_dtype", "bfloat16"),
)
class Step1oConfig(Step1Config):
# for step1o
model_type = "step1o"
def __init__(
self,
hidden_size: int = 5120,
intermediate_size: int = 13312,
num_attention_heads: int = 40,
num_attention_groups: int = 8,
num_hidden_layers: int = 48,
max_seq_len: int = 4096,
vocab_size: int = 65536,
rms_norm_eps: float = 1e-5,
use_im_start_end=True,
vision_select_layer=-1,
image_token_len=None,
image_token_id=13,
understand_projector_stride=1,
vit_scale=1.0,
projector_bias=True,
patch_token_len=None,
vision_tower_config=None,
bos_token_id: Optional[Union[List[int], int]] = None,
eos_token_id: Optional[Union[List[int], int]] = None,
**kwargs,
) -> None:
super().__init__(
bos_token_id=1 if bos_token_id is None else bos_token_id,
eos_token_id=[2, 3] if eos_token_id is None else eos_token_id,
**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_attention_groups = num_attention_groups
self.num_hidden_layers = num_hidden_layers
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.rms_norm_eps = rms_norm_eps
self.use_im_start_end = use_im_start_end
self.vision_select_layer = vision_select_layer
self.image_token_len = image_token_len
self.image_token_id = image_token_id
self.understand_projector_stride = understand_projector_stride
self.vit_scale = vit_scale
self.projector_bias = projector_bias
self.patch_token_len = patch_token_len if patch_token_len is not None else self.image_token_len
self.vision_tower_config = CLIPVisionConfig(
**vision_tower_config) if vision_tower_config is not None else None
self.text_config = Step1Config(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_attention_groups=num_attention_groups,
num_hidden_layers=num_hidden_layers,
max_seq_len=max_seq_len,
vocab_size=vocab_size,
rms_norm_eps=rms_norm_eps,
architectures=["Step1ForCausalLM"],
torch_dtype=getattr(self, "torch_dtype", "bfloat16"),
)
class MMGPTQwen2Config(PretrainedConfig):
# for step1.5t
model_type = "mmgpt_qwen2"
def __init__(
self,
vocab_size=64012,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=48,
num_attention_heads=32,
num_attention_groups=4,
num_key_value_heads=4,
hidden_act="silu",
max_position_embeddings=8192,
initializer_range=0.02,
rms_norm_eps=1e-6,
rope_theta=1000000.0,
rope_scaling=None,
use_im_start_end=True,
vision_select_layer=-1,
image_token_len=None,
image_token_id=151656,
understand_projector_stride=1,
vit_scale=1.0,
projector_bias=True,
pad_token_id=-1,
vision_tower_config=None,
sam_model_config=None,
eos_token_id=None,
**kwargs,
) -> None:
if eos_token_id is not None:
if isinstance(eos_token_id, list):
eos_token_id = list(set([151643, 151646] + eos_token_id))
else:
eos_token_id = [151643, 151646, eos_token_id]
else:
eos_token_id = [151643, 151646]
super().__init__(
eos_token_id=eos_token_id,
**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_attention_groups = num_attention_groups
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.use_im_start_end = use_im_start_end
self.vision_select_layer = vision_select_layer
self.image_token_len = image_token_len
self.image_token_id = image_token_id
self.understand_projector_stride = understand_projector_stride
self.vit_scale = vit_scale
self.projector_bias = projector_bias
self.pad_token_id = pad_token_id
self.vision_tower_config = CLIPVisionConfig(
**vision_tower_config) if vision_tower_config is not None else None
self.sam_model_config = SamViTConfig(
**sam_model_config) if sam_model_config is not None else None
self.text_config = Qwen2Config(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
hidden_act=hidden_act,
max_position_embeddings=max_position_embeddings,
initializer_range=initializer_range,
rms_norm_eps=rms_norm_eps,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
architectures=["Qwen2ForCausalLM"],
torch_dtype=getattr(self, "torch_dtype", "bfloat16"),
)
class MMGPTQwen2ConfigV2(MMGPTQwen2Config):
model_type = "mmgpt_qwen2_v2"
def __init__(
self,
vocab_size=64012,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=48,
num_attention_heads=32,
num_attention_groups=4,
num_key_value_heads=4,
hidden_act="silu",
max_position_embeddings=8192,
initializer_range=0.02,
rms_norm_eps=1e-6,
rope_theta=1000000.0,
rope_scaling=None,
use_im_start_end=True,
vision_select_layer=-1,
image_token_len=None,
image_token_id=151675,
understand_projector_stride=1,
vit_scale=1.0,
projector_bias=True,
pad_token_id=-1,
vision_tower_config=None,
sam_model_config=None,
eos_token_id=None,
**kwargs,
) -> None:
if eos_token_id is not None:
if isinstance(eos_token_id, list):
eos_token_id = list(set([151643, 151645, 151665] + eos_token_id))
else:
eos_token_id = [151643, 151645, 151665, eos_token_id]
else:
eos_token_id = [151643, 151645, 151665]
super().__init__(
eos_token_id=eos_token_id,
**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_attention_groups = num_attention_groups
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.use_im_start_end = use_im_start_end
self.vision_select_layer = vision_select_layer
self.image_token_len = image_token_len
self.image_token_id = image_token_id
self.understand_projector_stride = understand_projector_stride
self.vit_scale = vit_scale
self.projector_bias = projector_bias
self.pad_token_id = pad_token_id
self.vision_tower_config = CLIPVisionConfig(
**vision_tower_config) if vision_tower_config is not None else None
self.sam_model_config = SamViTConfig(
**sam_model_config) if sam_model_config is not None else None
self.text_config = Qwen2Config(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
hidden_act=hidden_act,
max_position_embeddings=max_position_embeddings,
initializer_range=initializer_range,
rms_norm_eps=rms_norm_eps,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
architectures=["Qwen2ForCausalLM"],
torch_dtype=getattr(self, "torch_dtype", "bfloat16"),
)
class Step3vConfig(Step1Config):
model_type = "step3v"
def __init__(
self,
hidden_size: int = 5120,
intermediate_size: int = 13312,
num_attention_heads: int = 40,
num_attention_groups: int = 8,
num_hidden_layers: int = 48,
max_seq_len: int = 4096,
vocab_size: int = 65536,
rms_norm_eps: float = 1e-5,
moe_every_n_layer:
int = 2, # 2 means 50% layers use MoE, interleaved with normal non-MoE layers.
use_moe: bool = False,
moe_intermediate_size: int = 10240,
moe_num_experts: int = 16,
moe_top_k: int = 4,
max_pos_interp_ratio: float = 1,
alibi_slopes: Optional[List[float]] = None,
moe_layer_offset: int = 0,
moe_dynamic_exp_p: float = 1.0,
rope_theta: float = 500000,
rope_scaling: Optional[dict[str, Any]] = None,
head_dim: Optional[int] = None,
max_position_embedding: int = 16384,
share_expert_dim: Optional[int] = None,
allgather_dtype: Optional[str] = None,
share_q_dim: Optional[int] = None,
norm_expert_weight: bool = True,
moe_layers_enum: Optional[str] = None,
use_im_start_end: bool = True,
vision_select_layer: int = -1,
image_token_len: Optional[int] = None,
image_token_id: int = 128001,
understand_projector_stride: int = 1,
vit_scale: float = 1.0,
projector_bias: bool = True,
patch_token_len: Optional[int] = None,
vision_tower_config: Optional[dict[str, Any]] = None,
bos_token_id: Optional[Union[List[int], int]] = None,
eos_token_id: Optional[Union[List[int], int]] = None,
**kwargs,
) -> None:
super().__init__(
bos_token_id=0 if bos_token_id is None else bos_token_id,
eos_token_id=[1, 128805] if eos_token_id is None else eos_token_id,
**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_attention_groups = num_attention_groups
self.num_hidden_layers = num_hidden_layers
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.rms_norm_eps = rms_norm_eps
self.moe_every_n_layer = moe_every_n_layer
self.use_moe = use_moe
self.moe_intermediate_size = moe_intermediate_size
self.moe_num_experts = moe_num_experts
self.moe_top_k = moe_top_k
self.max_pos_interp_ratio = max_pos_interp_ratio
self.alibi_slopes = alibi_slopes
self.moe_layer_offset = moe_layer_offset
self.moe_dynamic_exp_p = moe_dynamic_exp_p
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.head_dim = head_dim
self.max_position_embedding = max_position_embedding
self.share_expert_dim = share_expert_dim
self.allgather_dtype = allgather_dtype
self.share_q_dim = share_q_dim
self.norm_expert_weight = norm_expert_weight
self.use_im_start_end = use_im_start_end
self.vision_select_layer = vision_select_layer
self.image_token_len = image_token_len
self.image_token_id = image_token_id
self.understand_projector_stride = understand_projector_stride
self.vit_scale = vit_scale
self.projector_bias = projector_bias
self.patch_token_len = patch_token_len if patch_token_len is not None else self.image_token_len
self.vision_tower_config = CLIPVisionConfig(
**vision_tower_config) if vision_tower_config is not None else None
self.text_config = Step2MiniConfig(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_attention_groups=num_attention_groups,
num_hidden_layers=num_hidden_layers,
max_seq_len=max_seq_len,
vocab_size=vocab_size,
rms_norm_eps=rms_norm_eps,
moe_every_n_layer=moe_every_n_layer,
use_moe=use_moe,
moe_intermediate_size=moe_intermediate_size,
moe_num_experts=moe_num_experts,
moe_top_k=moe_top_k,
max_pos_interp_ratio=max_pos_interp_ratio,
alibi_slopes=alibi_slopes,
moe_layer_offset=moe_layer_offset,
moe_dynamic_exp_p=moe_dynamic_exp_p,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
head_dim=head_dim,
max_position_embedding=max_position_embedding,
share_expert_dim=share_expert_dim,
allgather_dtype=allgather_dtype,
share_q_dim=share_q_dim,
norm_expert_weight=norm_expert_weight,
moe_layers_enum=moe_layers_enum,
architectures=["Step2MiniForCausalLM"],
torch_dtype=getattr(self, "torch_dtype", "bfloat16"),
)
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional, Union
from transformers import PretrainedConfig
class StepConfig(PretrainedConfig):
model_type = "step"
def __init__(
self,
hidden_size: int = 5120,
intermediate_size: int = 13312,
num_attention_heads: int = 40,
num_attention_groups: int = 8,
num_hidden_layers: int = 48,
max_seq_len: int = 4096,
vocab_size: int = 65536,
rms_norm_eps: float = 1e-5,
moe_every_n_layer:
int = 2, # 2 means 50% layers use MoE, interleaved with normal non-MoE layers.
use_moe: bool = False,
moe_intermediate_size: int = 10240,
moe_num_experts: int = 16,
moe_top_k: int = 4,
max_pos_interp_ratio: float = 1,
alibi_slopes: Optional[List[float]] = None,
moe_layer_offset: int = 0,
moe_dynamic_exp_p: float = 1.0,
rope_theta: float = 500000,
rope_scaling: Optional[Dict[str, Any]] = None,
head_dim: Optional[int] = None,
max_position_embedding: int = 16384,
share_expert_dim: Optional[int] = None,
allgather_dtype: Optional[str] = None,
share_q_dim: Optional[int] = None,
norm_expert_weight: bool = True,
bos_token_id: Optional[Union[List[int], int]] = None,
eos_token_id: Optional[Union[List[int], int]] = None,
**kwargs,
) -> None:
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_attention_groups = num_attention_groups
self.num_hidden_layers = num_hidden_layers
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.rms_norm_eps = rms_norm_eps
self.use_moe = use_moe
self.moe_intermediate_size = moe_intermediate_size
self.moe_every_n_layer = moe_every_n_layer
self.moe_num_experts = moe_num_experts
self.moe_top_k = moe_top_k
self.max_pos_interp_ratio = max_pos_interp_ratio
self.alibi_slopes = alibi_slopes
self.moe_layer_offset = moe_layer_offset
self.moe_dynamic_exp_p = moe_dynamic_exp_p
#for step2 mini
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.head_dim = head_dim
self.max_position_embedding = max_position_embedding
if share_expert_dim is None:
self.share_expert_dim = self.moe_intermediate_size * self.moe_top_k
else:
self.share_expert_dim = share_expert_dim
self.share_q_dim = share_q_dim
self.norm_expert_weight = norm_expert_weight
self.allgather_dtype = allgather_dtype
self._verify_slopes()
super().__init__(
bos_token_id=1 if bos_token_id is None else bos_token_id,
eos_token_id=[2, 3] if eos_token_id is None else eos_token_id,
**kwargs)
def _verify_slopes(self):
if self.alibi_slopes is None:
return
if len(self.alibi_slopes) != self.num_attention_heads:
raise ValueError(
f"Number of alibi_slopes ({len(self.alibi_slopes)}) does not match num_attention_heads ({self.num_attention_heads})"
)
class Step1Config(StepConfig):
model_type = "step1"
class Step2Config(StepConfig):
model_type = "step2"
def __init__(self, use_offline_input_scales: bool = True, **kwargs):
self.use_offline_input_scales = use_offline_input_scales
super().__init__(**kwargs)
class Step2MiniConfig(StepConfig):
model_type = "step2_mini"
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
from transformers import Qwen2Config
from transformers.configuration_utils import PretrainedConfig
from vllm.transformers_utils.configs.step import Step1Config
class Step1fAudioEncoderConfig(PretrainedConfig):
model_type = "stepasr_encoder"
def __init__(
self,
n_mels: int = 128,
n_audio_ctx: int = 1500,
n_audio_state: int = 1280,
n_audio_head: int = 20,
n_audio_layer: int = 32,
n_codebook_size: int = 4096,
llm_dim: int = 3072,
kernel_size: int = 3,
adapter_stride: int = 2,
adapter_state: int = 2048,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.n_mels = n_mels
self.n_audio_ctx = n_audio_ctx
self.n_audio_state = n_audio_state
self.n_audio_head = n_audio_head
self.n_audio_layer = n_audio_layer
self.n_codebook_size = n_codebook_size
self.llm_dim = llm_dim
self.kernel_size = kernel_size
self.adapter_stride = adapter_stride
self.adapter_state = adapter_state
class Step1AudioConfig(PretrainedConfig):
# for step1.5t
model_type = "step1_audio"
def __init__(
self,
hidden_size: int = 5120,
intermediate_size: int = 13312,
num_attention_heads: int = 40,
num_attention_groups: int = 8,
num_hidden_layers: int = 48,
max_seq_len: int = 4096,
vocab_size: int = 65536,
rms_norm_eps: float = 1e-5,
audio_token_id: int = 29,
eos_token_id=None,
audio_encoder_config=None,
**kwargs,
) -> None:
if eos_token_id is not None:
if isinstance(eos_token_id, list):
eos_token_id = list(set([2, 3] + eos_token_id))
else:
eos_token_id = [2, 3, eos_token_id]
else:
eos_token_id = [2, 3]
super().__init__(
eos_token_id=eos_token_id,
**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_attention_groups = num_attention_groups
self.max_seq_len = max_seq_len
self.rms_norm_eps = rms_norm_eps
self.audio_token_id = audio_token_id
self.audio_encoder_config = Step1fAudioEncoderConfig(
**audio_encoder_config) if audio_encoder_config is not None else None
self.text_config = Step1Config(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_attention_groups=num_attention_groups,
num_hidden_layers=num_hidden_layers,
max_seq_len=max_seq_len,
vocab_size=vocab_size,
rms_norm_eps=rms_norm_eps,
architectures=["Step1ForCausalLM"],
torch_dtype=getattr(self, "torch_dtype", "bfloat16"),
)
class StepAudioQwen2Config(PretrainedConfig):
model_type = "step_audio_qwen2"
def __init__(
self,
vocab_size=64012,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=48,
num_attention_heads=32,
num_attention_groups=4,
num_key_value_heads=4,
hidden_act="silu",
max_position_embeddings=8192,
initializer_range=0.02,
rms_norm_eps=1e-6,
rope_theta=1000000.0,
rope_scaling=None,
audio_token_id=151690,
eos_token_id=None,
audio_encoder_config=None,
**kwargs
):
if eos_token_id is not None:
if isinstance(eos_token_id, list):
eos_token_id = list(set([151643, 151645, 151665] + eos_token_id))
else:
eos_token_id = [151643, 151645, 151665, eos_token_id]
else:
eos_token_id = [151643, 151645, 151665]
super().__init__(
eos_token_id=eos_token_id,
**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_attention_groups = num_attention_groups
self.num_key_value_heads = num_key_value_heads
assert self.num_attention_groups == self.num_key_value_heads, "num_attention_groups must be equal to num_key_value_heads"
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.audio_encoder_config = Step1fAudioEncoderConfig(
**audio_encoder_config) if audio_encoder_config is not None else None
self.audio_token_id = audio_token_id
self.text_config = Qwen2Config(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
hidden_act=hidden_act,
max_position_embeddings=max_position_embeddings,
initializer_range=initializer_range,
rms_norm_eps=rms_norm_eps,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
architectures=["Qwen2ForCausalLM"],
torch_dtype=getattr(self, "torch_dtype", "bfloat16"),
)
\ No newline at end of file
......@@ -4,6 +4,8 @@
from typing import Optional
from .tokenizer import AnyTokenizer
# from vllm.transformers_utils.tokenizers.sentencepiece_tokenizer import (
# SentencePieceTokenizer)
def _replace_none_with_empty(tokens: list[Optional[str]]):
......@@ -171,6 +173,13 @@ def detokenize_incrementally(
# The prefix text is necessary only to defeat cleanup algorithms in
# the decode which decide to add a space or not depending on the
# surrounding ids.
# FIXME(ys): for step1 sentencepiece tokenizer, we need to handle the special tokens in convert_tokens_to_string
# if isinstance(tokenizer, SentencePieceTokenizer):
# prefix_text = tokenizer.convert_tokens_to_string(
# output_tokens[prefix_offset:read_offset], skip_special_tokens=skip_special_tokens)
# new_text = tokenizer.convert_tokens_to_string(
# output_tokens[prefix_offset:], skip_special_tokens=skip_special_tokens)
if tokenizer.is_fast or not tokenizer.get_added_vocab():
prefix_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_offset:read_offset])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment