Unverified Commit 42efe609 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[MM][Bugfix] Replace `PatchEmbed`'s conv3d to linear layer (#27418)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
parent 88d3141e
...@@ -60,6 +60,7 @@ from vllm.model_executor.layers.linear import ( ...@@ -60,6 +60,7 @@ from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -98,7 +99,11 @@ from .utils import ( ...@@ -98,7 +99,11 @@ from .utils import (
init_vllm_registered_model, init_vllm_registered_model,
maybe_prefix, maybe_prefix,
) )
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend,
run_dp_sharded_mrope_vision_model,
)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -478,18 +483,15 @@ class Glm4vVisionPatchEmbed(nn.Module): ...@@ -478,18 +483,15 @@ class Glm4vVisionPatchEmbed(nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size) kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d( self.proj = ReplicatedLinear(
in_channels, in_channels * math.prod(kernel_size),
hidden_size, hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True, bias=True,
return_bias=False,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape x = self.proj(x)
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
return x return x
...@@ -887,6 +889,9 @@ class Glm4vVisionTransformer(nn.Module): ...@@ -887,6 +889,9 @@ class Glm4vVisionTransformer(nn.Module):
loaded_params: set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" """Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
import math
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import Annotated, Any, Literal, TypeAlias from typing import Annotated, Any, Literal, TypeAlias
...@@ -56,6 +57,7 @@ from vllm.model_executor.layers.linear import ( ...@@ -56,6 +57,7 @@ from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -98,7 +100,11 @@ from .utils import ( ...@@ -98,7 +100,11 @@ from .utils import (
init_vllm_registered_model, init_vllm_registered_model,
maybe_prefix, maybe_prefix,
) )
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend,
run_dp_sharded_mrope_vision_model,
)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -532,18 +538,15 @@ class Qwen2_5_VisionPatchEmbed(nn.Module): ...@@ -532,18 +538,15 @@ class Qwen2_5_VisionPatchEmbed(nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size) kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d( self.proj = ReplicatedLinear(
in_channels, in_channels * math.prod(kernel_size),
hidden_size, hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=False, bias=False,
return_bias=False,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape x = self.proj(x)
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
return x return x
...@@ -950,6 +953,9 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -950,6 +953,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
loaded_params: set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights.""" """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
import math
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial from functools import partial
from typing import Annotated, Any, Literal, TypeAlias from typing import Annotated, Any, Literal, TypeAlias
...@@ -53,7 +54,11 @@ from vllm.distributed import parallel_state, tensor_model_parallel_all_gather ...@@ -53,7 +54,11 @@ from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import ( from vllm.model_executor.layers.rotary_embedding.common import (
dispatch_rotary_emb_function, dispatch_rotary_emb_function,
...@@ -100,7 +105,11 @@ from .utils import ( ...@@ -100,7 +105,11 @@ from .utils import (
init_vllm_registered_model, init_vllm_registered_model,
maybe_prefix, maybe_prefix,
) )
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend,
run_dp_sharded_mrope_vision_model,
)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -561,18 +570,15 @@ class Qwen2VisionPatchEmbed(nn.Module): ...@@ -561,18 +570,15 @@ class Qwen2VisionPatchEmbed(nn.Module):
self.embed_dim = embed_dim self.embed_dim = embed_dim
kernel_size = (temporal_patch_size, patch_size, patch_size) kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d( self.proj = ReplicatedLinear(
in_channels, in_channels * math.prod(kernel_size),
embed_dim, embed_dim,
kernel_size=kernel_size,
stride=kernel_size,
bias=False, bias=False,
return_bias=False,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape x = self.proj(x)
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.embed_dim)
return x return x
...@@ -835,6 +841,9 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -835,6 +841,9 @@ class Qwen2VisionTransformer(nn.Module):
loaded_params: set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen3-Omni-Moe model (thinker part).""" """Inference-only Qwen3-Omni-Moe model (thinker part)."""
import math
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial from functools import partial
from typing import Any from typing import Any
...@@ -53,7 +54,11 @@ from vllm.config import VllmConfig ...@@ -53,7 +54,11 @@ from vllm.config import VllmConfig
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
...@@ -98,7 +103,11 @@ from .utils import ( ...@@ -98,7 +103,11 @@ from .utils import (
_merge_multimodal_embeddings, _merge_multimodal_embeddings,
maybe_prefix, maybe_prefix,
) )
from .vision import get_llm_pos_ids_for_vision, get_vit_attn_backend from .vision import (
conv3d_to_linear_weight,
get_llm_pos_ids_for_vision,
get_vit_attn_backend,
)
try: try:
import flash_attn import flash_attn
...@@ -131,18 +140,16 @@ class Qwen3_VisionPatchEmbed(nn.Module): ...@@ -131,18 +140,16 @@ class Qwen3_VisionPatchEmbed(nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size) kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d( self.proj = ReplicatedLinear(
in_channels, in_channels * math.prod(kernel_size),
hidden_size, hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True, bias=True,
return_bias=False,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) x = self.proj(x)
x = self.proj(x).view(L, self.hidden_size)
return x return x
...@@ -559,6 +566,9 @@ class Qwen3Omni_VisionTransformer(nn.Module): ...@@ -559,6 +566,9 @@ class Qwen3Omni_VisionTransformer(nn.Module):
loaded_params: set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen3VL model compatible with HuggingFace weights.""" """Inference-only Qwen3VL model compatible with HuggingFace weights."""
import math
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial from functools import partial
from itertools import islice from itertools import islice
...@@ -56,7 +57,11 @@ from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions ...@@ -56,7 +57,11 @@ from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
...@@ -107,7 +112,11 @@ from .utils import ( ...@@ -107,7 +112,11 @@ from .utils import (
_merge_multimodal_embeddings, _merge_multimodal_embeddings,
maybe_prefix, maybe_prefix,
) )
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend,
run_dp_sharded_mrope_vision_model,
)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -129,18 +138,15 @@ class Qwen3_VisionPatchEmbed(nn.Module): ...@@ -129,18 +138,15 @@ class Qwen3_VisionPatchEmbed(nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size) kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d( self.proj = ReplicatedLinear(
in_channels, in_channels * math.prod(kernel_size),
hidden_size, hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True, bias=True,
return_bias=False,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape x = self.proj(x)
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
return x return x
...@@ -576,6 +582,9 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -576,6 +582,9 @@ class Qwen3_VisionTransformer(nn.Module):
loaded_params: set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
...@@ -544,3 +544,19 @@ def get_llm_pos_ids_for_vision( ...@@ -544,3 +544,19 @@ def get_llm_pos_ids_for_vision(
llm_pos_ids_list.append(_llm_pos_ids + start_idx) llm_pos_ids_list.append(_llm_pos_ids + start_idx)
llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
return llm_pos_ids return llm_pos_ids
# Due to a performance regression with Conv3D in PyTorch2.9, we reshape
# Conv3D weights to Linear weights for better performance.
# See: https://github.com/vllm-project/vllm/issues/27406
# and https://github.com/pytorch/pytorch/issues/166122
# FIXME(Isotr0py): Revert the PR introduces this workaround
# (https://github.com/vllm-project/vllm/pull/27418),
# once the performance issue is resolved in PyTorch.
def conv3d_to_linear_weight(conv3d_weight: torch.Tensor) -> torch.Tensor:
"""
Reshape Conv3D weight to Linear weight. Only work when kernel_size==stride.
"""
out_channels, in_channels, kt, kh, kw = conv3d_weight.shape
linear_weight = conv3d_weight.reshape(out_channels, in_channels * kt * kh * kw)
return linear_weight
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment