qwen2_rm.py 4.33 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
# Adapted from
# https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
"""Inference-only Qwen2-RM model compatible with HuggingFace weights."""
9
10
from collections.abc import Iterable
from typing import Optional, Union
11
12
13
14

import torch
from torch import nn

15
from vllm.config import VllmConfig
16
17
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
18
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
19
from vllm.sequence import IntermediateTensors
20

21
22
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces_base import default_pooling_type
23
from .qwen2 import Qwen2Model
24
from .utils import AutoWeightsLoader, maybe_prefix
25
26


27
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
28
29

    is_pooling_model = True
30
    pooler: Pooler
31

32
33
34
35
36
37
38
39
40
41
42
43
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

44
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
45
46
47
48
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
49
50
51
52
53

        self.config = config
        self.lora_config = lora_config

        self.quant_config = quant_config
54
55
        self.model = Qwen2Model(vllm_config=vllm_config,
                                prefix=maybe_prefix(prefix, "model"))
56
        self.head_dtype = vllm_config.model_config.head_dtype
57
58
59
60

        self.score = nn.Sequential(
            ColumnParallelLinear(config.hidden_size,
                                 config.hidden_size,
61
                                 quant_config=quant_config,
62
                                 params_dtype=self.head_dtype,
63
64
                                 return_bias=False),
            nn.ReLU(),
65
66
            RowParallelLinear(config.hidden_size,
                              config.num_labels,
67
                              params_dtype=self.head_dtype,
68
69
                              quant_config=quant_config,
                              return_bias=False),
70
        )
71
72
73
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

74
75
76
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

77
78
79
80
81
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
82
        inputs_embeds: Optional[torch.Tensor] = None,
83
    ) -> Union[torch.Tensor, IntermediateTensors]:
84
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
85
                                   inputs_embeds)
86
        hidden_states = hidden_states.to(self.head_dtype)
87
        logits = self.score(hidden_states)
88
89
        return logits

90
91
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
92
93
        loader = AutoWeightsLoader(self,
                                   ignore_unexpected_prefixes=["lm_head."])
94
        return loader.load_weights(weights)
95
96


97
@default_pooling_type("ALL")
98
99
class Qwen2ForRewardModel(Qwen2RewardBaseModel):

100
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
101
102
        vllm_config.model_config.hf_config.num_labels = 1
        super().__init__(vllm_config=vllm_config, prefix=prefix)
103

104
        pooler_config = vllm_config.model_config.pooler_config
105
106
107
108
        assert pooler_config is not None

        self.pooler = DispatchPooler(
            {"encode": Pooler.for_encode(pooler_config)}, )
109
110


111
@default_pooling_type("STEP")
112
113
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):

114
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
115
116
        vllm_config.model_config.hf_config.num_labels = 2
        super().__init__(vllm_config=vllm_config, prefix=prefix)
117

118
        pooler_config = vllm_config.model_config.pooler_config
119
120
        assert pooler_config is not None

121
122
        self.pooler = DispatchPooler(
            {"encode": Pooler.for_encode(pooler_config)})