"tests/vscode:/vscode.git/clone" did not exist on "fa183ee8eabe115215fd5a18e6b6cea836261ab4"
flash_mistral.py 6.68 KB
Newer Older
1
2
3
4
import torch
import torch.distributed

from opentelemetry import trace
5
from transformers import AutoTokenizer, AutoConfig
drbh's avatar
drbh committed
6
from typing import Optional, Tuple, Dict, List
7
8

from text_generation_server.models import FlashCausalLM
9
from text_generation_server.models.flash_causal_lm import set_sliding_window
10
11
12
13
14
15
16
17
18
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
    FlashMistralForCausalLM,
    MistralConfig,
)
from text_generation_server.utils import (
    initialize_torch_distributed,
    weight_files,
    Weights,
)
Nicolas Patry's avatar
Nicolas Patry committed
19
from text_generation_server.utils.import_utils import SYSTEM
20

21
tracer = trace.get_tracer(__name__)
22
23


drbh's avatar
drbh committed
24
25
26
27
28
29
30
31
32
33
34
35
ADAPTER_LAYERS = [
    "q_proj",
    "k_proj",
    "v_proj",
    "o_proj",
    "gate_proj",
    "up_proj",
    "down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}


OlivierDehaene's avatar
OlivierDehaene committed
36
class BaseFlashMistral(FlashCausalLM):
37
    def __init__(
OlivierDehaene's avatar
OlivierDehaene committed
38
39
40
        self,
        model_cls,
        model_id: str,
41
        config_cls=AutoConfig,
OlivierDehaene's avatar
OlivierDehaene committed
42
43
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
Nicolas Patry's avatar
Nicolas Patry committed
44
        speculator: Optional[str] = None,
OlivierDehaene's avatar
OlivierDehaene committed
45
46
        dtype: Optional[torch.dtype] = None,
        trust_remote_code: bool = False,
47
        tokenizer_class=AutoTokenizer,
48
49
50
51
52
    ):
        self.process_group, rank, world_size = initialize_torch_distributed()
        if torch.cuda.is_available():
            device = torch.device(f"cuda:{rank}")
            dtype = torch.float16 if dtype is None else dtype
Nicolas Patry's avatar
Nicolas Patry committed
53
54
55
        elif SYSTEM == "ipex":
            if hasattr(torch, "xpu") and torch.xpu.is_available():
                device = torch.device(f"xpu:{rank}")
Wang, Yi's avatar
Wang, Yi committed
56
                dtype = torch.float16 if dtype is None else dtype
Nicolas Patry's avatar
Nicolas Patry committed
57
58
            else:
                device = torch.device("cpu")
Wang, Yi's avatar
Wang, Yi committed
59
                dtype = torch.bfloat16 if dtype is None else dtype
60
        else:
OlivierDehaene's avatar
OlivierDehaene committed
61
            raise NotImplementedError("FlashMistral is only available on GPU")
62

63
64
65
66
67
68
69
        tokenizer = tokenizer_class.from_pretrained(
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
        )
70

OlivierDehaene's avatar
OlivierDehaene committed
71
        config = config_cls.from_pretrained(
72
73
74
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        config.quantize = quantize
Nicolas Patry's avatar
Nicolas Patry committed
75
        config.speculator = speculator
76
77

        # Set context windows
78
        if getattr(config, "sliding_window", None) is not None:
79
            set_sliding_window(config.sliding_window)
80
81
        else:
            config.sliding_window = None
82
83
84
85
86

        torch.distributed.barrier(group=self.process_group)

        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
        weights = Weights(filenames, device, dtype, process_group=self.process_group)
87
        if config.quantize in ["gptq", "awq", "marlin"]:
OlivierDehaene's avatar
OlivierDehaene committed
88
            weights._set_gptq_params(model_id, revision)
89

90
91
        prefix = ""
        model = model_cls(prefix, config, weights)
92

93
94
        self.cuda_graphs = {}

95
        torch.distributed.barrier(group=self.process_group)
96
97
        num_layers, num_kv_heads, head_size = self.get_layer_config(model)
        super().__init__(
drbh's avatar
drbh committed
98
            model_id=model_id,
99
100
            model=model,
            tokenizer=tokenizer,
101
102
103
            num_layers=num_layers,
            num_kv_heads=num_kv_heads,
            head_size=head_size,
104
105
106
107
108
109
110
            dtype=dtype,
            device=device,
            rank=rank,
            world_size=world_size,
            sliding_window=config.sliding_window,
        )

111
112
113
114
115
116
117
    def get_layer_config(self, model) -> Tuple[int, int, int]:
        return (
            len(model.model.layers),
            model.model.num_key_value_heads,
            model.model.head_size,
        )

drbh's avatar
drbh committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    @property
    def supports_adapter_loading(self) -> bool:
        return True

    def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
        layer_weights = {}

        prefix = "model.layers"

        # This accounts for VLMs (e.g. LlavaNext, Idefics2)
        # that have a language_model inside of the larger model.
        if hasattr(self.model, "language_model"):
            _model = self.model.language_model
        elif hasattr(self.model, "text_model"):
            _model = self.model.text_model
        else:
            _model = self.model

        for i, layer in enumerate(_model.model.layers):
            layer_weights[(i, "q_proj")] = (
                f"{prefix}.{i}.self_attn.q_proj",
                layer.self_attn.query_key_value,
            )
            layer_weights[(i, "k_proj")] = (
                f"{prefix}.{i}.self_attn.k_proj",
                layer.self_attn.query_key_value,
            )
            layer_weights[(i, "v_proj")] = (
                f"{prefix}.{i}.self_attn.v_proj",
                layer.self_attn.query_key_value,
            )
            layer_weights[(i, "o_proj")] = (
                f"{prefix}.{i}.self_attn.o_proj",
                layer.self_attn.o_proj,
            )

            # TODO: this is a hack to avoid the gate_proj for
            # FlashStarcoder2 that doesnt have these layers
            if hasattr(layer.mlp, "gate_up_proj"):
                layer_weights[(i, "gate_proj")] = (
                    f"{prefix}.{i}.mlp.gate_proj",
                    layer.mlp.gate_up_proj,
                )
                layer_weights[(i, "up_proj")] = (
                    f"{prefix}.{i}.mlp.up_proj",
                    layer.mlp.gate_up_proj,
                )
                layer_weights[(i, "down_proj")] = (
                    f"{prefix}.{i}.mlp.down_proj",
                    layer.mlp.down_proj,
                )

        layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
        return layer_weights

    @property
    def adapter_layers(self) -> List[str]:
        return ADAPTER_LAYERS

    @property
    def default_traced_adapter_layers(self) -> List[str]:
        return ["q_proj", "v_proj"]

    def get_num_layers_for_type(self, layer_type: str) -> int:
        return 1 if layer_type == "lm_head" else len(self.model.model.layers)

    def is_row_parallel(self, layer_type: str) -> bool:
        return layer_type in ROW_PARALLEL

OlivierDehaene's avatar
OlivierDehaene committed
187
188
189

class FlashMistral(BaseFlashMistral):
    def __init__(
OlivierDehaene's avatar
OlivierDehaene committed
190
191
192
193
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
Nicolas Patry's avatar
Nicolas Patry committed
194
        speculator: Optional[str] = None,
OlivierDehaene's avatar
OlivierDehaene committed
195
196
        dtype: Optional[torch.dtype] = None,
        trust_remote_code: bool = False,
OlivierDehaene's avatar
OlivierDehaene committed
197
198
199
200
201
202
203
    ):
        super(FlashMistral, self).__init__(
            config_cls=MistralConfig,
            model_cls=FlashMistralForCausalLM,
            model_id=model_id,
            revision=revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
204
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
205
            dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
206
            trust_remote_code=trust_remote_code,
OlivierDehaene's avatar
OlivierDehaene committed
207
        )