flash_mistral.py 20.6 KB
Newer Older
1
2
3
4
5
6
7
8
import math
import torch
import torch.distributed

import numpy as np

from dataclasses import dataclass
from opentelemetry import trace
9
from transformers import PreTrainedTokenizerBase, AutoTokenizer, AutoConfig
OlivierDehaene's avatar
OlivierDehaene committed
10
from typing import Optional, Tuple, Type
11
12
13
14
15
16
17
18
19
20
21

from text_generation_server.pb import generate_pb2
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE
from text_generation_server.models.cache_manager import (
    get_cache_manager,
)
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
    FlashMistralForCausalLM,
    MistralConfig,
)
Nicolas Patry's avatar
Nicolas Patry committed
22
from text_generation_server.utils.speculate import get_speculate
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from text_generation_server.utils import (
    initialize_torch_distributed,
    weight_files,
    Weights,
    HeterogeneousNextTokenChooser,
    StoppingCriteria,
)

tracer = trace.get_tracer(__name__)

# Will be set in init
SLIDING_WINDOW: Optional[int] = None
SLIDING_WINDOW_BLOCKS: Optional[int] = None

37
38
MEM_POOL = torch.cuda.graph_pool_handle()

39

OlivierDehaene's avatar
OlivierDehaene committed
40
41
42
43
44
45
46
47
48
49
50
51
52
def set_sliding_window(sliding_window: int, sliding_window_blocks: int):
    global SLIDING_WINDOW
    global SLIDING_WINDOW_BLOCKS
    SLIDING_WINDOW = sliding_window
    SLIDING_WINDOW_BLOCKS = sliding_window_blocks


def get_sliding_windows() -> Tuple[int, int]:
    global SLIDING_WINDOW
    global SLIDING_WINDOW_BLOCKS
    return SLIDING_WINDOW, SLIDING_WINDOW_BLOCKS


53
54
55
56
57
58
59
60
61
# Adds windowing logic to FlashCausalLMBatch
@dataclass
class FlashMistralBatch(FlashCausalLMBatch):
    # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
    # as we only keep SLIDING_WINDOW values instead of the whole tensor
    prefill_cache_indices: Optional[torch.Tensor] = None

    @classmethod
    def from_pb(
OlivierDehaene's avatar
OlivierDehaene committed
62
63
64
65
66
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        dtype: torch.dtype,
        device: torch.device,
67
    ) -> "FlashCausalLMBatch":
68
69
        batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
        return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
70

71
72
73
74
75
76
77
78
79
80
    @classmethod
    def from_tokenized(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        batch_tokenized_inputs,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
        sliding_window, sliding_window_blocks = get_sliding_windows()
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

        position_ids = []
        cu_seqlen_prefill = [0]
        needed_blocks_slots = []
        start_slots = []
        slot_indices = []
        prefill_cache_indices = []

        input_lengths = []
        prefix_offsets = []
        read_offsets = []
        all_input_ids = []
        requests_idx_mapping = {}

        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_head_indices = []
        prefill_next_token_indices = []
        prefill_cu_outlens = [0]

        next_token_chooser_parameters = []
        stopping_criterias = []
        top_n_tokens = []

        # Cumulative length
        cumulative_length = 0
        cumulative_max_length = 0
        prefill_out_cumulative_length = 0

        blocks = 0
        max_seqlen = 0
        max_length = 0
        max_blocks = 0

        # Parse batch
        for i, (r, tokenized_input) in enumerate(
OlivierDehaene's avatar
OlivierDehaene committed
117
            zip(pb.requests, batch_tokenized_inputs)
118
119
120
121
        ):
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

OlivierDehaene's avatar
OlivierDehaene committed
122
            tokenized_input = tokenized_input[-r.truncate :]
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

            input_length = len(tokenized_input)
            input_lengths.append(input_length)

            prefix_offsets.append(input_length - 5)
            read_offsets.append(input_length)

            all_input_ids.append(tokenized_input)

            # Position ids
            request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
            position_ids.append(request_position_ids)

            # Add cumulative lengths of all previous inputs
            cu_seqlen_prefill.append(cumulative_length + input_length)

            next_token_chooser_parameters.append(r.parameters)

            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            max_new_tokens = stopping_criteria.max_new_tokens
            stopping_criterias.append(stopping_criteria)
            top_n_tokens.append(r.top_n_tokens)

            # Paged attention
            # Remove one as the first token des not have a past
Nicolas Patry's avatar
Nicolas Patry committed
150
151
            speculative_length = get_speculate()
            total_tokens = input_length + max_new_tokens - 1 + speculative_length
152
153

            # Needed blocks can not go over SLIDING_WINDOW_BLOCKS
154
            needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
OlivierDehaene's avatar
OlivierDehaene committed
155
156
            if sliding_window_blocks is not None:
                needed_blocks = min(needed_blocks, sliding_window_blocks)
157
158
159
160
161
162
163
164
165
166
167
168
169
            blocks += needed_blocks

            needed_blocks_slots.append((needed_blocks, total_tokens))
            start_slots.append(cumulative_max_length)

            request_slot_indices = torch.arange(
                cumulative_max_length,
                cumulative_max_length + input_length,
                dtype=torch.int64,
            )
            slot_indices.append(request_slot_indices)

            # Create tensor to slice into the kv tensor in prefill
OlivierDehaene's avatar
OlivierDehaene committed
170
            if sliding_window is not None:
171
                request_prefill_cache_indices = torch.arange(
OlivierDehaene's avatar
OlivierDehaene committed
172
                    cumulative_length + max(0, input_length - sliding_window),
173
174
175
176
                    cumulative_length + input_length,
                    dtype=torch.int64,
                )
                prefill_cache_indices.append(request_prefill_cache_indices)
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202

            all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs

            if r.prefill_logprobs:
                prefill_head_indices.append(request_position_ids + cumulative_length)
                prefill_next_token_indices.append(
                    prefill_out_cumulative_length + input_length - 1
                )
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_head_indices.append(
                    torch.tensor(
                        [cumulative_length + input_length - 1], dtype=torch.int32
                    )
                )
                prefill_next_token_indices.append(prefill_out_cumulative_length)
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

            # Update
            cumulative_length += input_length
            cumulative_max_length += total_tokens
            max_seqlen = max(max_seqlen, input_length)
            max_blocks = max(max_blocks, needed_blocks)
OlivierDehaene's avatar
OlivierDehaene committed
203
204
205
            max_length = max(
                max_length, input_length + max_new_tokens + speculative_length
            )
206
207

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
drbh's avatar
drbh committed
208
            next_token_chooser_parameters, dtype, device, tokenizer
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        )
        start_slots = torch.tensor(start_slots, dtype=torch.int64)

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids

        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

        if len(pb.requests) > 1:
            input_ids = np.concatenate(all_input_ids, dtype=np.int64)
            position_ids = torch.cat(position_ids)
            slot_indices = torch.cat(slot_indices)
OlivierDehaene's avatar
OlivierDehaene committed
228
            if sliding_window is not None:
229
                prefill_cache_indices = torch.cat(prefill_cache_indices)
230
231
232
233
        else:
            input_ids = all_input_ids[0]
            position_ids = position_ids[0]
            slot_indices = slot_indices[0]
OlivierDehaene's avatar
OlivierDehaene committed
234
            if sliding_window is not None:
235
                prefill_cache_indices = prefill_cache_indices[0]
236
237
238
239
240
241
242

        cu_seqlen_prefill = torch.tensor(
            cu_seqlen_prefill, device=device, dtype=torch.int32
        )

        position_ids = position_ids.to(device)
        slot_indices = slot_indices.to(device)
243
        prefill_cache_indices = (
OlivierDehaene's avatar
OlivierDehaene committed
244
            prefill_cache_indices.to(device) if sliding_window is not None else None
245
        )
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
        input_lengths_tensor = torch.tensor(
            input_lengths, dtype=torch.int32, device=device
        )

        if all_prefill_logprobs:
            prefill_head_indices = None
            prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
        elif no_prefill_logprobs:
            prefill_head_indices = cu_seqlen_prefill[1:] - 1
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.tensor(
                torch.cat(prefill_head_indices), dtype=torch.int64, device=device
            )
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlen_prefill=cu_seqlen_prefill,
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=needed_blocks_slots,
            block_tables=None,
            block_tables_tensor=None,
            slots=None,
            max_seqlen=max_seqlen,
            prefill_head_indices=prefill_head_indices,
            prefill_next_token_indices=prefill_next_token_indices,
            prefill_cu_outlens=prefill_cu_outlens,
            input_lengths=input_lengths,
            input_lengths_tensor=input_lengths_tensor,
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
            stopping_criterias=stopping_criterias,
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
            blocks=blocks,
            max_blocks=max_blocks,
            prefill_cache_indices=prefill_cache_indices,
OlivierDehaene's avatar
OlivierDehaene committed
298
            speculative_ids=None,
299
300
301
        )


OlivierDehaene's avatar
OlivierDehaene committed
302
class BaseFlashMistral(FlashCausalLM):
303
    def __init__(
OlivierDehaene's avatar
OlivierDehaene committed
304
305
306
        self,
        model_cls,
        model_id: str,
307
        config_cls=AutoConfig,
OlivierDehaene's avatar
OlivierDehaene committed
308
309
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
310
        use_medusa: Optional[str] = None,
OlivierDehaene's avatar
OlivierDehaene committed
311
312
        dtype: Optional[torch.dtype] = None,
        trust_remote_code: bool = False,
313
        tokenizer_class=AutoTokenizer,
314
315
316
317
318
319
    ):
        self.process_group, rank, world_size = initialize_torch_distributed()
        if torch.cuda.is_available():
            device = torch.device(f"cuda:{rank}")
            dtype = torch.float16 if dtype is None else dtype
        else:
OlivierDehaene's avatar
OlivierDehaene committed
320
            raise NotImplementedError("FlashMistral is only available on GPU")
321

322
323
324
325
326
327
328
        tokenizer = tokenizer_class.from_pretrained(
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
        )
329

OlivierDehaene's avatar
OlivierDehaene committed
330
        config = config_cls.from_pretrained(
331
332
333
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        config.quantize = quantize
334
        config.use_medusa = use_medusa
335
336

        # Set context windows
337
        if getattr(config, "sliding_window", None) is not None:
OlivierDehaene's avatar
OlivierDehaene committed
338
339
340
            set_sliding_window(
                config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
            )
341
342
        else:
            config.sliding_window = None
343
344
345
346
347
348

        torch.distributed.barrier(group=self.process_group)

        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
        weights = Weights(filenames, device, dtype, process_group=self.process_group)
        if config.quantize in ["gptq", "awq"]:
OlivierDehaene's avatar
OlivierDehaene committed
349
            weights._set_gptq_params(model_id, revision)
350

351
352
        prefix = ""
        model = model_cls(prefix, config, weights)
353

354
355
        self.cuda_graphs = {}

356
        torch.distributed.barrier(group=self.process_group)
357
358
        num_layers, num_kv_heads, head_size = self.get_layer_config(model)
        super().__init__(
359
360
            model=model,
            tokenizer=tokenizer,
361
362
363
            num_layers=num_layers,
            num_kv_heads=num_kv_heads,
            head_size=head_size,
364
365
366
367
368
369
370
            dtype=dtype,
            device=device,
            rank=rank,
            world_size=world_size,
            sliding_window=config.sliding_window,
        )

371
372
373
374
375
376
377
378
379
380
    def get_layer_config(self, model) -> Tuple[int, int, int]:
        return (
            len(model.model.layers),
            model.model.num_key_value_heads,
            model.model.head_size,
        )

    def max_past(self) -> int:
        return self.model.max_past

381
382
383
384
    @property
    def batch_type(self) -> Type[FlashMistralBatch]:
        return FlashMistralBatch

385
386
387
    def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
        input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
388
        slots = torch.arange(bs, dtype=torch.int64, device=self.device)
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
        input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
        block_tables = (
            torch.arange(max_bt, dtype=torch.int32, device=self.device)
            .repeat(bs)
            .reshape((bs, max_bt))
        )
        kv_cache = get_cache_manager().kv_cache

        self.cuda_graphs[bs] = {
            "input_ids": input_ids,
            "position_ids": position_ids,
            "kv_cache": kv_cache,
            "block_tables": block_tables,
            "slots": slots,
            "input_lengths": input_lengths,
        }
        graph = torch.cuda.CUDAGraph()
        self.cuda_graphs[bs]["graph"] = graph

        torch.cuda.synchronize()
        # Run once outside to warmup
        self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlen_prefill=None,
            kv_cache=kv_cache,
            block_tables=block_tables,
            slots=slots,
            input_lengths=input_lengths,
            max_s=max_s,
            prefill_cache_indices=None,
            lm_head_indices=None,
        )
        torch.cuda.synchronize()

        with torch.cuda.graph(graph, pool=MEM_POOL):
425
            logits, speculative_logits = self.model.forward(
426
427
428
429
430
431
432
433
434
435
436
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=None,
                kv_cache=kv_cache,
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
                prefill_cache_indices=None,
                lm_head_indices=None,
            )
437
438
            self.cuda_graphs[bs]["logits"] = logits
            self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
439
440
        torch.cuda.synchronize()

441
442
443
    def forward(
        self, batch: FlashMistralBatch
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
444
        # Model Forward
Nicolas Patry's avatar
Nicolas Patry committed
445
        if batch.speculative_ids is not None:
OlivierDehaene's avatar
OlivierDehaene committed
446
447
448
449
450
451
452
453
454
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
            kv_cache = get_cache_manager().kv_cache
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
455
456
457

            speculative_ids = batch.speculative_ids

OlivierDehaene's avatar
OlivierDehaene committed
458
            B, speculative_length = speculative_ids.shape
Nicolas Patry's avatar
Nicolas Patry committed
459
            new_length = speculative_length + 1
OlivierDehaene's avatar
OlivierDehaene committed
460
461
462
            new_input_ids = torch.cat(
                [input_ids.unsqueeze(-1), speculative_ids], dim=1
            ).reshape(-1)
Nicolas Patry's avatar
Nicolas Patry committed
463
464
            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
            arange_int = arange.to(dtype=torch.int32)
OlivierDehaene's avatar
OlivierDehaene committed
465
466
467
            new_position_ids = (
                position_ids.unsqueeze(-1).expand(B, new_length) + arange
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
468
            slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
OlivierDehaene's avatar
OlivierDehaene committed
469
470
471
            input_lengths = (
                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
472
473

            # Add Copy the block tables for all members
OlivierDehaene's avatar
OlivierDehaene committed
474
475
476
477
478
479
            block_tables = (
                block_tables.unsqueeze(1)
                .expand(B, new_length, -1)
                .reshape(B * new_length, -1)
                .contiguous()
            )
Nicolas Patry's avatar
Nicolas Patry committed
480
481
482
483
484
            max_s = max_s + speculative_length

            input_ids = new_input_ids
            position_ids = new_position_ids
        else:
OlivierDehaene's avatar
OlivierDehaene committed
485
486
487
488
489
490
491
492
493
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
            kv_cache = get_cache_manager().kv_cache
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
494

495
        if cu_seqlen_prefill is None and self.max_past() is not None:
496
497
498
            # In decode, not prefill, we're actually overwriting the KV-cache
            # in a circular buffer mode.
            # This makes sure the max_s for the decode pass is correct.
499
            max_s = min(self.max_past(), max_s)
500
501
502
503
504
505
506
507
508
509
510
511
512
513

        bs = input_ids.shape[0]
        padded_bs = bs
        if bs == 3:
            padded_bs = 4
        elif 3 < bs <= 8:
            padded_bs = 8
        elif bs > 8:
            padded_bs = (bs + 7) // 8 * 8

        # Try to find an associated cuda graph
        cuda_graph = self.cuda_graphs.get(padded_bs, None)

        if cu_seqlen_prefill is not None or cuda_graph is None:
514
            logits, speculative_logits = self.model.forward(
515
516
517
518
519
520
521
522
523
524
525
526
527
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=cu_seqlen_prefill,
                kv_cache=kv_cache,
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
                prefill_cache_indices=batch.prefill_cache_indices,
                lm_head_indices=lm_head_indices,
            )
            if batch.prefill_cache_indices is not None:
                batch.prefill_cache_indices = None
528
            return logits, speculative_logits
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545

        # Copy inputs to the static inputs of the cuda graph
        # Static inputs are potentially padded
        cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
        cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
        cuda_graph["block_tables"][
            : block_tables.shape[0], : block_tables.shape[1]
        ] = block_tables
        cuda_graph["slots"].fill_(-1)
        cuda_graph["slots"][: slots.shape[0]] = slots
        cuda_graph["input_lengths"].zero_()
        cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths

        # Replay the graph
        cuda_graph["graph"].replay()

        # Slice output to the correct shape
546
547
548
549
550
551
552
        speculative_logits = (
            cuda_graph["speculative_logits"][:bs]
            if cuda_graph["speculative_logits"] is not None
            else None
        )
        logits = cuda_graph["logits"][:bs]
        return logits, speculative_logits
OlivierDehaene's avatar
OlivierDehaene committed
553
554
555
556


class FlashMistral(BaseFlashMistral):
    def __init__(
OlivierDehaene's avatar
OlivierDehaene committed
557
558
559
560
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
561
        use_medusa: Optional[str] = None,
OlivierDehaene's avatar
OlivierDehaene committed
562
563
        dtype: Optional[torch.dtype] = None,
        trust_remote_code: bool = False,
OlivierDehaene's avatar
OlivierDehaene committed
564
565
566
567
568
569
570
    ):
        super(FlashMistral, self).__init__(
            config_cls=MistralConfig,
            model_cls=FlashMistralForCausalLM,
            model_id=model_id,
            revision=revision,
            quantize=quantize,
571
            use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
572
            dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
573
            trust_remote_code=trust_remote_code,
OlivierDehaene's avatar
OlivierDehaene committed
574
        )