qwen2_rm.py 4.02 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
# Adapted from
# https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
"""Inference-only Qwen2-RM model compatible with HuggingFace weights."""
9

10
from collections.abc import Iterable
11
12
13
14

import torch
from torch import nn

15
from vllm.config import VllmConfig
16
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
17
18
from vllm.model_executor.layers.pooler import Pooler
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify
19
from vllm.sequence import IntermediateTensors
20

21
22
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces_base import default_pooling_type
23
from .qwen2 import Qwen2Model
24
from .utils import AutoWeightsLoader, maybe_prefix
25
26


27
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
28
    is_pooling_model = True
29
    pooler: Pooler
30

31
32
33
34
35
36
37
38
39
40
41
42
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

43
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
44
45
46
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
47
48
49
50

        self.config = config

        self.quant_config = quant_config
51
52
53
        self.model = Qwen2Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
54
        self.head_dtype = vllm_config.model_config.head_dtype
55
56

        self.score = nn.Sequential(
57
58
59
60
61
62
63
            ColumnParallelLinear(
                config.hidden_size,
                config.hidden_size,
                quant_config=quant_config,
                params_dtype=self.head_dtype,
                return_bias=False,
            ),
64
            nn.ReLU(),
65
66
67
68
69
70
71
            RowParallelLinear(
                config.hidden_size,
                config.num_labels,
                params_dtype=self.head_dtype,
                quant_config=quant_config,
                return_bias=False,
            ),
72
        )
73
        self.make_empty_intermediate_tensors = (
74
75
            self.model.make_empty_intermediate_tensors
        )
76

77
78
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
79

80
81
    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
82
        input_ids: torch.Tensor | None,
83
        positions: torch.Tensor,
84
85
86
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
87
88
89
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
90
        hidden_states = hidden_states.to(self.head_dtype)
91
        logits = self.score(hidden_states)
92
93
        return logits

94
95
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["lm_head."])
96
        return loader.load_weights(weights)
97
98


99
@default_pooling_type(tok_pooling_type="ALL")
100
class Qwen2ForRewardModel(Qwen2RewardBaseModel):
101
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
102
103
        vllm_config.model_config.hf_config.num_labels = 1
        super().__init__(vllm_config=vllm_config, prefix=prefix)
104

105
        pooler_config = vllm_config.model_config.pooler_config
106
107
        assert pooler_config is not None

108
        self.pooler = pooler_for_token_classify(pooler_config)
109
110


111
@default_pooling_type(tok_pooling_type="STEP")
112
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
113
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
114
115
        vllm_config.model_config.hf_config.num_labels = 2
        super().__init__(vllm_config=vllm_config, prefix=prefix)
116

117
        pooler_config = vllm_config.model_config.pooler_config
118
119
        assert pooler_config is not None

zhuwenwen's avatar
zhuwenwen committed
120
        self.pooler = pooler_for_token_classify(pooler_config)