Commit 6cd43ae5 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix qwen3-next nn layout

parent 88411543
...@@ -36,6 +36,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -36,6 +36,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
import vllm.envs as envs
# Added by the IBM Team, 2024 # Added by the IBM Team, 2024
...@@ -160,6 +161,15 @@ def mamba_v2_sharded_weight_loader( ...@@ -160,6 +161,15 @@ def mamba_v2_sharded_weight_loader(
# - track boundary of (sharded) param, and loaded_weight, respectively # - track boundary of (sharded) param, and loaded_weight, respectively
boundary, loaded_boundary = 0, 0 boundary, loaded_boundary = 0, 0
if envs.VLLM_USE_NN:
loaded_total_dim = sum(full_dim - extra
for full_dim, extra, _ in shard_spec)
param_out_axis = 0 if param.dim() == 1 else (param.dim() - 1)
loaded_out_axis = 0
if (loaded_weight.dim() > 1 and loaded_weight.shape[-1] == loaded_total_dim
and loaded_weight.shape[0] != loaded_total_dim):
loaded_out_axis = loaded_weight.dim() - 1
# - iterate over the shard specs # - iterate over the shard specs
for full_dim, extra, duplicate_groups in shard_spec: for full_dim, extra, duplicate_groups in shard_spec:
# - full dim is the model dim (before TP). # - full dim is the model dim (before TP).
...@@ -190,12 +200,38 @@ def mamba_v2_sharded_weight_loader( ...@@ -190,12 +200,38 @@ def mamba_v2_sharded_weight_loader(
# - the ignore is for a mundane mypy error as it does not # - the ignore is for a mundane mypy error as it does not
# seem to handle slices well. # seem to handle slices well.
# https://github.com/python/mypy/issues/2410 # https://github.com/python/mypy/issues/2410
param.data[ if envs.VLLM_USE_NN:
boundary:(boundary + take), if take > 0:
... # type: ignore[misc] param_slice = param.data.narrow(param_out_axis, boundary, take)
] = loaded_weight[loaded_start_idx:(loaded_start_idx + loaded_slice = loaded_weight.narrow(loaded_out_axis,
take) # type: ignore[misc] loaded_start_idx, take)
] # type: ignore[misc]
if (param_slice.dim() == loaded_slice.dim() + 1
and param_slice.shape[1] == 1):
loaded_slice = loaded_slice.unsqueeze(1)
elif (loaded_slice.dim() == param_slice.dim() + 1
and loaded_slice.shape[1] == 1):
loaded_slice = loaded_slice.squeeze(1)
if param_slice.shape != loaded_slice.shape:
loaded_slice = loaded_slice.permute(*reversed(range(loaded_slice.dim())))
if param_slice.shape != loaded_slice.shape:
raise RuntimeError(
"mamba_v2_sharded_weight_loader shape mismatch: "
f"param_slice={tuple(param_slice.shape)} "
f"loaded_slice={tuple(loaded_slice.shape)} "
f"(param_out_axis={param_out_axis}, "
f"loaded_out_axis={loaded_out_axis})")
param_slice.copy_(loaded_slice)
else:
param.data[
boundary:(boundary + take),
... # type: ignore[misc]
] = loaded_weight[loaded_start_idx:(loaded_start_idx +
take) # type: ignore[misc]
] # type: ignore[misc]
# move indexing boundaries # move indexing boundaries
boundary += shard_size boundary += shard_size
...@@ -522,8 +558,12 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -522,8 +558,12 @@ class MambaMixer2(MambaBase, CustomOp):
dim=-1, dim=-1,
) )
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), if envs.VLLM_USE_NN:
self.conv1d.weight.size(2)) conv_weights = self.conv1d.weight.squeeze(1).transpose(
0, 1).contiguous()
else:
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
# - get hidden_states, B and C after depthwise convolution. # - get hidden_states, B and C after depthwise convolution.
split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split( split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
......
...@@ -63,6 +63,7 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, ...@@ -63,6 +63,7 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
import vllm.envs as envs
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -432,8 +433,12 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ...@@ -432,8 +433,12 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
mixed_qkv = torch.cat((query, key, value), dim=-1) mixed_qkv = torch.cat((query, key, value), dim=-1)
# 2. Convolution sequence transformation # 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), if envs.VLLM_USE_NN:
self.conv1d.weight.size(2)) conv_weights = self.conv1d.weight.squeeze(1).transpose(
0, 1).contiguous()
else:
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if spec_sequence_masks is not None: if spec_sequence_masks is not None:
if (attn_metadata.num_prefills == 0 if (attn_metadata.num_prefills == 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment