"docs/vscode:/vscode.git/clone" did not exist on "c9745ee082d94fb3324891f32f9c81ce0ca51cf5"
flash_mistral.py 22.1 KB
Newer Older
1
2
3
4
5
6
7
8
import math
import torch
import torch.distributed

import numpy as np

from dataclasses import dataclass
from opentelemetry import trace
9
from transformers import PreTrainedTokenizerBase, AutoTokenizer, AutoConfig
OlivierDehaene's avatar
OlivierDehaene committed
10
from typing import Optional, Tuple, Type
11
12
13
14
15
16
17
18
19
20
21

from text_generation_server.pb import generate_pb2
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE
from text_generation_server.models.cache_manager import (
    get_cache_manager,
)
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
    FlashMistralForCausalLM,
    MistralConfig,
)
Nicolas Patry's avatar
Nicolas Patry committed
22
from text_generation_server.utils.speculate import get_speculate
23
24
25
26
27
28
29
30
31
32
33
34
35
from text_generation_server.utils import (
    initialize_torch_distributed,
    weight_files,
    Weights,
    HeterogeneousNextTokenChooser,
    StoppingCriteria,
)

tracer = trace.get_tracer(__name__)

# Will be set in init
SLIDING_WINDOW: Optional[int] = None
SLIDING_WINDOW_BLOCKS: Optional[int] = None
Nicolas Patry's avatar
Nicolas Patry committed
36
from text_generation_server.utils.import_utils import SYSTEM
37

38
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
39

40

OlivierDehaene's avatar
OlivierDehaene committed
41
42
43
44
45
46
47
48
49
50
51
52
53
def set_sliding_window(sliding_window: int, sliding_window_blocks: int):
    global SLIDING_WINDOW
    global SLIDING_WINDOW_BLOCKS
    SLIDING_WINDOW = sliding_window
    SLIDING_WINDOW_BLOCKS = sliding_window_blocks


def get_sliding_windows() -> Tuple[int, int]:
    global SLIDING_WINDOW
    global SLIDING_WINDOW_BLOCKS
    return SLIDING_WINDOW, SLIDING_WINDOW_BLOCKS


54
55
56
57
58
59
60
61
62
# Adds windowing logic to FlashCausalLMBatch
@dataclass
class FlashMistralBatch(FlashCausalLMBatch):
    # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
    # as we only keep SLIDING_WINDOW values instead of the whole tensor
    prefill_cache_indices: Optional[torch.Tensor] = None

    @classmethod
    def from_pb(
OlivierDehaene's avatar
OlivierDehaene committed
63
64
65
66
67
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        dtype: torch.dtype,
        device: torch.device,
68
    ) -> "FlashCausalLMBatch":
69
70
        batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
        return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
71

72
73
74
75
76
77
78
79
80
81
    @classmethod
    def from_tokenized(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        batch_tokenized_inputs,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
        sliding_window, sliding_window_blocks = get_sliding_windows()
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

        position_ids = []
        cu_seqlen_prefill = [0]
        needed_blocks_slots = []
        start_slots = []
        slot_indices = []
        prefill_cache_indices = []

        input_lengths = []
        prefix_offsets = []
        read_offsets = []
        all_input_ids = []
        requests_idx_mapping = {}

        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_head_indices = []
        prefill_next_token_indices = []
        prefill_cu_outlens = [0]

        next_token_chooser_parameters = []
        stopping_criterias = []
        top_n_tokens = []

        # Cumulative length
        cumulative_length = 0
        cumulative_max_length = 0
        prefill_out_cumulative_length = 0

        blocks = 0
        max_seqlen = 0
        max_length = 0
        max_blocks = 0

        # Parse batch
        for i, (r, tokenized_input) in enumerate(
OlivierDehaene's avatar
OlivierDehaene committed
118
            zip(pb.requests, batch_tokenized_inputs)
119
120
121
122
        ):
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

OlivierDehaene's avatar
OlivierDehaene committed
123
            tokenized_input = tokenized_input[-r.truncate :]
124
125
126
127
128
            if (
                tokenized_input[0] == tokenizer.bos_token_id
                and tokenized_input[1] == tokenizer.bos_token_id
            ):
                tokenized_input = tokenized_input[1:]
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155

            input_length = len(tokenized_input)
            input_lengths.append(input_length)

            prefix_offsets.append(input_length - 5)
            read_offsets.append(input_length)

            all_input_ids.append(tokenized_input)

            # Position ids
            request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
            position_ids.append(request_position_ids)

            # Add cumulative lengths of all previous inputs
            cu_seqlen_prefill.append(cumulative_length + input_length)

            next_token_chooser_parameters.append(r.parameters)

            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            max_new_tokens = stopping_criteria.max_new_tokens
            stopping_criterias.append(stopping_criteria)
            top_n_tokens.append(r.top_n_tokens)

            # Paged attention
            # Remove one as the first token des not have a past
Nicolas Patry's avatar
Nicolas Patry committed
156
157
            speculative_length = get_speculate()
            total_tokens = input_length + max_new_tokens - 1 + speculative_length
158
159

            # Needed blocks can not go over SLIDING_WINDOW_BLOCKS
160
            needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
OlivierDehaene's avatar
OlivierDehaene committed
161
162
            if sliding_window_blocks is not None:
                needed_blocks = min(needed_blocks, sliding_window_blocks)
163
164
165
166
167
168
169
170
171
172
173
174
175
            blocks += needed_blocks

            needed_blocks_slots.append((needed_blocks, total_tokens))
            start_slots.append(cumulative_max_length)

            request_slot_indices = torch.arange(
                cumulative_max_length,
                cumulative_max_length + input_length,
                dtype=torch.int64,
            )
            slot_indices.append(request_slot_indices)

            # Create tensor to slice into the kv tensor in prefill
OlivierDehaene's avatar
OlivierDehaene committed
176
            if sliding_window is not None:
177
                request_prefill_cache_indices = torch.arange(
OlivierDehaene's avatar
OlivierDehaene committed
178
                    cumulative_length + max(0, input_length - sliding_window),
179
180
181
182
                    cumulative_length + input_length,
                    dtype=torch.int64,
                )
                prefill_cache_indices.append(request_prefill_cache_indices)
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208

            all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs

            if r.prefill_logprobs:
                prefill_head_indices.append(request_position_ids + cumulative_length)
                prefill_next_token_indices.append(
                    prefill_out_cumulative_length + input_length - 1
                )
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_head_indices.append(
                    torch.tensor(
                        [cumulative_length + input_length - 1], dtype=torch.int32
                    )
                )
                prefill_next_token_indices.append(prefill_out_cumulative_length)
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

            # Update
            cumulative_length += input_length
            cumulative_max_length += total_tokens
            max_seqlen = max(max_seqlen, input_length)
            max_blocks = max(max_blocks, needed_blocks)
OlivierDehaene's avatar
OlivierDehaene committed
209
210
211
            max_length = max(
                max_length, input_length + max_new_tokens + speculative_length
            )
212
213

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
drbh's avatar
drbh committed
214
            next_token_chooser_parameters, dtype, device, tokenizer
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
        )
        start_slots = torch.tensor(start_slots, dtype=torch.int64)

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids

        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

        if len(pb.requests) > 1:
            input_ids = np.concatenate(all_input_ids, dtype=np.int64)
            position_ids = torch.cat(position_ids)
            slot_indices = torch.cat(slot_indices)
OlivierDehaene's avatar
OlivierDehaene committed
234
            if sliding_window is not None:
235
                prefill_cache_indices = torch.cat(prefill_cache_indices)
236
237
238
239
        else:
            input_ids = all_input_ids[0]
            position_ids = position_ids[0]
            slot_indices = slot_indices[0]
OlivierDehaene's avatar
OlivierDehaene committed
240
            if sliding_window is not None:
241
                prefill_cache_indices = prefill_cache_indices[0]
242
243
244
245
246
247
248

        cu_seqlen_prefill = torch.tensor(
            cu_seqlen_prefill, device=device, dtype=torch.int32
        )

        position_ids = position_ids.to(device)
        slot_indices = slot_indices.to(device)
249
        prefill_cache_indices = (
OlivierDehaene's avatar
OlivierDehaene committed
250
            prefill_cache_indices.to(device) if sliding_window is not None else None
251
        )
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
        input_lengths_tensor = torch.tensor(
            input_lengths, dtype=torch.int32, device=device
        )

        if all_prefill_logprobs:
            prefill_head_indices = None
            prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
        elif no_prefill_logprobs:
            prefill_head_indices = cu_seqlen_prefill[1:] - 1
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.tensor(
                torch.cat(prefill_head_indices), dtype=torch.int64, device=device
            )
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlen_prefill=cu_seqlen_prefill,
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=needed_blocks_slots,
            block_tables=None,
            block_tables_tensor=None,
            slots=None,
            max_seqlen=max_seqlen,
            prefill_head_indices=prefill_head_indices,
            prefill_next_token_indices=prefill_next_token_indices,
            prefill_cu_outlens=prefill_cu_outlens,
            input_lengths=input_lengths,
            input_lengths_tensor=input_lengths_tensor,
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
            stopping_criterias=stopping_criterias,
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
            blocks=blocks,
            max_blocks=max_blocks,
            prefill_cache_indices=prefill_cache_indices,
OlivierDehaene's avatar
OlivierDehaene committed
304
            speculative_ids=None,
305
306
307
        )


OlivierDehaene's avatar
OlivierDehaene committed
308
class BaseFlashMistral(FlashCausalLM):
309
    def __init__(
OlivierDehaene's avatar
OlivierDehaene committed
310
311
312
        self,
        model_cls,
        model_id: str,
313
        config_cls=AutoConfig,
OlivierDehaene's avatar
OlivierDehaene committed
314
315
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
Nicolas Patry's avatar
Nicolas Patry committed
316
        speculator: Optional[str] = None,
OlivierDehaene's avatar
OlivierDehaene committed
317
318
        dtype: Optional[torch.dtype] = None,
        trust_remote_code: bool = False,
319
        tokenizer_class=AutoTokenizer,
320
321
322
323
324
    ):
        self.process_group, rank, world_size = initialize_torch_distributed()
        if torch.cuda.is_available():
            device = torch.device(f"cuda:{rank}")
            dtype = torch.float16 if dtype is None else dtype
Nicolas Patry's avatar
Nicolas Patry committed
325
        elif SYSTEM == "xpu":
326
327
            device = torch.device(f"xpu:{rank}")
            dtype = torch.float16 if dtype is None else dtype
328
        else:
OlivierDehaene's avatar
OlivierDehaene committed
329
            raise NotImplementedError("FlashMistral is only available on GPU")
330

331
332
333
334
335
336
337
        tokenizer = tokenizer_class.from_pretrained(
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
        )
338

OlivierDehaene's avatar
OlivierDehaene committed
339
        config = config_cls.from_pretrained(
340
341
342
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        config.quantize = quantize
Nicolas Patry's avatar
Nicolas Patry committed
343
        config.speculator = speculator
344
345

        # Set context windows
346
        if getattr(config, "sliding_window", None) is not None:
OlivierDehaene's avatar
OlivierDehaene committed
347
348
349
            set_sliding_window(
                config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
            )
350
351
        else:
            config.sliding_window = None
352
353
354
355
356
357

        torch.distributed.barrier(group=self.process_group)

        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
        weights = Weights(filenames, device, dtype, process_group=self.process_group)
        if config.quantize in ["gptq", "awq"]:
OlivierDehaene's avatar
OlivierDehaene committed
358
            weights._set_gptq_params(model_id, revision)
359

360
361
        prefix = ""
        model = model_cls(prefix, config, weights)
362

363
364
        self.cuda_graphs = {}

365
        torch.distributed.barrier(group=self.process_group)
366
367
        num_layers, num_kv_heads, head_size = self.get_layer_config(model)
        super().__init__(
368
369
            model=model,
            tokenizer=tokenizer,
370
371
372
            num_layers=num_layers,
            num_kv_heads=num_kv_heads,
            head_size=head_size,
373
374
375
376
377
378
379
            dtype=dtype,
            device=device,
            rank=rank,
            world_size=world_size,
            sliding_window=config.sliding_window,
        )

380
381
382
383
384
385
386
387
388
389
    def get_layer_config(self, model) -> Tuple[int, int, int]:
        return (
            len(model.model.layers),
            model.model.num_key_value_heads,
            model.model.head_size,
        )

    def max_past(self) -> int:
        return self.model.max_past

390
391
392
393
    @property
    def batch_type(self) -> Type[FlashMistralBatch]:
        return FlashMistralBatch

fxmarty's avatar
fxmarty committed
394
395
396
397
398
399
    def tunableop_warmup(self, seqlen: int):
        input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
        slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
        kv_cache = get_cache_manager().kv_cache

fxmarty's avatar
fxmarty committed
400
401
402
        # Dummy value, some models (starcoder2) don't accept `None`.
        input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)

fxmarty's avatar
fxmarty committed
403
404
405
406
407
408
409
410
411
        # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
        self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlen_prefill=torch.tensor(
                [0, seqlen], device=self.device, dtype=torch.int32
            ),
            kv_cache=get_cache_manager().kv_cache,
            block_tables=None,
fxmarty's avatar
fxmarty committed
412
            input_lengths=input_lengths,
fxmarty's avatar
fxmarty committed
413
414
415
416
417
418
            slots=slots,
            max_s=seqlen,
            lm_head_indices=None,
            prefill_cache_indices=None,
        )

419
420
421
    def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
        input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
422
        slots = torch.arange(bs, dtype=torch.int64, device=self.device)
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
        input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
        block_tables = (
            torch.arange(max_bt, dtype=torch.int32, device=self.device)
            .repeat(bs)
            .reshape((bs, max_bt))
        )
        kv_cache = get_cache_manager().kv_cache

        self.cuda_graphs[bs] = {
            "input_ids": input_ids,
            "position_ids": position_ids,
            "kv_cache": kv_cache,
            "block_tables": block_tables,
            "slots": slots,
            "input_lengths": input_lengths,
        }
        graph = torch.cuda.CUDAGraph()
        self.cuda_graphs[bs]["graph"] = graph

        torch.cuda.synchronize()
        # Run once outside to warmup
        self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlen_prefill=None,
            kv_cache=kv_cache,
            block_tables=block_tables,
            slots=slots,
            input_lengths=input_lengths,
            max_s=max_s,
            prefill_cache_indices=None,
            lm_head_indices=None,
        )
        torch.cuda.synchronize()

        with torch.cuda.graph(graph, pool=MEM_POOL):
459
            logits, speculative_logits = self.model.forward(
460
461
462
463
464
465
466
467
468
469
470
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=None,
                kv_cache=kv_cache,
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
                prefill_cache_indices=None,
                lm_head_indices=None,
            )
471
472
            self.cuda_graphs[bs]["logits"] = logits
            self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
473
474
        torch.cuda.synchronize()

475
476
477
    def forward(
        self, batch: FlashMistralBatch
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
478
        # Model Forward
Nicolas Patry's avatar
Nicolas Patry committed
479
        if batch.speculative_ids is not None:
OlivierDehaene's avatar
OlivierDehaene committed
480
481
482
483
484
485
486
487
488
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
            kv_cache = get_cache_manager().kv_cache
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
489
490
491

            speculative_ids = batch.speculative_ids

OlivierDehaene's avatar
OlivierDehaene committed
492
            B, speculative_length = speculative_ids.shape
Nicolas Patry's avatar
Nicolas Patry committed
493
            new_length = speculative_length + 1
OlivierDehaene's avatar
OlivierDehaene committed
494
495
496
            new_input_ids = torch.cat(
                [input_ids.unsqueeze(-1), speculative_ids], dim=1
            ).reshape(-1)
Nicolas Patry's avatar
Nicolas Patry committed
497
498
            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
            arange_int = arange.to(dtype=torch.int32)
OlivierDehaene's avatar
OlivierDehaene committed
499
500
501
            new_position_ids = (
                position_ids.unsqueeze(-1).expand(B, new_length) + arange
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
502
            slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
OlivierDehaene's avatar
OlivierDehaene committed
503
504
505
            input_lengths = (
                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
506
507

            # Add Copy the block tables for all members
OlivierDehaene's avatar
OlivierDehaene committed
508
509
510
511
512
513
            block_tables = (
                block_tables.unsqueeze(1)
                .expand(B, new_length, -1)
                .reshape(B * new_length, -1)
                .contiguous()
            )
Nicolas Patry's avatar
Nicolas Patry committed
514
515
516
517
518
            max_s = max_s + speculative_length

            input_ids = new_input_ids
            position_ids = new_position_ids
        else:
OlivierDehaene's avatar
OlivierDehaene committed
519
520
521
522
523
524
525
526
527
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
            kv_cache = get_cache_manager().kv_cache
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
528

529
        if cu_seqlen_prefill is None and self.max_past() is not None:
530
531
532
            # In decode, not prefill, we're actually overwriting the KV-cache
            # in a circular buffer mode.
            # This makes sure the max_s for the decode pass is correct.
533
            max_s = min(self.max_past(), max_s)
534
535
536
537
538
539
540
541
542
543
544
545
546
547

        bs = input_ids.shape[0]
        padded_bs = bs
        if bs == 3:
            padded_bs = 4
        elif 3 < bs <= 8:
            padded_bs = 8
        elif bs > 8:
            padded_bs = (bs + 7) // 8 * 8

        # Try to find an associated cuda graph
        cuda_graph = self.cuda_graphs.get(padded_bs, None)

        if cu_seqlen_prefill is not None or cuda_graph is None:
Nicolas Patry's avatar
Nicolas Patry committed
548
549
550
551
552
553
554
555
556
557
558
559
            logits, speculative_logits = self.model.forward(
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=cu_seqlen_prefill,
                kv_cache=kv_cache,
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
                prefill_cache_indices=batch.prefill_cache_indices,
                lm_head_indices=lm_head_indices,
            )
560
561
            if batch.prefill_cache_indices is not None:
                batch.prefill_cache_indices = None
562
            return logits, speculative_logits
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579

        # Copy inputs to the static inputs of the cuda graph
        # Static inputs are potentially padded
        cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
        cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
        cuda_graph["block_tables"][
            : block_tables.shape[0], : block_tables.shape[1]
        ] = block_tables
        cuda_graph["slots"].fill_(-1)
        cuda_graph["slots"][: slots.shape[0]] = slots
        cuda_graph["input_lengths"].zero_()
        cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths

        # Replay the graph
        cuda_graph["graph"].replay()

        # Slice output to the correct shape
580
581
582
583
584
585
586
        speculative_logits = (
            cuda_graph["speculative_logits"][:bs]
            if cuda_graph["speculative_logits"] is not None
            else None
        )
        logits = cuda_graph["logits"][:bs]
        return logits, speculative_logits
OlivierDehaene's avatar
OlivierDehaene committed
587
588
589
590


class FlashMistral(BaseFlashMistral):
    def __init__(
OlivierDehaene's avatar
OlivierDehaene committed
591
592
593
594
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
Nicolas Patry's avatar
Nicolas Patry committed
595
        speculator: Optional[str] = None,
OlivierDehaene's avatar
OlivierDehaene committed
596
597
        dtype: Optional[torch.dtype] = None,
        trust_remote_code: bool = False,
OlivierDehaene's avatar
OlivierDehaene committed
598
599
600
601
602
603
604
    ):
        super(FlashMistral, self).__init__(
            config_cls=MistralConfig,
            model_cls=FlashMistralForCausalLM,
            model_id=model_id,
            revision=revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
605
            speculator=speculator,
OlivierDehaene's avatar
OlivierDehaene committed
606
            dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
607
            trust_remote_code=trust_remote_code,
OlivierDehaene's avatar
OlivierDehaene committed
608
        )