"vscode:/vscode.git/clone" did not exist on "84725ec7e3195e0d9ef2dece5ff4f8d8db5fb472"
causal_lm.py 29 KB
Newer Older
1
import torch
2
import time
3

4
from dataclasses import dataclass
5
from opentelemetry import trace
6
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
7
from typing import Optional, Tuple, List, Type, Dict
8

9
from text_generation_server.models import Model
Daniël de Kok's avatar
Daniël de Kok committed
10
from text_generation_server.utils.chunks import concat_text_chunks
11
from text_generation_server.utils.tokens import batch_top_tokens
12
13
from text_generation_server.models.types import (
    Batch,
Nicolas Patry's avatar
Nicolas Patry committed
14
    Tokens,
15
16
17
18
19
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
20

21
22
tracer = trace.get_tracer(__name__)

23
24

@dataclass
25
class CausalLMBatch(Batch):
26
27
    batch_id: int
    requests: List[generate_pb2.Request]
28
    requests_idx_mapping: Dict[int, int]
OlivierDehaene's avatar
OlivierDehaene committed
29
30
31
32

    # Decoder values
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
33
    position_ids: torch.Tensor
OlivierDehaene's avatar
OlivierDehaene committed
34
35
36
    past_key_values: Optional[List[Tuple]]

    # All tokens
37
    all_input_ids: List[torch.Tensor]
OlivierDehaene's avatar
OlivierDehaene committed
38
39
40

    # Lengths of all generations present in the batch
    input_lengths: List[int]
41
42
    prefix_offsets: List[int]
    read_offsets: List[int]
OlivierDehaene's avatar
OlivierDehaene committed
43
44

    # Generation helpers
45
46
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
47
48
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
OlivierDehaene's avatar
OlivierDehaene committed
49
50

    # Metadata used for padding
51
    max_input_length: int
52
    padding_right_offset: int
53

54
55
56
    # Maximum number of tokens this batch will grow to
    max_tokens: int

57
58
59
    # Past metadata
    keys_head_dim_last: bool = True

60
61
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
62
            id=self.batch_id,
63
            request_ids=[r.id for r in self.requests],
64
            size=len(self),
65
            max_tokens=self.max_tokens,
66
67
68
69
        )

    @classmethod
    def from_pb(
70
71
72
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
73
        dtype: torch.dtype,
74
        device: torch.device,
75
76
77
78
    ) -> "CausalLMBatch":
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
79
        top_n_tokens = []
80
81
        prefix_offsets = []
        read_offsets = []
82
        requests_idx_mapping = {}
83
84

        # Parse batch
85
        max_truncation = 0
86
        padding_right_offset = 0
87
        max_decode_tokens = 0
88
89
        for i, r in enumerate(pb.requests):
            requests_idx_mapping[r.id] = i
Daniël de Kok's avatar
Daniël de Kok committed
90
91
            inputs.append(concat_text_chunks(r.input_chunks.chunks))

drbh's avatar
drbh committed
92
93
94
            next_token_choosers.append(
                NextTokenChooser.from_pb(r.parameters, device, tokenizer)
            )
95
96
97
98
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
99
            top_n_tokens.append(r.top_n_tokens)
100
            max_truncation = max(max_truncation, r.truncate)
101
            max_decode_tokens += stopping_criteria.max_new_tokens
102
103
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
104
105
            )

OlivierDehaene's avatar
OlivierDehaene committed
106
        tokenized_inputs = tokenizer(
107
108
109
            inputs,
            return_tensors="pt",
            padding=True,
110
            return_token_type_ids=False,
111
112
            truncation=True,
            max_length=max_truncation,
113
        ).to(device)
114
115
        for _ in pb.requests:
            input_len = tokenized_inputs["input_ids"].shape[1]
116
            prefix_offsets.append(input_len - 5)
117
            read_offsets.append(input_len)
118

119
120
121
        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

122
123
124
        input_ids = tokenized_inputs["input_ids"]
        # Allocate maximum attention_mask
        attention_mask = input_ids.new_zeros(
125
            (pb.size, max_input_length + padding_right_offset)
126
127
        )
        # Copy tokenizer attention_mask into fully allocated attention_mask
128
        attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
129

130
131
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
132
        all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
Nicolas Patry's avatar
Nicolas Patry committed
133
134
135
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
136

137
        max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
138

139
140
141
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
142
            requests_idx_mapping=requests_idx_mapping,
143
144
            input_ids=input_ids,
            attention_mask=attention_mask,
145
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
146
            past_key_values=None,
147
            all_input_ids=list(all_input_ids),
148
            input_lengths=input_lengths.tolist(),
149
150
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
151
152
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
153
154
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
155
            max_input_length=max_input_length.item(),
156
            padding_right_offset=padding_right_offset,
157
            max_tokens=max_tokens,
158
159
        )

160
    @tracer.start_as_current_span("filter")
161
162
    def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
        if len(request_ids) == 0:
163
            raise ValueError("Batch must have at least one request")
164
        if len(request_ids) == len(self):
165
166
167
168
169
170
            return self

        keep_indices = []

        # New values after filtering
        requests_idx_mapping = {}
171
        requests = []
172
        input_lengths = []
173
174
        prefix_offsets = []
        read_offsets = []
175
176
177
178
179
        all_input_ids = []
        max_input_length = 0

        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
180
        top_n_tokens = []
181

182
        total_remaining_decode_tokens = 0
183
184
        new_padding_right_offset = 0

185
186
187
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
            requests_idx_mapping[request_id] = i
188
189
            keep_indices.append(idx)

190
            requests.append(self.requests[idx])
191
192
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
193
194
195
196
197
198
199
            all_input_ids.append(self.all_input_ids[idx])

            request_input_length = self.input_lengths[idx]
            input_lengths.append(request_input_length)
            max_input_length = max(max_input_length, request_input_length)

            next_token_choosers.append(self.next_token_choosers[idx])
200
201
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
202
            top_n_tokens.append(self.top_n_tokens[idx])
203
            remaining_decode_tokens = (
204
205
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
206
207
208
209
            total_remaining_decode_tokens += remaining_decode_tokens
            new_padding_right_offset = max(
                new_padding_right_offset, remaining_decode_tokens
            )
210
211
212
213

        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
        input_ids = self.input_ids[keep_indices]
        position_ids = self.position_ids[keep_indices]
214
215
        self.attention_mask = self.attention_mask[
            keep_indices,
216
217
218
219
            -(self.padding_right_offset + max_input_length) : (
                self.attention_mask.shape[1] - self.padding_right_offset
            )
            + new_padding_right_offset,
220
221
        ]

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
        # Ensure that past_key_values tensors can be updated in-place
        if type(self.past_key_values[0]) == tuple:
            self.past_key_values = [list(layer) for layer in self.past_key_values]

        # Update tensors in-place to allow incremental garbage collection
        past_kv_length = max_input_length - 1
        for layer in self.past_key_values:
            past_keys, past_values = layer
            if len(past_keys.shape) == 3:
                # Force past to be of dim [self_size, num_heads, ...] for easy indexing
                past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
                past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
            if self.keys_head_dim_last:
                layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]
            else:
                layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]
            del past_keys
            layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
            del past_values

Nicolas Patry's avatar
Nicolas Patry committed
242
        top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
243
        max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
244

245
246
247
248
249
250
        self.requests = requests
        self.requests_idx_mapping = requests_idx_mapping
        self.input_ids = input_ids
        self.position_ids = position_ids
        self.all_input_ids = all_input_ids
        self.input_lengths = input_lengths
251
252
        self.prefix_offsets = prefix_offsets
        self.read_offsets = read_offsets
253
254
        self.next_token_choosers = next_token_choosers
        self.stopping_criterias = stopping_criterias
Nicolas Patry's avatar
Nicolas Patry committed
255
256
        self.top_n_tokens = top_n_tokens
        self.top_n_tokens_tensor = top_n_tokens_tensor
257
258
        self.max_input_length = max_input_length
        self.padding_right_offset = new_padding_right_offset
259
        self.max_tokens = max_tokens
260
261

        return self
262

263
    @classmethod
264
    @tracer.start_as_current_span("concatenate")
265
266
    def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
        # Used for padding
267
        total_batch_size = 0
268
        max_input_length = 0
269
270
        padding_right_offset = 0
        for batch in batches:
271
            total_batch_size += len(batch)
272
            max_input_length = max(max_input_length, batch.max_input_length)
273
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
274
275
276

        # Batch attributes
        requests = []
277
        requests_idx_mapping = {}
OlivierDehaene's avatar
OlivierDehaene committed
278
        input_lengths = []
279
280
        prefix_offsets = []
        read_offsets = []
281
282
283
        all_input_ids = []
        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
284
        top_n_tokens = []
285
        max_tokens = 0
286

OlivierDehaene's avatar
OlivierDehaene committed
287
288
289
        # Batch tensors
        input_ids = None
        attention_mask = None
290
        position_ids = None
OlivierDehaene's avatar
OlivierDehaene committed
291
        past_key_values = []
Nicolas Patry's avatar
Nicolas Patry committed
292
        top_n_tokens_tensor = None
OlivierDehaene's avatar
OlivierDehaene committed
293

294
295
296
297
298
        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
OlivierDehaene's avatar
OlivierDehaene committed
299
            input_lengths.extend(batch.input_lengths)
300
301
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
302
303
304
            all_input_ids.extend(batch.all_input_ids)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)
Nicolas Patry's avatar
Nicolas Patry committed
305
            top_n_tokens.extend(batch.top_n_tokens)
306

307
308
309
310
311
312
313
            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + start_index

314
            # Slicing end index for this batch
315
            end_index = start_index + len(batch)
316
317

            # We only concatenate batches that did at least one step
318
319
            if batch.past_key_values is None:
                raise ValueError("only concatenate prefilled batches")
320

OlivierDehaene's avatar
OlivierDehaene committed
321
322
323
324
            # Create empty tensor
            # input_ids is always of shape [batch_size, 1]
            # We do not need to pad it
            if input_ids is None:
325
                input_ids = batch.input_ids.new_empty((total_batch_size, 1))
OlivierDehaene's avatar
OlivierDehaene committed
326
327
328
329
330
            # Copy to correct indices
            input_ids[start_index:end_index] = batch.input_ids

            # Create padded tensor
            if attention_mask is None:
331
                attention_mask = batch.attention_mask.new_zeros(
332
                    (total_batch_size, max_input_length + padding_right_offset),
333
334
                )

Nicolas Patry's avatar
Nicolas Patry committed
335
336
337
338
339
340
            if top_n_tokens_tensor is None:
                top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
                    total_batch_size,
                )
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor

341
            # We need to slice the attention mask to remove padding from previous steps
342
            # and to remove unused allocated space
343
            left_offset = max_input_length - batch.max_input_length
344
            batch_left_offset = (
345
                batch.attention_mask.shape[1]
346
                - batch.max_input_length
347
                - batch.padding_right_offset
348
            )
OlivierDehaene's avatar
OlivierDehaene committed
349
            attention_mask[
350
351
352
353
354
355
                start_index:end_index,
                left_offset:-padding_right_offset,
            ] = batch.attention_mask[
                :,
                batch_left_offset : -batch.padding_right_offset,
            ]
356

357
358
359
360
361
362
            # Create empty tensor
            # position_ids is always of shape [batch_size, 1]
            if position_ids is None:
                position_ids = batch.position_ids.new_empty((total_batch_size, 1))
            position_ids[start_index:end_index] = batch.position_ids

363
364
365
366
367
368
            # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
            # BLOOM Keys:   [batch_size * num_heads, head_dim, seq_length]
            # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
            # And ensure that we can update tensors in-place
            if type(batch.past_key_values[0]) == tuple:
                batch.past_key_values = [
369
370
                    [t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
                    for layer in batch.past_key_values
371
                ]
372
            elif len(batch.past_key_values[0][0].shape) == 3:
373
374
375
376
                for layer in batch.past_key_values:
                    for k, t in enumerate(layer):
                        layer[k] = t.view(len(batch), -1, *t.shape[-2:])

377
378
379
380
            # Add eventual padding tokens that were added while concatenating
            max_tokens += batch.max_tokens + (
                max_input_length - batch.max_input_length
            ) * len(batch)
381

382
383
            start_index = end_index

384
385
386
387
388
389
390
391
392
        first_past_kvs = batches[0].past_key_values
        _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape

        padded_past_values_shape = (
            total_batch_size,
            num_heads,
            max_input_length - 1,
            head_dim,
        )
393

394
395
396
397
398
399
400
401
402
403
        if batches[0].keys_head_dim_last:
            padded_past_keys_shape = padded_past_values_shape
        else:
            # seq_length is last for BLOOM
            padded_past_keys_shape = (
                total_batch_size,
                num_heads,
                head_dim,
                max_input_length - 1,
            )
404

405
406
407
408
409
410
411
412
413
414
415
416
417
418
        # Iterate over attention layers
        # Concatenate past key values layer by layer to allow incremental garbage collection
        for j in range(len(first_past_kvs)):
            padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape)
            start_index = 0
            for batch in batches:
                past_keys = batch.past_key_values[j][0]
                # Clear reference to the original tensor
                batch.past_key_values[j][0] = None

                # Slicing end index for this batch
                end_index = start_index + len(batch)
                # We slice the keys to remove the padding from previous batches
                past_seq_len = batch.max_input_length - 1
419
                if batch.keys_head_dim_last:
drbh's avatar
drbh committed
420
421
422
                    padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
                        past_keys[:, :, -past_seq_len:, :]
                    )
423
                else:
424
                    # BLOOM case
drbh's avatar
drbh committed
425
426
427
                    padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
                        past_keys[:, :, :, -past_seq_len:]
                    )
428
429
430
431
                del past_keys

                start_index = end_index

432
433
434
            padded_past_values = first_past_kvs[j][1].new_zeros(
                padded_past_values_shape
            )
435
436
437
438
439
440
441
442
443
444
            start_index = 0
            for batch in batches:
                past_values = batch.past_key_values[j][1]
                # Clear reference to the original tensor
                batch.past_key_values[j][1] = None

                # Slicing end index for this batch
                end_index = start_index + len(batch)
                # We slice the past values to remove the padding from previous batches
                past_seq_len = batch.max_input_length - 1
drbh's avatar
drbh committed
445
446
447
                padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
                    past_values[:, :, -past_seq_len:, :]
                )
448
449
                del past_values

450
                # Update values
451
452
453
                start_index = end_index

            past_key_values.append([padded_past_keys, padded_past_values])
454
455
456
457

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
458
            requests_idx_mapping=requests_idx_mapping,
459
            input_ids=input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
460
            attention_mask=attention_mask,
461
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
462
            past_key_values=past_key_values,
463
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
464
            input_lengths=input_lengths,
465
466
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
467
468
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
469
470
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
471
            max_input_length=max_input_length,
472
            padding_right_offset=padding_right_offset,
473
            keys_head_dim_last=batches[0].keys_head_dim_last,
474
            max_tokens=max_tokens,
475
        )
476

477
478
479
    def __len__(self):
        return len(self.requests)

480
481

class CausalLM(Model):
482
483
484
485
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
486
        quantize: Optional[str] = None,
Nicolas Patry's avatar
Nicolas Patry committed
487
        speculator: Optional[str] = None,
488
        dtype: Optional[torch.dtype] = None,
489
        trust_remote_code: bool = False,
490
    ):
Nicolas Patry's avatar
Nicolas Patry committed
491
492
        if speculator:
            raise RuntimeError("Speculator decoding is not enabled for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
493

494
495
        if torch.cuda.is_available():
            device = torch.device("cuda")
496
            dtype = torch.float16 if dtype is None else dtype
497
        else:
498
499
500
            if quantize:
                raise ValueError("quantization is not available on CPU")

501
            device = torch.device("cpu")
Wang, Yi's avatar
Wang, Yi committed
502
            dtype = torch.float32 if dtype is None else dtype
503

504
        tokenizer = AutoTokenizer.from_pretrained(
505
506
507
508
509
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
510
        )
511
        model = AutoModelForCausalLM.from_pretrained(
512
            model_id,
513
            revision=revision,
514
            torch_dtype=dtype,
drbh's avatar
drbh committed
515
516
517
518
519
            device_map=(
                "auto"
                if torch.cuda.is_available() and torch.cuda.device_count() > 1
                else None
            ),
520
            load_in_8bit=quantize == "bitsandbytes",
521
            trust_remote_code=trust_remote_code,
522
        )
OlivierDehaene's avatar
OlivierDehaene committed
523
524
525
526
527
        if (
            torch.cuda.is_available()
            and torch.cuda.device_count() == 1
            and quantize != "bitsandbytes"
        ):
528
529
            model = model.cuda()

530
531
532
533
534
535
536
537
538
539
        if tokenizer.pad_token_id is None:
            if model.config.pad_token_id is not None:
                tokenizer.pad_token_id = model.config.pad_token_id
            elif model.config.eos_token_id is not None:
                tokenizer.pad_token_id = model.config.eos_token_id
            elif tokenizer.eos_token_id is not None:
                tokenizer.pad_token_id = tokenizer.eos_token_id
            else:
                tokenizer.add_special_tokens({"pad_token": "[PAD]"})

540
        super(CausalLM, self).__init__(
drbh's avatar
drbh committed
541
            model_id=model_id,
542
            model=model,
543
544
545
546
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
547
548
549
550
551
        )

    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return CausalLMBatch
552

553
554
    def decode(self, generated_ids: List[int]) -> str:
        return self.tokenizer.decode(
555
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
556
557
        )

558
    def forward(
559
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
560
561
562
    ) -> Tuple[
        torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]
    ]:
563
        # Model Forward
564
565
566
567
568
569
570
571
572
573
574
        kwargs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "past_key_values": past_key_values,
            "use_cache": True,
            "return_dict": True,
        }
        if self.has_position_ids:
            kwargs["position_ids"] = position_ids

        outputs = self.model.forward(**kwargs)
575
576
577
578
579
        if isinstance(outputs, tuple):
            outputs, speculative_logits = outputs
        else:
            speculative_logits = None
        return outputs.logits, speculative_logits, outputs.past_key_values
580

581
    @tracer.start_as_current_span("generate_token")
582
583
    def generate_token(
        self, batch: CausalLMBatch
584
585
    ) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]:
        start = time.time_ns()
586
587
588
        # slice the attention mask to the correct shape
        attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]

589
        logits, speculative_logits, past = self.forward(
590
            batch.input_ids,
591
            attention_mask,
592
593
            batch.position_ids,
            batch.past_key_values,
594
595
        )

596
597
        # Results
        generations: List[Generation] = []
598
        stopped = True
599

Nicolas Patry's avatar
Nicolas Patry committed
600
601
        # Speculation is not active for causal
        accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
Nicolas Patry's avatar
Nicolas Patry committed
602
603
604
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
            batch.top_n_tokens,
            batch.top_n_tokens_tensor,
605
            torch.log_softmax(logits[:, -1], -1),
Nicolas Patry's avatar
Nicolas Patry committed
606
            accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
607
608
        )

609
610
        start_decode = time.time_ns()

611
612
613
        # Zipped iterator
        iterator = zip(
            batch.requests,
OlivierDehaene's avatar
OlivierDehaene committed
614
            batch.input_lengths,
615
616
            batch.prefix_offsets,
            batch.read_offsets,
617
618
619
620
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
Nicolas Patry's avatar
Nicolas Patry committed
621
622
623
            batch.top_n_tokens,
            batch_top_token_ids,
            batch_top_token_logprobs,
624
625
626
627
628
629
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
630
631
            prefix_offset,
            read_offset,
632
633
634
            logits,
            next_token_chooser,
            stopping_criteria,
OlivierDehaene's avatar
OlivierDehaene committed
635
            all_input_ids,
Nicolas Patry's avatar
Nicolas Patry committed
636
637
638
            top_n_tokens,
            top_token_ids,
            top_token_logprobs,
639
640
        ) in enumerate(iterator):
            # Select next token
641
            next_token_id, logprobs = next_token_chooser(
642
                all_input_ids.view(1, -1), logits[-1:, :]
643
            )
644
645

            # Append next token to all tokens
646
            all_input_ids = torch.cat([all_input_ids, next_token_id])
OlivierDehaene's avatar
OlivierDehaene committed
647
648
            new_input_length = input_length + 1

649
650
651
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
652
653
            next_token_text, prefix_offset, read_offset = self.decode_token(
                all_input_ids[:, 0], prefix_offset, read_offset
654
            )
655
656

            # Evaluate stopping criteria
657
            stop, reason = stopping_criteria(
658
659
                next_token_id_squeezed,
                next_token_text,
660
            )
661

662
            if not stop:
663
                stopped = False
664

665
666
667
668
669
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
670
671
                    output_text, _, _ = self.decode_token(
                        all_input_ids[:, 0],
OlivierDehaene's avatar
OlivierDehaene committed
672
673
674
675
676
677
                        prefix_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens
                        - 1,
                        read_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens,
                        skip_special_tokens=True,
678
679
680
681
682
683
684
685
686
687
688
689
690
691
                    )
                    # Get seed
                    if isinstance(next_token_chooser.choice, Sampling):
                        seed = next_token_chooser.choice.seed
                    else:
                        seed = None

                    generated_text = GeneratedText(
                        output_text, stopping_criteria.current_tokens, reason, seed
                    )
                else:
                    generated_text = None

                # Prefill
692
                if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
693
694
695
696
697
698
699
700
701
702
703
704
                    # Remove generated token to only have prefill and add nan for first prompt token
                    prefill_logprobs = [float("nan")] + torch.log_softmax(
                        logits, -1
                    ).gather(1, all_input_ids[1:]).squeeze(1)[
                        -new_input_length:-1
                    ].tolist()
                    prefill_token_ids = all_input_ids[-new_input_length:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
705
                    prefill_tokens = Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
706
707
708
709
                        prefill_token_ids,
                        prefill_logprobs,
                        prefill_texts,
                        is_special=[],
710
711
712
713
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
714
                if top_n_tokens > 0:
Nicolas Patry's avatar
Nicolas Patry committed
715
                    all_top_tokens = []
drbh's avatar
drbh committed
716
                    for top_token_ids, top_token_logprobs in zip(
717
718
                        top_token_ids, top_token_logprobs
                    ):
Nicolas Patry's avatar
Nicolas Patry committed
719
720
721
722
723
724
                        toptoken_texts = self.tokenizer.batch_decode(
                            top_token_ids,
                            clean_up_tokenization_spaces=False,
                            skip_special_tokens=False,
                        )
                        special_toptokens = [
725
726
                            token_id in self.all_special_ids
                            for token_id in top_token_ids
Nicolas Patry's avatar
Nicolas Patry committed
727
728
729
730
731
732
733
734
735
                        ]
                        top_tokens = Tokens(
                            top_token_ids,
                            top_token_logprobs,
                            toptoken_texts,
                            special_toptokens,
                        )
                        all_top_tokens.append(top_tokens)
                    top_tokens = all_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
736
737
738
                else:
                    top_tokens = None

739
740
741
                generation = Generation(
                    request.id,
                    prefill_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
742
                    Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
743
744
745
746
747
                        [next_token_id_squeezed],
                        [next_token_logprob],
                        [next_token_text],
                        [next_token_id_squeezed.item() in self.all_special_ids],
                    ),
748
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
749
                    top_tokens,
750
751
                )

752
                generations.append(generation)
753

754
            # Update values
drbh's avatar
drbh committed
755
756
757
            batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
                next_token_id_squeezed.item()
            )
758
759
760
            batch.input_ids[i, 0] = next_token_id
            batch.all_input_ids[i] = all_input_ids
            batch.input_lengths[i] = new_input_length
761
762
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
763
764
            batch.max_input_length = max(batch.max_input_length, new_input_length)

765
        # We finished all generations in the batch; there is no next batch
766
        if stopped:
767
768
769
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, None, (forward_ns, decode_ns)
770

771
772
        # Slice unused values from prefill
        batch.input_ids = batch.input_ids[:, :1]
773

774
        # Update attention_mask as we added a new token to input_ids
775
776
777
        batch.attention_mask[:, -batch.padding_right_offset] = 1
        # Decrease right offset
        batch.padding_right_offset -= 1
778

779
        # Update position_ids
780
781
782
783
784
        batch.position_ids = batch.position_ids[:, -1:] + 1

        # Update past key values
        batch.past_key_values = past

785
786
787
        forward_ns = start_decode - start
        decode_ns = time.time_ns() - start_decode
        return generations, batch, (forward_ns, decode_ns)