Commit 8cdc3a30 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix qwen3-next nn layout

parent 440222e9
......@@ -44,6 +44,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
import vllm.envs as envs
# Added by the IBM Team, 2024
......@@ -171,6 +172,15 @@ def mamba_v2_sharded_weight_loader(
# - track boundary of (sharded) param, and loaded_weight, respectively
boundary, loaded_boundary = 0, 0
if envs.VLLM_USE_NN:
loaded_total_dim = sum(full_dim - extra
for full_dim, extra, _ in shard_spec)
param_out_axis = 0 if param.dim() == 1 else (param.dim() - 1)
loaded_out_axis = 0
if (loaded_weight.dim() > 1 and loaded_weight.shape[-1] == loaded_total_dim
and loaded_weight.shape[0] != loaded_total_dim):
loaded_out_axis = loaded_weight.dim() - 1
# - iterate over the shard specs
for full_dim, extra, duplicate_groups in shard_spec:
# - full dim is the model dim (before TP).
......@@ -201,6 +211,32 @@ def mamba_v2_sharded_weight_loader(
# - the ignore is for a mundane mypy error as it does not
# seem to handle slices well.
# https://github.com/python/mypy/issues/2410
if envs.VLLM_USE_NN:
if take > 0:
param_slice = param.data.narrow(param_out_axis, boundary, take)
loaded_slice = loaded_weight.narrow(loaded_out_axis,
loaded_start_idx, take)
if (param_slice.dim() == loaded_slice.dim() + 1
and param_slice.shape[1] == 1):
loaded_slice = loaded_slice.unsqueeze(1)
elif (loaded_slice.dim() == param_slice.dim() + 1
and loaded_slice.shape[1] == 1):
loaded_slice = loaded_slice.squeeze(1)
if param_slice.shape != loaded_slice.shape:
loaded_slice = loaded_slice.permute(*reversed(range(loaded_slice.dim())))
if param_slice.shape != loaded_slice.shape:
raise RuntimeError(
"mamba_v2_sharded_weight_loader shape mismatch: "
f"param_slice={tuple(param_slice.shape)} "
f"loaded_slice={tuple(loaded_slice.shape)} "
f"(param_out_axis={param_out_axis}, "
f"loaded_out_axis={loaded_out_axis})")
param_slice.copy_(loaded_slice)
else:
param.data[
boundary : (boundary + take), ... # type: ignore[misc]
] = loaded_weight[
......@@ -428,6 +464,10 @@ class MambaMixer2(MambaBase, CustomOp):
# `ColumnParallelLinear` and `MergedColumnParallelLinear`,
# and `set_weight_attrs` doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
if envs.VLLM_USE_NN:
conv_weights = self.conv1d.weight.squeeze(1).transpose(
0, 1).contiguous()
else:
conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)
......
......@@ -95,6 +95,7 @@ from .utils import (
make_layers,
maybe_prefix,
)
import vllm.envs as envs
logger = init_logger(__name__)
......@@ -533,6 +534,10 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
a = a[:num_actual_tokens]
# 1. Convolution sequence transformation
if envs.VLLM_USE_NN:
conv_weights = self.conv1d.weight.squeeze(1).transpose(
0, 1).contiguous()
else:
conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment