flash_mistral.py 6.7 KB
Newer Older
1
2
3
4
import torch
import torch.distributed

from opentelemetry import trace
xuxzh1's avatar
last  
xuxzh1 committed
5
6
from transformers import AutoTokenizer, AutoConfig
from typing import Optional, Tuple, Dict, List
7
8

from text_generation_server.models import FlashCausalLM
xuxzh1's avatar
last  
xuxzh1 committed
9
from text_generation_server.models.flash_causal_lm import set_sliding_window
10
11
12
13
14
15
16
17
18
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
    FlashMistralForCausalLM,
    MistralConfig,
)
from text_generation_server.utils import (
    initialize_torch_distributed,
    weight_files,
    Weights,
)
xuxzh1's avatar
last  
xuxzh1 committed
19
from text_generation_server.utils.import_utils import SYSTEM
20
21
22
23

tracer = trace.get_tracer(__name__)


xuxzh1's avatar
last  
xuxzh1 committed
24
25
26
27
28
29
30
31
32
33
ADAPTER_LAYERS = [
    "q_proj",
    "k_proj",
    "v_proj",
    "o_proj",
    "gate_proj",
    "up_proj",
    "down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
34
35


OlivierDehaene's avatar
OlivierDehaene committed
36
class BaseFlashMistral(FlashCausalLM):
37
    def __init__(
OlivierDehaene's avatar
OlivierDehaene committed
38
39
40
        self,
        model_cls,
        model_id: str,
41
        config_cls=AutoConfig,
OlivierDehaene's avatar
OlivierDehaene committed
42
43
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
xuxzh1's avatar
last  
xuxzh1 committed
44
        speculator: Optional[str] = None,
OlivierDehaene's avatar
OlivierDehaene committed
45
46
        dtype: Optional[torch.dtype] = None,
        trust_remote_code: bool = False,
47
        tokenizer_class=AutoTokenizer,
48
49
50
51
52
    ):
        self.process_group, rank, world_size = initialize_torch_distributed()
        if torch.cuda.is_available():
            device = torch.device(f"cuda:{rank}")
            dtype = torch.float16 if dtype is None else dtype
xuxzh1's avatar
last  
xuxzh1 committed
53
54
55
56
57
58
59
        elif SYSTEM == "ipex":
            if hasattr(torch, "xpu") and torch.xpu.is_available():
                device = torch.device(f"xpu:{rank}")
                dtype = torch.float16 if dtype is None else dtype
            else:
                device = torch.device("cpu")
                dtype = torch.bfloat16 if dtype is None else dtype
60
        else:
OlivierDehaene's avatar
OlivierDehaene committed
61
            raise NotImplementedError("FlashMistral is only available on GPU")
62

63
64
65
66
67
68
69
        tokenizer = tokenizer_class.from_pretrained(
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
        )
70

OlivierDehaene's avatar
OlivierDehaene committed
71
        config = config_cls.from_pretrained(
72
73
74
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        config.quantize = quantize
xuxzh1's avatar
last  
xuxzh1 committed
75
        config.speculator = speculator
76
77

        # Set context windows
78
        if getattr(config, "sliding_window", None) is not None:
xuxzh1's avatar
last  
xuxzh1 committed
79
            set_sliding_window(config.sliding_window)
80
81
        else:
            config.sliding_window = None
82
83
84
85
86

        torch.distributed.barrier(group=self.process_group)

        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
        weights = Weights(filenames, device, dtype, process_group=self.process_group)
xuxzh1's avatar
last  
xuxzh1 committed
87
        if config.quantize in ["gptq", "awq", "marlin"]:
OlivierDehaene's avatar
OlivierDehaene committed
88
            weights._set_gptq_params(model_id, revision)
89

90
91
        prefix = ""
        model = model_cls(prefix, config, weights)
92

93
94
        self.cuda_graphs = {}

95
        torch.distributed.barrier(group=self.process_group)
96
97
        num_layers, num_kv_heads, head_size = self.get_layer_config(model)
        super().__init__(
xuxzh1's avatar
last  
xuxzh1 committed
98
            model_id=model_id,
99
100
            model=model,
            tokenizer=tokenizer,
101
102
103
            num_layers=num_layers,
            num_kv_heads=num_kv_heads,
            head_size=head_size,
104
105
106
107
108
109
110
            dtype=dtype,
            device=device,
            rank=rank,
            world_size=world_size,
            sliding_window=config.sliding_window,
        )

111
112
113
114
115
116
117
    def get_layer_config(self, model) -> Tuple[int, int, int]:
        return (
            len(model.model.layers),
            model.model.num_key_value_heads,
            model.model.head_size,
        )

118
    @property
xuxzh1's avatar
last  
xuxzh1 committed
119
120
    def supports_adapter_loading(self) -> bool:
        return True
121

xuxzh1's avatar
last  
xuxzh1 committed
122
123
    def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
        layer_weights = {}
124

xuxzh1's avatar
last  
xuxzh1 committed
125
        prefix = "model.layers"
126

xuxzh1's avatar
last  
xuxzh1 committed
127
128
129
130
131
132
133
134
        # This accounts for VLMs (e.g. LlavaNext, Idefics2)
        # that have a language_model inside of the larger model.
        if hasattr(self.model, "language_model"):
            _model = self.model.language_model
        elif hasattr(self.model, "text_model"):
            _model = self.model.text_model
        else:
            _model = self.model
135

xuxzh1's avatar
last  
xuxzh1 committed
136
137
138
139
        for i, layer in enumerate(_model.model.layers):
            layer_weights[(i, "q_proj")] = (
                f"{prefix}.{i}.self_attn.q_proj",
                layer.self_attn.query_key_value,
140
            )
xuxzh1's avatar
last  
xuxzh1 committed
141
142
143
144
145
146
147
148
149
150
151
            layer_weights[(i, "k_proj")] = (
                f"{prefix}.{i}.self_attn.k_proj",
                layer.self_attn.query_key_value,
            )
            layer_weights[(i, "v_proj")] = (
                f"{prefix}.{i}.self_attn.v_proj",
                layer.self_attn.query_key_value,
            )
            layer_weights[(i, "o_proj")] = (
                f"{prefix}.{i}.self_attn.o_proj",
                layer.self_attn.o_proj,
OlivierDehaene's avatar
OlivierDehaene committed
152
            )
153

xuxzh1's avatar
last  
xuxzh1 committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
            # TODO: this is a hack to avoid the gate_proj for
            # FlashStarcoder2 that doesnt have these layers
            if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"):
                layer_weights[(i, "gate_proj")] = (
                    f"{prefix}.{i}.mlp.gate_proj",
                    layer.mlp.gate_up_proj,
                )
                layer_weights[(i, "up_proj")] = (
                    f"{prefix}.{i}.mlp.up_proj",
                    layer.mlp.gate_up_proj,
                )
                layer_weights[(i, "down_proj")] = (
                    f"{prefix}.{i}.mlp.down_proj",
                    layer.mlp.down_proj,
                )
169

xuxzh1's avatar
last  
xuxzh1 committed
170
171
        layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
        return layer_weights
172

xuxzh1's avatar
last  
xuxzh1 committed
173
174
175
    @property
    def adapter_layers(self) -> List[str]:
        return ADAPTER_LAYERS
176

xuxzh1's avatar
last  
xuxzh1 committed
177
178
179
    @property
    def default_traced_adapter_layers(self) -> List[str]:
        return ["q_proj", "v_proj"]
180

xuxzh1's avatar
last  
xuxzh1 committed
181
182
    def get_num_layers_for_type(self, layer_type: str) -> int:
        return 1 if layer_type == "lm_head" else len(self.model.model.layers)
183

xuxzh1's avatar
last  
xuxzh1 committed
184
185
    def is_row_parallel(self, layer_type: str) -> bool:
        return layer_type in ROW_PARALLEL
OlivierDehaene's avatar
OlivierDehaene committed
186
187
188
189


class FlashMistral(BaseFlashMistral):
    def __init__(
OlivierDehaene's avatar
OlivierDehaene committed
190
191
192
193
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
xuxzh1's avatar
last  
xuxzh1 committed
194
        speculator: Optional[str] = None,
OlivierDehaene's avatar
OlivierDehaene committed
195
196
        dtype: Optional[torch.dtype] = None,
        trust_remote_code: bool = False,
OlivierDehaene's avatar
OlivierDehaene committed
197
198
199
200
201
202
203
    ):
        super(FlashMistral, self).__init__(
            config_cls=MistralConfig,
            model_cls=FlashMistralForCausalLM,
            model_id=model_id,
            revision=revision,
            quantize=quantize,
xuxzh1's avatar
last  
xuxzh1 committed
204
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
205
            dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
206
            trust_remote_code=trust_remote_code,
OlivierDehaene's avatar
OlivierDehaene committed
207
        )