llama_embedding.py 3.18 KB
Newer Older
1
from typing import Iterable, Tuple
2
3
4
5
6
7

import torch
from torch import nn
from transformers import LlamaConfig

from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
8
from sglang.srt.model_executor.model_runner import ForwardBatch
9
from sglang.srt.model_loader.weight_utils import default_weight_loader
10
from sglang.srt.models.llama import LlamaModel
11
from sglang.srt.utils import add_prefix
12
13
14
15
16
17
18


class LlamaEmbeddingModel(nn.Module):
    def __init__(
        self,
        config: LlamaConfig,
        quant_config=None,
19
        prefix: str = "",
20
21
    ) -> None:
        super().__init__()
22
23
24
        self.model = LlamaModel(
            config, quant_config=quant_config, prefix=add_prefix("model", prefix)
        )
25
26
27
28
29
30
31
        self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)

    @torch.no_grad()
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
32
        forward_batch: ForwardBatch,
33
        input_embeds: torch.Tensor = None,
34
        get_embedding: bool = True,
35
    ) -> EmbeddingPoolerOutput:
36
37
38
        assert (
            get_embedding
        ), "LlamaEmbeddingModel / MistralModel is only used for embedding"
39
40
        hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
        return self.pooler(hidden_states, forward_batch)
41

42
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
43
44
45
46
47
48
49
50
51
52
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.model.named_parameters())

53
        for name, loaded_weight in weights:
54
55
56
57
58
59
            if "rotary_emb.inv_freq" in name or "projector" in name:
                return
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                return
60
61
62
            if name.startswith("model.vision_tower") and name not in params_dict:
                return

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    return
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)


83
84
85
86
87
class MistralModel(LlamaEmbeddingModel):
    pass


EntryClass = [LlamaEmbeddingModel, MistralModel]