"vscode:/vscode.git/clone" did not exist on "c312320764193e7d0ffa99d247c61efe5458a635"
Unverified Commit 5400014d authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Chore] Remove `use_data_parallel` kwargs from ViT implementation (#33310)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 3a92c6f3
...@@ -39,7 +39,7 @@ from vllm.model_executor.layers.linear import ( ...@@ -39,7 +39,7 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .vision import run_dp_sharded_vision_model from .vision import is_vit_use_data_parallel, run_dp_sharded_vision_model
class Idefics2VisionEmbeddings(nn.Module): class Idefics2VisionEmbeddings(nn.Module):
...@@ -126,9 +126,9 @@ class Idefics2VisionAttention(nn.Module): ...@@ -126,9 +126,9 @@ class Idefics2VisionAttention(nn.Module):
config: Idefics2VisionConfig, config: Idefics2VisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
use_data_parallel = is_vit_use_data_parallel()
self.config = config self.config = config
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
...@@ -187,11 +187,12 @@ class Idefics2VisionMLP(nn.Module): ...@@ -187,11 +187,12 @@ class Idefics2VisionMLP(nn.Module):
config: Idefics2VisionConfig, config: Idefics2VisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.activation_fn = get_act_fn(config.hidden_act) self.activation_fn = get_act_fn(config.hidden_act)
use_data_parallel = is_vit_use_data_parallel()
self.fc1 = ColumnParallelLinear( self.fc1 = ColumnParallelLinear(
config.hidden_size, config.hidden_size,
config.intermediate_size, config.intermediate_size,
...@@ -222,7 +223,6 @@ class Idefics2EncoderLayer(nn.Module): ...@@ -222,7 +223,6 @@ class Idefics2EncoderLayer(nn.Module):
config: Idefics2Config, config: Idefics2Config,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
...@@ -230,14 +230,12 @@ class Idefics2EncoderLayer(nn.Module): ...@@ -230,14 +230,12 @@ class Idefics2EncoderLayer(nn.Module):
config, config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
use_data_parallel=use_data_parallel,
) )
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Idefics2VisionMLP( self.mlp = Idefics2VisionMLP(
config, config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
) )
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
...@@ -279,7 +277,6 @@ class Idefics2Encoder(nn.Module): ...@@ -279,7 +277,6 @@ class Idefics2Encoder(nn.Module):
*, *,
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -296,7 +293,6 @@ class Idefics2Encoder(nn.Module): ...@@ -296,7 +293,6 @@ class Idefics2Encoder(nn.Module):
config, config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}", prefix=f"{prefix}.layers.{layer_idx}",
use_data_parallel=use_data_parallel,
) )
for layer_idx in range(num_hidden_layers) for layer_idx in range(num_hidden_layers)
] ]
...@@ -331,20 +327,18 @@ class Idefics2VisionTransformer(nn.Module): ...@@ -331,20 +327,18 @@ class Idefics2VisionTransformer(nn.Module):
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
require_post_norm: bool = True, require_post_norm: bool = True,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
embed_dim = config.hidden_size embed_dim = config.hidden_size
self.config = config self.config = config
self.use_data_parallel = use_data_parallel self.use_data_parallel = is_vit_use_data_parallel()
self.embeddings = Idefics2VisionEmbeddings(config) self.embeddings = Idefics2VisionEmbeddings(config)
self.encoder = Idefics2Encoder( self.encoder = Idefics2Encoder(
config, config,
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
use_data_parallel=use_data_parallel,
) )
num_hidden_layers = config.num_hidden_layers num_hidden_layers = config.num_hidden_layers
......
...@@ -34,7 +34,7 @@ from vllm.model_executor.layers.linear import ( ...@@ -34,7 +34,7 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .vision import run_dp_sharded_vision_model from .vision import is_vit_use_data_parallel, run_dp_sharded_vision_model
NORM2FN = { NORM2FN = {
"rms_norm": RMSNorm, "rms_norm": RMSNorm,
...@@ -148,7 +148,6 @@ class InternParallelAttention(nn.Module): ...@@ -148,7 +148,6 @@ class InternParallelAttention(nn.Module):
*, *,
num_dummy_heads: int = 0, num_dummy_heads: int = 0,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -163,9 +162,14 @@ class InternParallelAttention(nn.Module): ...@@ -163,9 +162,14 @@ class InternParallelAttention(nn.Module):
f" {self.num_heads})." f" {self.num_heads})."
) )
self.tp_size = ( use_data_parallel = is_vit_use_data_parallel()
1 if use_data_parallel else get_tensor_model_parallel_world_size() # if the number of heads is not divisible by tp_size,
# we also disable Attention's TP
tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
use_data_parallel = (
use_data_parallel or (self.num_heads + num_dummy_heads) % tp_size != 0
) )
self.tp_size = 1 if use_data_parallel else tp_size
self.tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank() self.tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank()
# Additional dummy heads are used to enable TP for common GPU counts. # Additional dummy heads are used to enable TP for common GPU counts.
...@@ -242,12 +246,12 @@ class InternMLP(nn.Module): ...@@ -242,12 +246,12 @@ class InternMLP(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.activation_fn = get_act_fn(config.hidden_act) self.activation_fn = get_act_fn(config.hidden_act)
use_data_parallel = is_vit_use_data_parallel()
self.fc1 = ColumnParallelLinear( self.fc1 = ColumnParallelLinear(
config.hidden_size, config.hidden_size,
config.intermediate_size, config.intermediate_size,
...@@ -281,11 +285,9 @@ class InternVisionEncoderLayer(nn.Module): ...@@ -281,11 +285,9 @@ class InternVisionEncoderLayer(nn.Module):
*, *,
num_dummy_heads: int = 0, num_dummy_heads: int = 0,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
attn_cls: type[InternParallelAttention] = InternParallelAttention, attn_cls: type[InternParallelAttention] = InternParallelAttention,
) -> None: ) -> None:
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.intermediate_size = config.intermediate_size self.intermediate_size = config.intermediate_size
self.norm_type = config.norm_type self.norm_type = config.norm_type
...@@ -296,14 +298,12 @@ class InternVisionEncoderLayer(nn.Module): ...@@ -296,14 +298,12 @@ class InternVisionEncoderLayer(nn.Module):
quant_config, quant_config,
num_dummy_heads=num_dummy_heads, num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
) )
self.mlp = InternMLP( self.mlp = InternMLP(
config, config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
) )
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
...@@ -318,23 +318,12 @@ class InternVisionEncoderLayer(nn.Module): ...@@ -318,23 +318,12 @@ class InternVisionEncoderLayer(nn.Module):
*, *,
num_dummy_heads: int, num_dummy_heads: int,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
# fallback to sdpa attention if tp unavailable
tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
# if the number of heads is not divisible by tp_size,
# we also disable Attention's TP
use_data_parallel = (
use_data_parallel or (num_heads + num_dummy_heads) % tp_size != 0
)
return self.attn_cls( return self.attn_cls(
config, config,
quant_config=quant_config, quant_config=quant_config,
num_dummy_heads=num_dummy_heads, num_dummy_heads=num_dummy_heads,
prefix=prefix, prefix=prefix,
use_data_parallel=use_data_parallel,
) )
def forward( def forward(
...@@ -357,7 +346,6 @@ class InternVisionEncoder(nn.Module): ...@@ -357,7 +346,6 @@ class InternVisionEncoder(nn.Module):
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
num_dummy_heads: int = 0, num_dummy_heads: int = 0,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
layer_cls: type[InternVisionEncoderLayer] = InternVisionEncoderLayer, layer_cls: type[InternVisionEncoderLayer] = InternVisionEncoderLayer,
): ):
super().__init__() super().__init__()
...@@ -377,7 +365,6 @@ class InternVisionEncoder(nn.Module): ...@@ -377,7 +365,6 @@ class InternVisionEncoder(nn.Module):
quant_config, quant_config,
num_dummy_heads=num_dummy_heads, num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.layers.{layer_idx}", prefix=f"{prefix}.layers.{layer_idx}",
use_data_parallel=use_data_parallel,
) )
for layer_idx in range(num_hidden_layers) for layer_idx in range(num_hidden_layers)
] ]
...@@ -404,12 +391,11 @@ class InternVisionModel(nn.Module): ...@@ -404,12 +391,11 @@ class InternVisionModel(nn.Module):
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
num_dummy_heads: int = 0, num_dummy_heads: int = 0,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.use_data_parallel = use_data_parallel self.use_data_parallel = is_vit_use_data_parallel()
self.embeddings = InternVisionEmbeddings(config) self.embeddings = InternVisionEmbeddings(config)
self.encoder = InternVisionEncoder( self.encoder = InternVisionEncoder(
...@@ -418,7 +404,6 @@ class InternVisionModel(nn.Module): ...@@ -418,7 +404,6 @@ class InternVisionModel(nn.Module):
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
num_dummy_heads=num_dummy_heads, num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
use_data_parallel=use_data_parallel,
) )
def get_input_embeddings(self): def get_input_embeddings(self):
......
...@@ -1153,7 +1153,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA) ...@@ -1153,7 +1153,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers, num_hidden_layers_override=num_hidden_layers,
prefix=prefix, prefix=prefix,
use_data_parallel=self.use_data_parallel,
) )
else: else:
return InternVisionPatchModel(config.vision_config) return InternVisionPatchModel(config.vision_config)
......
...@@ -81,7 +81,7 @@ from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig ...@@ -81,7 +81,7 @@ from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
from .vision import run_dp_sharded_mrope_vision_model from .vision import is_vit_use_data_parallel, run_dp_sharded_mrope_vision_model
# For dummy input only # For dummy input only
...@@ -93,10 +93,12 @@ class MaxImageTokenMeta: ...@@ -93,10 +93,12 @@ class MaxImageTokenMeta:
class KimiVLMultiModalProjector(nn.Module): class KimiVLMultiModalProjector(nn.Module):
def __init__( def __init__(
self, config: KimiVLConfig, use_data_parallel: bool = False, prefix: str = "" self,
config: KimiVLConfig,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.use_data_parallel = use_data_parallel self.use_data_parallel = is_vit_use_data_parallel()
self.hidden_size = ( self.hidden_size = (
config.vision_config.hidden_size config.vision_config.hidden_size
...@@ -321,7 +323,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -321,7 +323,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
) )
self.multi_modal_projector = KimiVLMultiModalProjector( self.multi_modal_projector = KimiVLMultiModalProjector(
config=config, config=config,
use_data_parallel=self.use_data_parallel,
prefix=maybe_prefix(prefix, "multi_modal_projector"), prefix=maybe_prefix(prefix, "multi_modal_projector"),
) )
......
...@@ -57,6 +57,7 @@ from .utils import ( ...@@ -57,6 +57,7 @@ from .utils import (
init_vllm_registered_model, init_vllm_registered_model,
maybe_prefix, maybe_prefix,
) )
from .vision import is_vit_use_data_parallel
class Lfm2VLImagePixelInputs(TensorSchema): class Lfm2VLImagePixelInputs(TensorSchema):
...@@ -426,10 +427,12 @@ class Lfm2VLMultiModalProcessor(BaseMultiModalProcessor[Lfm2VLProcessingInfo]): ...@@ -426,10 +427,12 @@ class Lfm2VLMultiModalProcessor(BaseMultiModalProcessor[Lfm2VLProcessingInfo]):
class Lfm2VLMultiModalProjector(nn.Module): class Lfm2VLMultiModalProjector(nn.Module):
def __init__( def __init__(
self, config: Lfm2VlConfig, use_data_parallel: bool = False, prefix: str = "" self,
config: Lfm2VlConfig,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.use_data_parallel = use_data_parallel self.use_data_parallel = is_vit_use_data_parallel()
in_channels = config.vision_config.hidden_size * (config.downsample_factor**2) in_channels = config.vision_config.hidden_size * (config.downsample_factor**2)
self.factor = config.downsample_factor self.factor = config.downsample_factor
...@@ -607,7 +610,6 @@ class Lfm2VLForConditionalGeneration( ...@@ -607,7 +610,6 @@ class Lfm2VLForConditionalGeneration(
self.multi_modal_projector = Lfm2VLMultiModalProjector( self.multi_modal_projector = Lfm2VLMultiModalProjector(
config=config, config=config,
use_data_parallel=self.use_data_parallel,
prefix=maybe_prefix(prefix, "multi_modal_projector"), prefix=maybe_prefix(prefix, "multi_modal_projector"),
) )
......
...@@ -1335,7 +1335,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): ...@@ -1335,7 +1335,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
config.vision_config, config.vision_config,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
use_data_parallel=self.use_data_parallel,
) )
if self.config.drop_vision_last_layer: if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1] model.encoder.layers = model.encoder.layers[:-1]
...@@ -1428,7 +1427,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): ...@@ -1428,7 +1427,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
config.vision_config, config.vision_config,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
use_data_parallel=self.use_data_parallel,
) )
if self.config.drop_vision_last_layer: if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1] model.encoder.layers = model.encoder.layers[:-1]
...@@ -1526,7 +1524,6 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA): ...@@ -1526,7 +1524,6 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA):
config.vision_config, config.vision_config,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
use_data_parallel=self.use_data_parallel,
) )
if self.config.drop_vision_last_layer: if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1] model.encoder.layers = model.encoder.layers[:-1]
...@@ -1624,7 +1621,6 @@ class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA): ...@@ -1624,7 +1621,6 @@ class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA):
config.vision_config, config.vision_config,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
use_data_parallel=self.use_data_parallel,
) )
if self.config.drop_vision_last_layer: if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1] model.encoder.layers = model.encoder.layers[:-1]
......
...@@ -79,7 +79,7 @@ from .interfaces import ( ...@@ -79,7 +79,7 @@ from .interfaces import (
) )
from .llama4 import Llama4ForCausalLM from .llama4 import Llama4ForCausalLM
from .utils import AutoWeightsLoader, StageMissingLayer, maybe_prefix from .utils import AutoWeightsLoader, StageMissingLayer, maybe_prefix
from .vision import run_dp_sharded_vision_model from .vision import is_vit_use_data_parallel, run_dp_sharded_vision_model
class Llama4ImagePatchInputs(TensorSchema): class Llama4ImagePatchInputs(TensorSchema):
...@@ -124,9 +124,9 @@ class Llama4VisionMLP(nn.Module): ...@@ -124,9 +124,9 @@ class Llama4VisionMLP(nn.Module):
output_activation: bool, output_activation: bool,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
use_data_parallel = is_vit_use_data_parallel()
self.fc1 = ColumnParallelLinear( self.fc1 = ColumnParallelLinear(
input_size=input_size, input_size=input_size,
output_size=intermediate_size, output_size=intermediate_size,
...@@ -208,7 +208,6 @@ class Llama4VisionPixelShuffleMLP(nn.Module): ...@@ -208,7 +208,6 @@ class Llama4VisionPixelShuffleMLP(nn.Module):
config, config,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
self.pixel_shuffle_ratio = config.pixel_shuffle_ratio self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
...@@ -224,7 +223,6 @@ class Llama4VisionPixelShuffleMLP(nn.Module): ...@@ -224,7 +223,6 @@ class Llama4VisionPixelShuffleMLP(nn.Module):
output_activation=True, output_activation=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
) )
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor: def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
...@@ -238,10 +236,10 @@ class Llama4VisionAttention(nn.Module): ...@@ -238,10 +236,10 @@ class Llama4VisionAttention(nn.Module):
config: Llama4VisionConfig, config: Llama4VisionConfig,
quant_config: QuantizationConfig | None, quant_config: QuantizationConfig | None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
use_data_parallel = is_vit_use_data_parallel()
self.tp_size = ( self.tp_size = (
1 if use_data_parallel else get_tensor_model_parallel_world_size() 1 if use_data_parallel else get_tensor_model_parallel_world_size()
) )
...@@ -336,7 +334,6 @@ class Llama4VisionEncoderLayer(nn.Module): ...@@ -336,7 +334,6 @@ class Llama4VisionEncoderLayer(nn.Module):
config: Llama4VisionConfig, config: Llama4VisionConfig,
quant_config: QuantizationConfig | None, quant_config: QuantizationConfig | None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -347,7 +344,6 @@ class Llama4VisionEncoderLayer(nn.Module): ...@@ -347,7 +344,6 @@ class Llama4VisionEncoderLayer(nn.Module):
config, config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
use_data_parallel=use_data_parallel,
) )
self.mlp = Llama4VisionMLP( self.mlp = Llama4VisionMLP(
input_size=config.hidden_size, input_size=config.hidden_size,
...@@ -357,7 +353,6 @@ class Llama4VisionEncoderLayer(nn.Module): ...@@ -357,7 +353,6 @@ class Llama4VisionEncoderLayer(nn.Module):
output_activation=False, output_activation=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
) )
self.input_layernorm = nn.LayerNorm(config.hidden_size) self.input_layernorm = nn.LayerNorm(config.hidden_size)
...@@ -389,7 +384,6 @@ class Llama4VisionEncoder(nn.Module): ...@@ -389,7 +384,6 @@ class Llama4VisionEncoder(nn.Module):
config: Llama4VisionConfig, config: Llama4VisionConfig,
quant_config: QuantizationConfig | None, quant_config: QuantizationConfig | None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -399,7 +393,6 @@ class Llama4VisionEncoder(nn.Module): ...@@ -399,7 +393,6 @@ class Llama4VisionEncoder(nn.Module):
config, config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}", prefix=f"{prefix}.layers.{layer_idx}",
use_data_parallel=use_data_parallel,
) )
for layer_idx in range(config.num_hidden_layers) for layer_idx in range(config.num_hidden_layers)
] ]
...@@ -432,13 +425,13 @@ class Llama4UnfoldConvolution(nn.Module): ...@@ -432,13 +425,13 @@ class Llama4UnfoldConvolution(nn.Module):
config: Llama4VisionConfig, config: Llama4VisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
kernel_size = config.patch_size kernel_size = config.patch_size
if isinstance(kernel_size, int): if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size) kernel_size = (kernel_size, kernel_size)
self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size) self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size)
use_data_parallel = is_vit_use_data_parallel()
self.linear = ColumnParallelLinear( self.linear = ColumnParallelLinear(
input_size=config.num_channels * kernel_size[0] * kernel_size[1], input_size=config.num_channels * kernel_size[0] * kernel_size[1],
output_size=config.hidden_size, output_size=config.hidden_size,
...@@ -465,7 +458,6 @@ class Llama4VisionModel(nn.Module): ...@@ -465,7 +458,6 @@ class Llama4VisionModel(nn.Module):
config: Llama4VisionConfig, config: Llama4VisionConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -481,7 +473,6 @@ class Llama4VisionModel(nn.Module): ...@@ -481,7 +473,6 @@ class Llama4VisionModel(nn.Module):
config, config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.patch_embedding", prefix=f"{prefix}.patch_embedding",
use_data_parallel=use_data_parallel,
) )
self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size)) self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size))
...@@ -498,14 +489,12 @@ class Llama4VisionModel(nn.Module): ...@@ -498,14 +489,12 @@ class Llama4VisionModel(nn.Module):
config, config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.model", prefix=f"{prefix}.model",
use_data_parallel=use_data_parallel,
) )
self.vision_adapter = Llama4VisionPixelShuffleMLP( self.vision_adapter = Llama4VisionPixelShuffleMLP(
config, config,
quant_config, quant_config,
prefix=f"{prefix}.vision_adapter", prefix=f"{prefix}.vision_adapter",
use_data_parallel=use_data_parallel,
) )
def forward( def forward(
...@@ -780,7 +769,6 @@ class Llama4ForConditionalGeneration( ...@@ -780,7 +769,6 @@ class Llama4ForConditionalGeneration(
config=config.vision_config, config=config.vision_config,
quant_config=None, quant_config=None,
prefix=maybe_prefix(prefix, "vision_model"), prefix=maybe_prefix(prefix, "vision_model"),
use_data_parallel=self.use_data_parallel,
) )
self.multi_modal_projector = Llama4MultiModalProjector( self.multi_modal_projector = Llama4MultiModalProjector(
......
...@@ -54,7 +54,7 @@ from .utils import ( ...@@ -54,7 +54,7 @@ from .utils import (
init_vllm_registered_model, init_vllm_registered_model,
maybe_prefix, maybe_prefix,
) )
from .vision import run_dp_sharded_vision_model from .vision import is_vit_use_data_parallel, run_dp_sharded_vision_model
class Step3VLImagePixelInputs(TensorSchema): class Step3VLImagePixelInputs(TensorSchema):
...@@ -724,7 +724,6 @@ class Step3VisionAttention(nn.Module): ...@@ -724,7 +724,6 @@ class Step3VisionAttention(nn.Module):
config, config,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -734,6 +733,7 @@ class Step3VisionAttention(nn.Module): ...@@ -734,6 +733,7 @@ class Step3VisionAttention(nn.Module):
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
use_data_parallel = is_vit_use_data_parallel()
tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size() tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_size == 0 assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size self.num_heads = self.total_num_heads // tp_size
...@@ -786,11 +786,11 @@ class Step3VisionMLP(nn.Module): ...@@ -786,11 +786,11 @@ class Step3VisionMLP(nn.Module):
config, config,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.activation_fn = get_act_fn(config.hidden_act) self.activation_fn = get_act_fn(config.hidden_act)
use_data_parallel = is_vit_use_data_parallel()
self.fc1 = ColumnParallelLinear( self.fc1 = ColumnParallelLinear(
config.hidden_size, config.hidden_size,
config.intermediate_size, config.intermediate_size,
...@@ -821,23 +821,19 @@ class Step3VisionEncoderLayer(nn.Module): ...@@ -821,23 +821,19 @@ class Step3VisionEncoderLayer(nn.Module):
config: Step3VisionEncoderConfig, config: Step3VisionEncoderConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
self.use_data_parallel = use_data_parallel
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.self_attn = Step3VisionAttention( self.self_attn = Step3VisionAttention(
config, config,
quant_config, quant_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
use_data_parallel=self.use_data_parallel,
) )
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Step3VisionMLP( self.mlp = Step3VisionMLP(
config, config,
quant_config, quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
use_data_parallel=self.use_data_parallel,
) )
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
...@@ -856,18 +852,15 @@ class Step3VisionEncoder(nn.Module): ...@@ -856,18 +852,15 @@ class Step3VisionEncoder(nn.Module):
config: Step3VisionEncoderConfig, config: Step3VisionEncoderConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.use_data_parallel = use_data_parallel
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
Step3VisionEncoderLayer( Step3VisionEncoderLayer(
config, config,
quant_config, quant_config,
prefix=f"{prefix}.layers.{i}", prefix=f"{prefix}.layers.{i}",
use_data_parallel=self.use_data_parallel,
) )
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
...@@ -889,18 +882,16 @@ class Step3VisionTransformer(nn.Module): ...@@ -889,18 +882,16 @@ class Step3VisionTransformer(nn.Module):
config: Step3VisionEncoderConfig, config: Step3VisionEncoderConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.use_data_parallel = use_data_parallel self.use_data_parallel = is_vit_use_data_parallel()
self.image_size = config.image_size self.image_size = config.image_size
self.embeddings = Step3VisionEmbeddings(config) self.embeddings = Step3VisionEmbeddings(config)
self.transformer = Step3VisionEncoder( self.transformer = Step3VisionEncoder(
config, config,
quant_config, quant_config,
prefix=f"{prefix}.transformer", prefix=f"{prefix}.transformer",
use_data_parallel=self.use_data_parallel,
) )
def forward( def forward(
...@@ -952,7 +943,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -952,7 +943,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
config.vision_config, config.vision_config,
None, None,
prefix=maybe_prefix(prefix, "vision_model"), prefix=maybe_prefix(prefix, "vision_model"),
use_data_parallel=self.use_data_parallel,
) )
self.vit_downsampler = Conv2dLayer( self.vit_downsampler = Conv2dLayer(
config.vision_config.hidden_size, config.vision_config.hidden_size,
......
...@@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from .step3_vl import Step3VLForConditionalGeneration from .step3_vl import Step3VLForConditionalGeneration
from .utils import WeightsMapper, init_vllm_registered_model, maybe_prefix from .utils import WeightsMapper, init_vllm_registered_model, maybe_prefix
from .vision import run_dp_sharded_vision_model from .vision import is_vit_use_data_parallel, run_dp_sharded_vision_model
_DEFAULT_NORM_LAYER = partial(nn.LayerNorm, eps=1e-5) _DEFAULT_NORM_LAYER = partial(nn.LayerNorm, eps=1e-5)
...@@ -151,9 +151,9 @@ class PerceptionEncoderMLP(nn.Module): ...@@ -151,9 +151,9 @@ class PerceptionEncoderMLP(nn.Module):
act_layer: Callable[[], nn.Module], act_layer: Callable[[], nn.Module],
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
use_data_parallel = is_vit_use_data_parallel()
self.fc1 = ColumnParallelLinear( self.fc1 = ColumnParallelLinear(
input_dim, input_dim,
hidden_dim, hidden_dim,
...@@ -189,7 +189,6 @@ class PerceptionEncoderVisionAttention(nn.Module): ...@@ -189,7 +189,6 @@ class PerceptionEncoderVisionAttention(nn.Module):
use_cls_token: bool = False, use_cls_token: bool = False,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
...@@ -197,6 +196,7 @@ class PerceptionEncoderVisionAttention(nn.Module): ...@@ -197,6 +196,7 @@ class PerceptionEncoderVisionAttention(nn.Module):
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
use_data_parallel = is_vit_use_data_parallel()
tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size() tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_size == 0, ( assert self.total_num_heads % tp_size == 0, (
"embed_dim must be divisible by num_heads" "embed_dim must be divisible by num_heads"
...@@ -258,7 +258,6 @@ class PerceptionEncoderVisionBlock(nn.Module): ...@@ -258,7 +258,6 @@ class PerceptionEncoderVisionBlock(nn.Module):
use_cls_token: bool = False, use_cls_token: bool = False,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
self.attn = PerceptionEncoderVisionAttention( self.attn = PerceptionEncoderVisionAttention(
...@@ -269,7 +268,6 @@ class PerceptionEncoderVisionBlock(nn.Module): ...@@ -269,7 +268,6 @@ class PerceptionEncoderVisionBlock(nn.Module):
use_cls_token=use_cls_token, use_cls_token=use_cls_token,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
) )
self.ls_1 = ( self.ls_1 = (
PerceptionEncoderLayerScale(d_model, ls_init_value) PerceptionEncoderLayerScale(d_model, ls_init_value)
...@@ -290,7 +288,6 @@ class PerceptionEncoderVisionBlock(nn.Module): ...@@ -290,7 +288,6 @@ class PerceptionEncoderVisionBlock(nn.Module):
act_layer, act_layer,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
) )
def forward(self, x: torch.Tensor, grid_hw: tuple[int, int]): def forward(self, x: torch.Tensor, grid_hw: tuple[int, int]):
...@@ -314,7 +311,6 @@ class PerceptionEncoderVisionTransformer(nn.Module): ...@@ -314,7 +311,6 @@ class PerceptionEncoderVisionTransformer(nn.Module):
use_cls_token: bool = False, use_cls_token: bool = False,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
self.width = width self.width = width
...@@ -333,7 +329,6 @@ class PerceptionEncoderVisionTransformer(nn.Module): ...@@ -333,7 +329,6 @@ class PerceptionEncoderVisionTransformer(nn.Module):
use_cls_token=use_cls_token, use_cls_token=use_cls_token,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.resblocks.{i}", prefix=f"{prefix}.resblocks.{i}",
use_data_parallel=use_data_parallel,
) )
for i in range(layers) for i in range(layers)
] ]
...@@ -353,7 +348,6 @@ class PerceptionEncoder(nn.Module): ...@@ -353,7 +348,6 @@ class PerceptionEncoder(nn.Module):
norm_layer: Callable = _DEFAULT_NORM_LAYER, norm_layer: Callable = _DEFAULT_NORM_LAYER,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
self.patch_size = config.patch_size self.patch_size = config.patch_size
...@@ -394,7 +388,6 @@ class PerceptionEncoder(nn.Module): ...@@ -394,7 +388,6 @@ class PerceptionEncoder(nn.Module):
use_cls_token=self.use_cls_token, use_cls_token=self.use_cls_token,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.transformer", prefix=f"{prefix}.transformer",
use_data_parallel=use_data_parallel,
) )
self.vit_downsampler1 = Conv2dLayer( self.vit_downsampler1 = Conv2dLayer(
...@@ -511,7 +504,6 @@ class StepVLForConditionalGeneration(Step3VLForConditionalGeneration): ...@@ -511,7 +504,6 @@ class StepVLForConditionalGeneration(Step3VLForConditionalGeneration):
get_act_fn(config.vision_config.hidden_act), get_act_fn(config.vision_config.hidden_act),
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "vision_model"), prefix=maybe_prefix(prefix, "vision_model"),
use_data_parallel=self.use_data_parallel,
) )
self.vit_large_projector = ColumnParallelLinear( self.vit_large_projector = ColumnParallelLinear(
config.vision_config.width * 4, config.vision_config.width * 4,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment