"test/torchscript_consistency_impl.py" did not exist on "a9c4d0a8b0cd94a3a812786d5354314a6c081700"
flash_causal_lm.py 30.4 KB
Newer Older
1
2
3
import torch
import torch.distributed

4
5
import numpy as np

6
7
8
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
9
from typing import Optional, Tuple, List, Type, Union, Dict
10
11
12
13
14
15
16
17
18

from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
19
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
20
21
22
23
24
25
26
27

tracer = trace.get_tracer(__name__)


@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
28
29
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
30
31

    # Decoder values
32
33
34
    input_ids: torch.Tensor
    position_ids: torch.Tensor

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    # Indices to copy present to the correct indices is the pre-allocated past key values
    past_present_indices: torch.Tensor

    # tensor of length b holding starting offset of each sequence
    start_seq: torch.Tensor
    # tensor of length b holding ending offset of each sequence
    end_seq: torch.Tensor
    # tensor of length b holding starting offset of each sequence, only used in prefill
    start_seq_prefill: Optional[torch.Tensor]
    # tensor of length b holding ending offset of each sequence, only used in prefill
    end_seq_prefill: Optional[torch.Tensor]
    # tensor of length b holding starting offset of each query sequence, only used in decode
    start_seq_q: Optional[torch.Tensor]
    # tensor of length b holding ending offset of each query sequence, only used in decode
    end_seq_q: Optional[torch.Tensor]
50
51
    # past key values, only used in decode
    past_key_values: Optional[torch.Tensor]
52
53
    max_seqlen: int

54
55
56
57
58
    # Prefill metadata tensors to efficiently compute logprobs
    prefill_head_indices: Optional[torch.Tensor]
    prefill_next_token_indices: Optional[torch.tensor]
    prefill_cu_outlens: Optional[List[int]]

59
60
    # All tokens
    all_input_ids: List[List[int]]
61
    all_input_ids_tensor: torch.Tensor
62
63
64

    # Lengths of all generations present in the batch
    input_lengths: List[int]
65
66
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
67
68

    # Generation helpers
69
    next_token_chooser: HeterogeneousNextTokenChooser
70
71
    stopping_criterias: List[StoppingCriteria]

72
73
74
    # Maximum number of tokens this batch will grow to
    max_tokens: int

75
76
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
77
            id=self.batch_id,
78
            request_ids=[r.id for r in self.requests],
79
80
            size=len(self),
            max_tokens=self.max_tokens,
81
82
83
84
85
86
87
        )

    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
88
        dtype: torch.dtype,
89
        device: torch.device,
90
    ) -> "FlashCausalLMBatch":
91
92
93
94
95
96
97
98
99
100
        batch_inputs = []
        max_truncation = 0
        for r in pb.requests:
            batch_inputs.append(r.inputs)
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
            batch_inputs, truncation=True, max_length=max_truncation
        )["input_ids"]

101
        position_ids = []
102
103
104
105
106
        past_present_indices = []
        start_seq = []
        end_seq = []
        start_seq_prefill = []
        end_seq_prefill = []
107
108
109
        max_seqlen = 0

        input_lengths = []
110
111
        prefix_offsets = []
        read_offsets = []
112
        all_input_ids = []
113
        requests_idx_mapping = {}
114

115
116
117
118
119
120
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_head_indices = []
        prefill_next_token_indices = []
        prefill_cu_outlens = [0]

121
        next_token_chooser_parameters = []
122
123
124
125
        stopping_criterias = []

        # Cumulative length
        cumulative_length = 0
126
        cumulative_max_length = 0
127
        prefill_out_cumulative_length = 0
128

129
        max_length = 0
130

131
        # Parse batch
132
133
134
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
135
136
137
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

138
            tokenized_input = tokenized_input[-r.truncate :]
139

140
141
142
            input_length = len(tokenized_input)
            max_seqlen = max(max_seqlen, input_length)
            input_lengths.append(input_length)
143

144
            prefix_offsets.append(input_length - 5)
145
            read_offsets.append(input_length)
146

147
            all_input_ids.append(tokenized_input)
148
149

            # Position ids
150
151
            request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
            position_ids.append(request_position_ids)
152
153

            # Add cumulative lengths of all previous inputs
154
155
156
157
            start_seq_prefill.append(cumulative_length)
            end_seq_prefill.append(cumulative_length + input_length)
            start_seq.append(cumulative_max_length)
            end_seq.append(cumulative_max_length + input_length)
158

159
            next_token_chooser_parameters.append(r.parameters)
160

161
162
163
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
164
            max_new_tokens = stopping_criteria.max_new_tokens
165
            stopping_criterias.append(stopping_criteria)
166

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
            all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs

            if r.prefill_logprobs:
                prefill_head_indices.append(request_position_ids + cumulative_length)
                prefill_next_token_indices.append(
                    prefill_out_cumulative_length + input_length - 1
                )
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_head_indices.append(
                    torch.tensor(
                        [cumulative_length + input_length - 1], dtype=torch.int32
                    )
                )
                prefill_next_token_indices.append(prefill_out_cumulative_length)
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

187
188
189
190
191
192
193
            request_past_present_indices = torch.arange(
                cumulative_max_length,
                cumulative_max_length + input_length,
                dtype=torch.int64,
            )
            past_present_indices.append(request_past_present_indices)

194
            # Update
195
            # Remove one as the first token des not have a past
196
            cumulative_length += input_length
197
            cumulative_max_length += input_length + max_new_tokens - 1
198
199
200
201
202
203
204
205
206
207
208
209
            max_length = max(max_length, input_length + max_new_tokens)

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
            next_token_chooser_parameters, dtype, device
        )

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
210

211
212
213
214
215
216
217
        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )
        start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32)
        end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32)

218
219
220
        if len(pb.requests) > 1:
            input_ids = np.concatenate(all_input_ids, dtype=np.int64)
            position_ids = torch.cat(position_ids)
221
222
223
224
225
226
227
228
229

            past_present_indices = np.concatenate(past_present_indices, dtype=np.int64)

            start_seq_prefill = torch.tensor(
                start_seq_prefill, device=device, dtype=torch.int32
            )
            end_seq_prefill = torch.tensor(
                end_seq_prefill, device=device, dtype=torch.int32
            )
230
231
232
233
        else:
            input_ids = all_input_ids[0]
            position_ids = position_ids[0]

234
235
236
237
238
            past_present_indices = past_present_indices[0]

            start_seq_prefill = start_seq
            end_seq_prefill = end_seq

239
240
        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
        position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
241
242
243
        past_present_indices = torch.tensor(
            past_present_indices, device=device, dtype=torch.int64
        )
244

245
246
        if all_prefill_logprobs:
            prefill_head_indices = None
247
            prefill_next_token_indices = end_seq_prefill - 1
248
        elif no_prefill_logprobs:
249
            prefill_head_indices = end_seq_prefill - 1
250
251
252
253
254
255
256
257
258
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.tensor(
                torch.cat(prefill_head_indices), dtype=torch.int64, device=device
            )
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )

259
260
261
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
262
            requests_idx_mapping=requests_idx_mapping,
263
264
            input_ids=input_ids,
            position_ids=position_ids,
265
266
267
268
269
270
271
            past_present_indices=past_present_indices,
            start_seq=start_seq,
            end_seq=end_seq,
            start_seq_prefill=start_seq_prefill,
            end_seq_prefill=end_seq_prefill,
            start_seq_q=None,
            end_seq_q=None,
272
            max_seqlen=max_seqlen,
273
274
275
            prefill_head_indices=prefill_head_indices,
            prefill_next_token_indices=prefill_next_token_indices,
            prefill_cu_outlens=prefill_cu_outlens,
276
277
            past_key_values=None,
            input_lengths=input_lengths,
278
279
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
280
            all_input_ids=all_input_ids,
281
282
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
283
            stopping_criterias=stopping_criterias,
284
            max_tokens=cumulative_max_length,
285
286
        )

287
    @tracer.start_as_current_span("filter")
288
289
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
290
291
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
292
        if len(request_ids) == len(self):
293
294
            return self

295
        device = self.input_ids.device
296

297
        # Cumulative length
298
        cumulative_max_length = 0
299
300
301
302

        # New values after filtering
        requests_idx_mapping = {}

303
304
305
        # Used to index into tensors
        indices = []

306
307
308
309
310
        # past indices to keep
        past_indices = torch.zeros(
            self.past_key_values.shape[0], dtype=torch.bool, device=device
        )

311
        # Create on CPU to only move to GPU once instead of at every copy
312
313
314
315
        start_seq = torch.empty(len(request_ids), dtype=torch.int32)
        end_seq = torch.empty(len(request_ids), dtype=torch.int32)
        start_seq_q = self.start_seq_q[: len(request_ids)]
        end_seq_q = self.end_seq_q[: len(request_ids)]
316
317
        max_seqlen = 0

318
        requests = []
319
320
        all_input_ids = []

321
        input_lengths = []
322
323
        prefix_offsets = []
        read_offsets = []
324

325
326
        stopping_criterias = []

327
328
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
329
            indices.append(idx)
330
331
332
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
333
334
335
336

            # Get length
            request_input_length = self.input_lengths[idx]
            max_seqlen = max(max_seqlen, request_input_length)
337

338
339
340
            all_input_ids.append(self.all_input_ids[idx])

            input_lengths.append(request_input_length)
341
342
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
343

344
345
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
346

347
            remaining_tokens = (
348
349
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
350

351
352
353
354
355
356
357
358
359
360
            # Copy to tensor (CPU)
            start_seq[i] = cumulative_max_length
            end_seq[i] = cumulative_max_length + request_input_length

            # Set slice
            past_indices[
                self.start_seq[idx] : self.end_seq[idx] + remaining_tokens - 1
            ] = True

            cumulative_max_length += request_input_length + remaining_tokens - 1
361

362
363
364
365
366
        # Index into tensors
        input_ids = self.input_ids[indices]
        position_ids = self.position_ids[indices]
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
        next_token_chooser = self.next_token_chooser.filter(indices)
367
        past_key_values = self.past_key_values[past_indices]
368

369
        # Move to GPU now that we have the whole tensor
370
371
372
        start_seq = start_seq.to(device)
        end_seq = end_seq.to(device)
        past_present_indices = end_seq - 1
373

374
375
376
377
378
379
        return FlashCausalLMBatch(
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
380
381
382
383
384
385
386
            past_present_indices=past_present_indices,
            start_seq=start_seq,
            end_seq=end_seq,
            start_seq_prefill=None,
            end_seq_prefill=None,
            start_seq_q=start_seq_q,
            end_seq_q=end_seq_q,
387
            max_seqlen=max_seqlen,
388
389
390
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
391
392
            past_key_values=past_key_values,
            input_lengths=input_lengths,
393
394
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
395
396
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
397
            next_token_chooser=next_token_chooser,
398
            stopping_criterias=stopping_criterias,
399
            max_tokens=cumulative_max_length,
400
401
402
403
404
405
406
407
408
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

409
410
        total_batch_size = sum([len(b) for b in batches])

411
        dtype = batches[0].past_key_values.dtype
412
413
414
415
        device = batches[0].input_ids.device

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
416
417
418
419
        start_seq = batches[0].start_seq.new_empty(total_batch_size)
        end_seq = batches[0].end_seq.new_empty(total_batch_size)
        start_seq_q = torch.arange(
            0, total_batch_size, device=device, dtype=torch.int32
420
        )
421
        end_seq_q = start_seq_q + 1
422
423
424
        max_seqlen = 0
        past_key_values = []

425
426
427
        all_input_ids = []

        input_lengths = []
428
429
        prefix_offsets = []
        read_offsets = []
430

431
        next_token_chooser_parameters = []
432
433
        stopping_criterias = []

434
        # Cumulative length
435
        cumulative_batch_size = 0
436
        max_tokens = 0
437
        max_length = 0
438
439
440

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
441
442
443
444
445
446
447
448

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

449
450
451
452
453
454
455
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids

456
457
            start_seq[start_index:end_index] = batch.start_seq + max_tokens
            end_seq[start_index:end_index] = batch.end_seq + max_tokens
458

459
            max_seqlen = max(max_seqlen, batch.max_seqlen)
460
461
462

            all_input_ids.extend(batch.all_input_ids)

463
            input_lengths.extend(batch.input_lengths)
464
465
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
466

467
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
468
            stopping_criterias.extend(batch.stopping_criterias)
469
            past_key_values.append(batch.past_key_values)
470
471

            # Update
472
            cumulative_batch_size += len(batch)
473
            max_tokens += batch.max_tokens
474
475
476
477
478
479
480
481
482
483
484
485
            max_length = max(
                max_length,
                max(
                    input_length
                    + stopping_criteria.max_new_tokens
                    - stopping_criteria.current_tokens
                    for input_length, stopping_criteria in zip(
                        batch.input_lengths, batch.stopping_criterias
                    )
                ),
            )

486
487
488
        past_key_values = torch.cat(past_key_values, dim=0)
        past_present_indices = end_seq - 1

489
490
491
492
493
494
495
496
497
498
499
500
501
502
        all_input_ids_tensor = torch.zeros(
            (total_batch_size, max_length), dtype=torch.int64, device=device
        )

        cumulative_batch_size = 0
        for i, batch in enumerate(batches):
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)

            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]

            cumulative_batch_size += len(batch)
503

504
505
506
507
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
            next_token_chooser_parameters, dtype=dtype, device=device
        )

508
509
510
        return FlashCausalLMBatch(
            batch_id=batches[0].batch_id,
            requests=requests,
511
            requests_idx_mapping=requests_idx_mapping,
512
513
            input_ids=input_ids,
            position_ids=position_ids,
514
515
516
517
518
519
520
            past_present_indices=past_present_indices,
            start_seq=start_seq,
            end_seq=end_seq,
            start_seq_prefill=None,
            end_seq_prefill=None,
            start_seq_q=start_seq_q,
            end_seq_q=end_seq_q,
521
            max_seqlen=max_seqlen,
522
523
524
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
525
526
            past_key_values=past_key_values,
            input_lengths=input_lengths,
527
528
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
529
530
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
531
            next_token_chooser=next_token_chooser,
532
            stopping_criterias=stopping_criterias,
533
            max_tokens=max_tokens,
534
535
536
537
538
539
540
541
542
543
544
545
        )

    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
        model_cls: Type[PreTrainedModel],
        model_id: str,
        revision: Optional[str] = None,
546
        quantize: Optional[str] = None,
547
        trust_remote_code: bool = False,
548
549
550
    ):
        if torch.cuda.is_available():
            device = torch.device("cuda")
551
            dtype = torch.float16
552
553
554
555
        else:
            raise NotImplementedError("FlashCausalLM is only available on GPU")

        tokenizer = AutoTokenizer.from_pretrained(
556
557
558
559
560
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
561
        )
562
563
564
565
566
        model = model_cls.from_pretrained(
            model_id,
            revision=revision,
            torch_dtype=dtype,
            load_in_8bit=quantize == "bitsandbytes",
567
            trust_remote_code=trust_remote_code,
568
        ).to(device)
569
570

        super(FlashCausalLM, self).__init__(
571
            model=model,
572
573
574
575
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
576
577
578
579
580
581
582
583
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

    def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
        return self.tokenizer.decode(
584
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
585
586
587
588
589
590
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
591
592
593
594
        start_seq: torch.Tensor,
        end_seq: torch.Tensor,
        start_seq_q: Optional[torch.Tensor],
        end_seq_q: Optional[torch.Tensor],
595
        max_s: int,
596
        past_present_indices: torch.Tensor,
597
        past_key_values: Optional = None,
598
        pre_allocate_past_size: Optional[int] = None,
599
        lm_head_indices: Optional[torch.Tensor] = None,
600
601
602
603
604
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Model Forward
        return self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
605
606
607
608
            start_seq=start_seq,
            end_seq=end_seq,
            start_seq_q=start_seq_q,
            end_seq_q=end_seq_q,
609
            max_s=max_s,
610
            past_present_indices=past_present_indices,
611
            past_key_values=past_key_values,
612
            pre_allocate_past_size=pre_allocate_past_size,
613
            lm_head_indices=lm_head_indices,
614
615
616
617
618
619
        )

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
620
        prefill = batch.past_key_values is None
621
        prefill_logprobs = batch.prefill_next_token_indices is not None
622

623
        if prefill:
624
            # Ask to pre-allocate kv to its max size
625
626
627
628
            # == Sum over batch size (number of tokens + max_new_tokens) - batch size
            pre_allocate_past_size = batch.max_tokens
            start_seq = batch.start_seq_prefill
            end_seq = batch.end_seq_prefill
629
630
        else:
            pre_allocate_past_size = None
631
632
            start_seq = batch.start_seq
            end_seq = batch.end_seq
633

634
        out, present = self.forward(
635
636
            batch.input_ids,
            batch.position_ids,
637
638
639
640
            start_seq,
            end_seq,
            batch.start_seq_q,
            batch.end_seq_q,
641
            batch.max_seqlen,
642
            batch.past_present_indices,
643
            batch.past_key_values,
644
            pre_allocate_past_size,
645
            batch.prefill_head_indices,
646
647
        )

648
649
        if prefill:
            next_token_logits = (
650
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
651
652
653
654
655
656
657
658
            )
        else:
            next_token_logits = out

        next_input_ids, next_token_logprobs = batch.next_token_chooser(
            batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
        )

659
        if prefill:
660
            if len(batch) > 1 and prefill_logprobs:
661
662
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
663
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
664

665
666
667
            # Create batch.start_seq_q and batch.end_seq_q for decode
            batch.start_seq_q = torch.arange(
                0, len(batch), device=self.device, dtype=torch.int32
668
            )
669
            batch.end_seq_q = batch.start_seq_q + 1
670
            next_position_ids = batch.position_ids.new_empty(len(batch))
671
672
673
            # We do not need start_seq_prefill and end_seq_prefill anymore
            batch.start_seq_prefill = None
            batch.end_seq_prefill = None
674
675
676
677
        else:
            prefill_logprobs = None
            next_position_ids = batch.position_ids

678
679
680
681
682
        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
683
        stopped = True
684
685
686
687
688
689
690

        # Zipped iterator
        iterator = zip(
            batch.input_lengths,
            batch.all_input_ids,
        )

691
692
693
694
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

695
696
697
698
699
        # For each member of the batch
        for i, (
            input_length,
            all_input_ids,
        ) in enumerate(iterator):
700
            # Indexing metadata
701
702
703
            start_index = cumulative_length
            end_index = cumulative_length + input_length

704
            if prefill:
705
706
707
708
709
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]
                out_length = out_end_index - out_start_index

710
711
712
713
714
715
                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
716
717
718
719
720
721
722
723
724
725
                if prefill_logprobs:
                    if len(batch) > 1:
                        prefill_tokens_indices[
                            out_start_index : out_end_index - 1
                        ] = batch.input_ids[start_index + 1 : start_index + out_length]
                    else:
                        # Set prefill_tokens_indices to the correct slice
                        prefill_tokens_indices = batch.input_ids[
                            start_index + 1 : start_index + out_length
                        ]
726

727
            batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
728
729
730
731
732
733

            cumulative_length += input_length

        # Set values in batch
        batch.input_ids = next_input_ids
        batch.position_ids = next_position_ids + 1
734
735
        batch.past_present_indices = batch.end_seq
        batch.end_seq = batch.end_seq + 1
736

737
        if prefill and prefill_logprobs:
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
        next_token_ids = batch.input_ids.tolist()

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
754
755
            batch.prefix_offsets,
            batch.read_offsets,
756
757
758
            batch.stopping_criterias,
            batch.all_input_ids,
            batch.all_input_ids_tensor,
759
760
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
761
762
763
764
765
766
767
768
            next_token_ids,
            next_token_logprobs,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
769
770
            prefix_offset,
            read_offset,
771
772
773
            stopping_criteria,
            all_input_ids,
            all_input_ids_tensor,
774
775
            do_sample,
            seed,
776
777
778
            next_token_id,
            next_token_logprob,
        ) in enumerate(iterator):
779
            # Append next token to all tokens
780
            all_input_ids.append(next_token_id)
781
782

            # Generated token
783
            next_token_text, prefix_offset, read_offset = self.decode_token(
784
                all_input_ids,
785
786
                prefix_offset,
                read_offset,
787
788
789
790
            )

            # Evaluate stopping criteria
            stop, reason = stopping_criteria(
791
                next_token_id,
792
793
794
                next_token_text,
            )

795
            if not stop:
796
                stopped = False
797

798
799
800
801
802
803
804
805
806
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
                    output_text = self.decode(
                        all_input_ids[-stopping_criteria.current_tokens :]
                    )
                    generated_text = GeneratedText(
807
808
809
810
                        output_text,
                        stopping_criteria.current_tokens,
                        reason,
                        seed if do_sample else None,
811
812
813
814
815
                    )
                else:
                    generated_text = None

                # Prefill
816
817
818
819
                if prefill and request.prefill_logprobs:
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]

820
821
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
822
                        out_start_index : out_end_index - 1
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
                    prefill_tokens = PrefillTokens(
                        prefill_token_ids, request_prefill_logprobs, prefill_texts
                    )
                else:
                    prefill_tokens = None

                generation = Generation(
                    request.id,
                    prefill_tokens,
                    next_token_id,
                    next_token_logprob,
                    next_token_text,
                    next_token_id in self.all_special_ids,
                    generated_text,
844
845
                )

846
                generations.append(generation)
847

848
            new_input_length = input_length + 1
849

850
851
            # Update values
            batch.input_lengths[i] = new_input_length
852
853
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
854
855
            batch.all_input_ids[i] = all_input_ids

856
857
858
        batch.prefill_cu_outlens = None
        batch.prefill_head_indices = None
        batch.prefill_next_token_indices = None
859
        batch.max_seqlen = batch.max_seqlen + 1
860
        batch.past_key_values = present
861

862
863
        # No need to return a batch if we know that all requests stopped
        return generations, batch if not stopped else None