qwen2_rm.py 4.13 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
# Adapted from
# https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
"""Inference-only Qwen2-RM model compatible with HuggingFace weights."""
9
10
from collections.abc import Iterable
from typing import Optional, Union
11
12
13
14

import torch
from torch import nn

15
from vllm.config import VllmConfig
16
17
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
18
from vllm.model_executor.layers.pooler import Pooler, PoolingType, SimplePooler
19
from vllm.sequence import IntermediateTensors
20

21
from .interfaces import SupportsLoRA, SupportsPP
22
from .qwen2 import Qwen2Model
23
from .utils import AutoWeightsLoader, maybe_prefix
24
25


26
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
27
28
29
30

    is_pooling_model = True
    pooler: SimplePooler

31
32
33
34
35
36
37
38
39
40
41
42
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

43
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
44
45
46
47
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
48
49
50
51
52

        self.config = config
        self.lora_config = lora_config

        self.quant_config = quant_config
53
54
        self.model = Qwen2Model(vllm_config=vllm_config,
                                prefix=maybe_prefix(prefix, "model"))
55
56
57
58

        self.score = nn.Sequential(
            ColumnParallelLinear(config.hidden_size,
                                 config.hidden_size,
59
60
61
                                 quant_config=quant_config,
                                 return_bias=False),
            nn.ReLU(),
62
63
            RowParallelLinear(config.hidden_size,
                              config.num_labels,
64
65
                              quant_config=quant_config,
                              return_bias=False),
66
        )
67
68
69
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

70
71
72
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

73
74
75
76
77
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
78
        inputs_embeds: Optional[torch.Tensor] = None,
79
    ) -> Union[torch.Tensor, IntermediateTensors]:
80
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
81
                                   inputs_embeds)
82
        logits = self.score(hidden_states)
83
84
        return logits

85
86
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
87
88
        loader = AutoWeightsLoader(self,
                                   ignore_unexpected_prefixes=["lm_head."])
89
        return loader.load_weights(weights)
90
91
92
93


class Qwen2ForRewardModel(Qwen2RewardBaseModel):

94
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
95
96
97
        vllm_config.model_config.hf_config.num_labels = 1
        super().__init__(vllm_config=vllm_config, prefix=prefix)
        pooler_config = vllm_config.model_config.pooler_config
98
        self.pooler = Pooler.from_config_with_defaults(
99
100
101
102
103
104
105
106
            pooler_config,
            pooling_type=PoolingType.ALL,
            normalize=False,
            softmax=False)


class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):

107
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
108
109
110
        vllm_config.model_config.hf_config.num_labels = 2
        super().__init__(vllm_config=vllm_config, prefix=prefix)
        pooler_config = vllm_config.model_config.pooler_config
111
        self.pooler = Pooler.from_config_with_defaults(
112
113
114
115
116
117
            pooler_config,
            pooling_type=PoolingType.STEP,
            normalize=False,
            softmax=True,
            step_tag_id=151651,
        )