"vscode:/vscode.git/clone" did not exist on "0f621c2c7dfe409b6e24e8810dc039745b9a8a7a"
Unverified Commit 41b92f7d authored by Shanshan Shen's avatar Shanshan Shen Committed by GitHub
Browse files

[Model][MM] Extract conv layer as CustomOp (#28455)


Signed-off-by: default avatarshen-shanshan <467638484@qq.com>
Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 360bd876
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Conv Layer Class."""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm.model_executor.custom_op import CustomOp
from vllm.utils.torch_utils import is_torch_equal
class ConvLayerBase(CustomOp):
"""Conv layer base class."""
num_dim: int
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, ...],
stride: int | tuple[int, ...] = 1,
padding: int | tuple[int, ...] = 0,
dilation: int | tuple[int, ...] = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
*,
params_dtype: torch.dtype | None = None,
) -> None:
super().__init__()
if params_dtype is None:
params_dtype = torch.get_default_dtype()
kernel_size = (
(kernel_size,) * self.num_dim
if isinstance(kernel_size, int)
else kernel_size
)
stride = (stride,) * self.num_dim if isinstance(stride, int) else stride
padding = (padding,) * self.num_dim if isinstance(padding, int) else padding
dilation = (dilation,) * self.num_dim if isinstance(dilation, int) else dilation
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.padding_mode = padding_mode
self.enable_linear = (
(self.kernel_size == self.stride)
and not any(self.padding)
and self.groups == 1
)
self.input_size = in_channels * math.prod(self.kernel_size)
self.weight = nn.Parameter(
torch.empty(
out_channels,
in_channels // groups,
*kernel_size,
dtype=params_dtype,
),
)
if bias:
self.bias = nn.Parameter(torch.empty(self.out_channels, dtype=params_dtype))
else:
self.register_parameter("bias", None)
def extra_repr(self) -> str:
s = f"in_channels={self.in_channels}, "
s += f"out_channels={self.out_channels}, "
s += f"kernel_size={self.kernel_size}, "
s += f"stride={self.stride}, "
s += f"padding={self.padding}, "
s += f"bias={self.bias is not None}"
return s
@CustomOp.register("conv2d")
class Conv2dLayer(ConvLayerBase):
"""Conv layer with Conv2d."""
num_dim = 2
def _forward_mulmat(self, x: torch.Tensor) -> torch.Tensor:
assert x.dim() == 4
B, C, H, W = x.shape
K1, K2 = self.kernel_size
H, W = H // K1, W // K2
x = x.unfold(2, K1, K1).unfold(3, K2, K2)
x = x.permute(0, 2, 3, 1, 4, 5).reshape(-1, self.input_size)
x = F.linear(
x,
self.weight.view(self.out_channels, self.input_size),
self.bias,
)
x = x.view(B, H, W, self.out_channels).permute(0, 3, 1, 2)
return x
def _forward_conv(self, x: torch.Tensor) -> torch.Tensor:
assert x.dim() == 4
x = F.conv2d(
x,
self.weight,
self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)
return x
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""Expected input shape: (batch_size, in_channels, height, width)"""
assert x.dim() == 4
if self.enable_linear:
return self._forward_mulmat(x)
else:
return self._forward_conv(x)
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
# By default, we use CUDNN's convolution ops with optimization.
return self._forward_conv(x)
class CausalConv2dLayer(Conv2dLayer):
"""
A causal version of nn.Conv2d where each location in the 2D matrix would
have no access to locations on its right or down
All arguments are the same as nn.Conv2d except padding which should be
set as None
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
*,
params_dtype: torch.dtype | None = None,
) -> None:
if padding is not None:
raise ValueError(
"Argument padding should be set to None for CausalConv2dLayer."
)
self._left_padding: int = kernel_size - 1
self._right_padding: int = stride - 1
padding = 0
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
params_dtype=params_dtype,
)
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
x = F.pad(x, pad=(self._left_padding, self._right_padding, 0, 0))
x = super().forward(x)
return x
@CustomOp.register("conv3d")
class Conv3dLayer(ConvLayerBase):
"""Conv layer with Conv3d."""
num_dim = 3
def _forward_mulmat(self, x: torch.Tensor) -> torch.Tensor:
assert x.dim() == 5
B, C, T, H, W = x.shape
K1, K2, K3 = self.kernel_size
T, H, W = T // K1, H // K2, W // K3
x = x.unfold(2, K1, K1).unfold(3, K2, K2).unfold(4, K3, K3)
x = x.permute(0, 2, 3, 4, 1, 5, 6, 7).reshape(-1, self.input_size)
x = F.linear(
x,
self.weight.view(self.out_channels, self.input_size),
self.bias,
)
x = x.view(B, T, H, W, self.out_channels).permute(0, 4, 1, 2, 3)
return x
def _forward_conv(self, x: torch.Tensor) -> torch.Tensor:
assert x.dim() == 5
x = F.conv3d(
x,
self.weight,
self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)
return x
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""Expected input shape: (batch_size, in_channels, time, height, width)"""
if self.enable_linear:
return self._forward_mulmat(x)
else:
return self._forward_conv(x)
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
# PyTorch2.9.0 disabled CUDNN's Conv3D, which caused a
# significant performance regression.
# See: https://github.com/vllm-project/vllm/issues/27406
# and https://github.com/pytorch/pytorch/issues/166122
# By default, we use CUDNN's convolution ops with optimization.
if self.enable_linear and is_torch_equal("2.9.0"):
return self._forward_mulmat(x)
return self._forward_conv(x)
...@@ -20,6 +20,7 @@ from vllm.config import VllmConfig ...@@ -20,6 +20,7 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -315,7 +316,7 @@ class CLIPVisionEmbeddings(nn.Module): ...@@ -315,7 +316,7 @@ class CLIPVisionEmbeddings(nn.Module):
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
self.patch_embedding = nn.Conv2d( self.patch_embedding = Conv2dLayer(
in_channels=config.num_channels, in_channels=config.num_channels,
out_channels=self.embed_dim, out_channels=self.embed_dim,
kernel_size=self.patch_size, kernel_size=self.patch_size,
......
...@@ -56,12 +56,12 @@ from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions ...@@ -56,12 +56,12 @@ from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.conv import Conv3dLayer
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -103,7 +103,6 @@ from .utils import ( ...@@ -103,7 +103,6 @@ from .utils import (
maybe_prefix, maybe_prefix,
) )
from .vision import ( from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend, get_vit_attn_backend,
run_dp_sharded_mrope_vision_model, run_dp_sharded_mrope_vision_model,
) )
...@@ -486,15 +485,18 @@ class Glm4vVisionPatchEmbed(nn.Module): ...@@ -486,15 +485,18 @@ class Glm4vVisionPatchEmbed(nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size) kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = ReplicatedLinear( self.proj = Conv3dLayer(
in_channels * math.prod(kernel_size), in_channels,
hidden_size, hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True, bias=True,
return_bias=False,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x) L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
return x return x
...@@ -893,9 +895,6 @@ class Glm4vVisionTransformer(nn.Module): ...@@ -893,9 +895,6 @@ class Glm4vVisionTransformer(nn.Module):
loaded_params: set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" """Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
import math
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import Annotated, Any, Literal, TypeAlias from typing import Annotated, Any, Literal, TypeAlias
...@@ -56,12 +55,12 @@ from vllm.distributed import utils as dist_utils ...@@ -56,12 +55,12 @@ from vllm.distributed import utils as dist_utils
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.conv import Conv3dLayer
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -110,7 +109,6 @@ from .utils import ( ...@@ -110,7 +109,6 @@ from .utils import (
maybe_prefix, maybe_prefix,
) )
from .vision import ( from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend, get_vit_attn_backend,
run_dp_sharded_mrope_vision_model, run_dp_sharded_mrope_vision_model,
) )
...@@ -525,15 +523,18 @@ class Qwen2_5_VisionPatchEmbed(nn.Module): ...@@ -525,15 +523,18 @@ class Qwen2_5_VisionPatchEmbed(nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size) kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = ReplicatedLinear( self.proj = Conv3dLayer(
in_channels * math.prod(kernel_size), in_channels,
hidden_size, hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=False, bias=False,
return_bias=False,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x) L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
return x return x
...@@ -957,9 +958,6 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -957,9 +958,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
loaded_params: set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights.""" """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
import math
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial from functools import partial
from typing import Annotated, Any, Literal, TypeAlias from typing import Annotated, Any, Literal, TypeAlias
...@@ -54,9 +53,9 @@ from vllm.distributed import parallel_state ...@@ -54,9 +53,9 @@ from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.conv import Conv3dLayer
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -107,7 +106,6 @@ from .utils import ( ...@@ -107,7 +106,6 @@ from .utils import (
maybe_prefix, maybe_prefix,
) )
from .vision import ( from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend, get_vit_attn_backend,
run_dp_sharded_mrope_vision_model, run_dp_sharded_mrope_vision_model,
) )
...@@ -566,15 +564,18 @@ class Qwen2VisionPatchEmbed(nn.Module): ...@@ -566,15 +564,18 @@ class Qwen2VisionPatchEmbed(nn.Module):
self.embed_dim = embed_dim self.embed_dim = embed_dim
kernel_size = (temporal_patch_size, patch_size, patch_size) kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = ReplicatedLinear( self.proj = Conv3dLayer(
in_channels * math.prod(kernel_size), in_channels,
embed_dim, embed_dim,
kernel_size=kernel_size,
stride=kernel_size,
bias=False, bias=False,
return_bias=False,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x) L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.embed_dim)
return x return x
...@@ -844,9 +845,6 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -844,9 +845,6 @@ class Qwen2VisionTransformer(nn.Module):
loaded_params: set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen3-Omni-Moe model (thinker part).""" """Inference-only Qwen3-Omni-Moe model (thinker part)."""
import math
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial from functools import partial
from typing import Any from typing import Any
...@@ -54,9 +53,9 @@ from vllm.config import VllmConfig ...@@ -54,9 +53,9 @@ from vllm.config import VllmConfig
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.conv import Conv3dLayer
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
...@@ -102,7 +101,6 @@ from .utils import ( ...@@ -102,7 +101,6 @@ from .utils import (
maybe_prefix, maybe_prefix,
) )
from .vision import ( from .vision import (
conv3d_to_linear_weight,
get_llm_pos_ids_for_vision, get_llm_pos_ids_for_vision,
get_vit_attn_backend, get_vit_attn_backend,
) )
...@@ -138,16 +136,18 @@ class Qwen3_VisionPatchEmbed(nn.Module): ...@@ -138,16 +136,18 @@ class Qwen3_VisionPatchEmbed(nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size) kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = ReplicatedLinear( self.proj = Conv3dLayer(
in_channels * math.prod(kernel_size), in_channels,
hidden_size, hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True, bias=True,
return_bias=False,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape L, C = x.shape
x = self.proj(x) x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
return x return x
...@@ -566,9 +566,6 @@ class Qwen3Omni_VisionTransformer(nn.Module): ...@@ -566,9 +566,6 @@ class Qwen3Omni_VisionTransformer(nn.Module):
loaded_params: set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen3VL model compatible with HuggingFace weights.""" """Inference-only Qwen3VL model compatible with HuggingFace weights."""
import math
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial from functools import partial
from itertools import islice from itertools import islice
...@@ -57,9 +56,9 @@ from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions ...@@ -57,9 +56,9 @@ from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.conv import Conv3dLayer
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
...@@ -114,7 +113,6 @@ from .utils import ( ...@@ -114,7 +113,6 @@ from .utils import (
maybe_prefix, maybe_prefix,
) )
from .vision import ( from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend, get_vit_attn_backend,
run_dp_sharded_mrope_vision_model, run_dp_sharded_mrope_vision_model,
) )
...@@ -139,15 +137,18 @@ class Qwen3_VisionPatchEmbed(nn.Module): ...@@ -139,15 +137,18 @@ class Qwen3_VisionPatchEmbed(nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size) kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = ReplicatedLinear( self.proj = Conv3dLayer(
in_channels * math.prod(kernel_size), in_channels,
hidden_size, hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True, bias=True,
return_bias=False,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x) L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
return x return x
...@@ -579,9 +580,6 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -579,9 +580,6 @@ class Qwen3_VisionTransformer(nn.Module):
loaded_params: set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
...@@ -550,19 +550,3 @@ def get_llm_pos_ids_for_vision( ...@@ -550,19 +550,3 @@ def get_llm_pos_ids_for_vision(
llm_pos_ids_list.append(_llm_pos_ids + start_idx) llm_pos_ids_list.append(_llm_pos_ids + start_idx)
llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
return llm_pos_ids return llm_pos_ids
# Due to a performance regression with Conv3D in PyTorch2.9, we reshape
# Conv3D weights to Linear weights for better performance.
# See: https://github.com/vllm-project/vllm/issues/27406
# and https://github.com/pytorch/pytorch/issues/166122
# FIXME(Isotr0py): Revert the PR introduces this workaround
# (https://github.com/vllm-project/vllm/pull/27418),
# once the performance issue is resolved in PyTorch.
def conv3d_to_linear_weight(conv3d_weight: torch.Tensor) -> torch.Tensor:
"""
Reshape Conv3D weight to Linear weight. Only work when kernel_size==stride.
"""
out_channels, in_channels, kt, kh, kw = conv3d_weight.shape
linear_weight = conv3d_weight.reshape(out_channels, in_channels * kt * kh * kw)
return linear_weight
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment