qwen2_rm.py 4.12 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
# Adapted from
# https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
"""Inference-only Qwen2-RM model compatible with HuggingFace weights."""
9

10
11
from collections.abc import Iterable
from typing import Optional, Union
12
13
14
15

import torch
from torch import nn

16
from vllm.config import VllmConfig
17
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
18
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
19
from vllm.sequence import IntermediateTensors
20

21
22
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces_base import default_pooling_type
23
from .qwen2 import Qwen2Model
24
from .utils import AutoWeightsLoader, maybe_prefix
25
26


27
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
28
    is_pooling_model = True
29
    pooler: Pooler
30

31
32
33
34
35
36
37
38
39
40
41
42
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

43
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
44
45
46
47
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
48
49
50
51
52

        self.config = config
        self.lora_config = lora_config

        self.quant_config = quant_config
53
54
55
        self.model = Qwen2Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
56
        self.head_dtype = vllm_config.model_config.head_dtype
57
58

        self.score = nn.Sequential(
59
60
61
62
63
64
65
            ColumnParallelLinear(
                config.hidden_size,
                config.hidden_size,
                quant_config=quant_config,
                params_dtype=self.head_dtype,
                return_bias=False,
            ),
66
            nn.ReLU(),
67
68
69
70
71
72
73
            RowParallelLinear(
                config.hidden_size,
                config.num_labels,
                params_dtype=self.head_dtype,
                quant_config=quant_config,
                return_bias=False,
            ),
74
        )
75
        self.make_empty_intermediate_tensors = (
76
77
            self.model.make_empty_intermediate_tensors
        )
78

79
80
81
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

82
83
84
85
86
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
87
        inputs_embeds: Optional[torch.Tensor] = None,
88
    ) -> Union[torch.Tensor, IntermediateTensors]:
89
90
91
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
92
        hidden_states = hidden_states.to(self.head_dtype)
93
        logits = self.score(hidden_states)
94
95
        return logits

96
97
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["lm_head."])
98
        return loader.load_weights(weights)
99
100


101
@default_pooling_type("ALL")
102
class Qwen2ForRewardModel(Qwen2RewardBaseModel):
103
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
104
105
        vllm_config.model_config.hf_config.num_labels = 1
        super().__init__(vllm_config=vllm_config, prefix=prefix)
106

107
        pooler_config = vllm_config.model_config.pooler_config
108
109
110
        assert pooler_config is not None

        self.pooler = DispatchPooler(
111
112
            {"encode": Pooler.for_encode(pooler_config)},
        )
113
114


115
@default_pooling_type("STEP")
116
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
117
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
118
119
        vllm_config.model_config.hf_config.num_labels = 2
        super().__init__(vllm_config=vllm_config, prefix=prefix)
120

121
        pooler_config = vllm_config.model_config.pooler_config
122
123
        assert pooler_config is not None

124
        self.pooler = DispatchPooler({"encode": Pooler.for_encode(pooler_config)})