qwen2_rm.py 3.91 KB
Newer Older
1
2
3
4
5
# Adapted from
# https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
"""Inference-only Qwen2-RM model compatible with HuggingFace weights."""
6
from typing import Iterable, List, Optional, Set, Tuple, Union
7
8
9
10
11

import torch
from torch import nn

from vllm.attention import AttentionMetadata
12
from vllm.config import VllmConfig
13
14
15
16
17
18
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput

19
from .interfaces import SupportsLoRA, SupportsPP
20
from .qwen2 import Qwen2Model
21
from .utils import AutoWeightsLoader, maybe_prefix
22
23
24
25
26
27
28
29
30
31
32
33
34


class ReLU(nn.Module):

    def __init__(self):
        super().__init__()
        self.activation = nn.ReLU()

    def forward(self, input):
        input, _ = input
        return self.activation(input)


35
class Qwen2ForRewardModel(nn.Module, SupportsLoRA, SupportsPP):
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj",
        "o_proj",
        "gate_up_proj",
        "down_proj",
    ]
    embedding_modules = {}
    embedding_padding_modules = []

58
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
59
60
61
62
63
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
        pooler_config = vllm_config.model_config.pooler_config
64
65
66
67
68

        self.config = config
        self.lora_config = lora_config

        self.quant_config = quant_config
69
70
        self.model = Qwen2Model(vllm_config=vllm_config,
                                prefix=maybe_prefix(prefix, "model"))
71
72
73
74
75
76
77
78
79

        self.score = nn.Sequential(
            ColumnParallelLinear(config.hidden_size,
                                 config.hidden_size,
                                 quant_config=quant_config),
            ReLU(),
            RowParallelLinear(config.hidden_size, 1,
                              quant_config=quant_config),
        )
80
81
82
83
84
        self._pooler = Pooler.from_config_with_defaults(
            pooler_config,
            pooling_type=PoolingType.ALL,
            normalize=False,
            softmax=False)
85
86
87
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

88
89
90
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

91
92
93
94
95
96
97
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
98
        inputs_embeds: Optional[torch.Tensor] = None,
99
    ) -> Union[torch.Tensor, IntermediateTensors]:
100
        hidden_states = self.model(input_ids, positions, kv_caches,
101
102
                                   attn_metadata, intermediate_tensors,
                                   inputs_embeds)
103
104
105
106
107
108
109
110
111
112
        logits, _ = self.score(hidden_states)
        return logits

    def pooler(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> Optional[PoolerOutput]:
        return self._pooler(hidden_states, pooling_metadata)

113
114
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
115
116
        loader = AutoWeightsLoader(self,
                                   ignore_unexpected_prefixes=["lm_head."])
117
        return loader.load_weights(weights)