llama_classification.py 4.82 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

16
17
18
19
20
21
22
from typing import Iterable, Optional, Tuple

import torch
import tqdm
from torch import nn
from transformers import LlamaConfig
from vllm.config import CacheConfig
Ying Sheng's avatar
Ying Sheng committed
23
from vllm.distributed import get_tensor_model_parallel_rank
24
25
26
27
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

from sglang.srt.layers.logits_processor import LogitProcessorOutput
Ying Sheng's avatar
Ying Sheng committed
28
from sglang.srt.managers.controller.model_runner import InputMetadata
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from sglang.srt.models.llama2 import LlamaModel


class LlamaForClassification(nn.Module):
    def __init__(
        self,
        config: LlamaConfig,
        quant_config: Optional[QuantizationConfig] = None,
        cache_config: Optional[CacheConfig] = None,
    ) -> None:
        super().__init__()
        self.config = config
        self.quant_config = quant_config
        self.model = LlamaModel(config, quant_config=quant_config)

Ying Sheng's avatar
Ying Sheng committed
44
45
46
        self.classification_head = nn.Linear(
            config.hidden_size, config.classification_out_size
        )
47
48
        self.eos_token_id = config.eos_token_id

Liangsheng Yin's avatar
Liangsheng Yin committed
49
    @torch.no_grad()
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        input_metadata: InputMetadata,
        input_embeds: torch.Tensor = None,
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
        is_eos_token = input_ids == self.eos_token_id
        hidden_states = hidden_states[is_eos_token]
        scores = self.classification_head(hidden_states)

        if scores.shape[0] != input_metadata.batch_size:
            print("Warning: the EOS tokens are missing in some sentences.")
Ying Sheng's avatar
Ying Sheng committed
64
65
66
            scores = torch.ones(
                (input_metadata.batch_size, self.config.classification_out_size)
            ).to(input_ids.device)
67
68
69
70
71

        return LogitProcessorOutput(
            next_token_logits=scores,
            next_token_logprobs=scores,
            normalized_prompt_logprobs=scores,
72
73
74
            input_token_logprobs=torch.ones_like(input_ids),
            input_top_logprobs=None,
            output_top_logprobs=None,
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
        if get_tensor_model_parallel_rank() == 0:
            weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name or "projector" in name:
                continue
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
            if "lm_head" in name:
                continue

            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if name.startswith("model.vision_tower") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if name.startswith("model.vision_tower") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)

Ying Sheng's avatar
Ying Sheng committed
122
123

EntryClass = LlamaForClassification