voyage.py 4.82 KB
Newer Older
chengchengpei's avatar
chengchengpei committed
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations

from collections import defaultdict
from collections.abc import Iterable

Harry Mellor's avatar
Harry Mellor committed
8
import regex as re
chengchengpei's avatar
chengchengpei committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
import torch.nn as nn

from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen3 import Qwen3Model
from vllm.model_executor.models.utils import WeightsMapper

WeightItem = tuple[str, torch.Tensor]

_LAYER_RE = re.compile(r"^layers\.(\d+)\.(.+)$")


class VoyageQwen3BidirectionalEmbedModel(Qwen3Model):
    """
    Qwen3Model + Voyage embedding head + bidirectional attention.

    Checkpoint conventions (HF):
      - MLP: gate_proj + up_proj (unfused)
      - Attn: q_proj + k_proj + v_proj (unfused)
      - Linear head: linear.weight
      - Weights prefixed with "model." (e.g., model.layers.0...)

    vLLM Qwen3Model expects:
      - mlp.gate_up_proj (fused)
      - self_attn.qkv_proj (fused)
      - No "model." prefix

    We remap/fuse weights using generator pipeline and load directly
    (bypassing parent's stacked_params_mapping which would cause
    double-transformation like qkv_proj -> qkqkv_proj).
    """

    hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # Embedding head (hidden_size -> num_labels, bias=False)
        self.linear = nn.Linear(
            self.config.hidden_size,
            self.config.num_labels,
            bias=False,
        )

    def forward(self, *args, **kwargs):
        out = super().forward(*args, **kwargs)
        return self.linear(out)

    def _fuse_qkv_proj(self, weights: Iterable[WeightItem]) -> Iterable[WeightItem]:
        """Fuse q_proj, k_proj, v_proj into qkv_proj."""
        qkv_buf: dict[int, dict[str, torch.Tensor]] = defaultdict(dict)
        qkv_suffixes = {
            "self_attn.q_proj.weight": "q",
            "self_attn.k_proj.weight": "k",
            "self_attn.v_proj.weight": "v",
        }

        for name, tensor in weights:
            m = _LAYER_RE.match(name)
            if m and m.group(2) in qkv_suffixes:
                layer_idx = int(m.group(1))
                qkv_buf[layer_idx][qkv_suffixes[m.group(2)]] = tensor
            else:
                yield name, tensor

        # Yield fused QKV weights
        for layer_idx in sorted(qkv_buf.keys()):
            parts = qkv_buf[layer_idx]
            if all(p in parts for p in ("q", "k", "v")):
                fused = torch.cat([parts["q"], parts["k"], parts["v"]], dim=0)
                yield f"layers.{layer_idx}.self_attn.qkv_proj.weight", fused
            elif parts:
                missing = [p for p in ("q", "k", "v") if p not in parts]
                raise ValueError(f"Layer {layer_idx} missing QKV parts: {missing}")

    def _fuse_gate_up_proj(self, weights: Iterable[WeightItem]) -> Iterable[WeightItem]:
        """Fuse gate_proj and up_proj into gate_up_proj."""
        mlp_buf: dict[int, dict[str, torch.Tensor]] = defaultdict(dict)
        mlp_suffixes = {
            "mlp.gate_proj.weight": "gate",
            "mlp.up_proj.weight": "up",
        }

        for name, tensor in weights:
            m = _LAYER_RE.match(name)
            if m and m.group(2) in mlp_suffixes:
                layer_idx = int(m.group(1))
                mlp_buf[layer_idx][mlp_suffixes[m.group(2)]] = tensor
            else:
                yield name, tensor

        # Yield fused gate_up weights
        for layer_idx in sorted(mlp_buf.keys()):
            parts = mlp_buf[layer_idx]
            if all(p in parts for p in ("gate", "up")):
                fused = torch.cat([parts["gate"], parts["up"]], dim=0)
                yield f"layers.{layer_idx}.mlp.gate_up_proj.weight", fused
            elif parts:
                missing = [p for p in ("gate", "up") if p not in parts]
                raise ValueError(f"Layer {layer_idx} missing MLP parts: {missing}")

    def load_weights(self, weights: Iterable[WeightItem]) -> set[str]:
        """Remap, fuse, and load weights using generator pipeline."""
        # Chain weight transformations
        weights = self.hf_to_vllm_mapper.apply(weights)
        weights = self._fuse_qkv_proj(weights)
        weights = self._fuse_gate_up_proj(weights)

        # Load weights directly into model parameters
        # (bypass parent's stacked_params_mapping)
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            if name not in params_dict:
                continue
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)

        return loaded_params