causal_lm.py 32.3 KB
Newer Older
1
import torch
2
import time
3
import torch.distributed
4

5
from dataclasses import dataclass
6
from opentelemetry import trace
7
8
9
10
11
12
from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForCausalLM,
    PreTrainedTokenizerBase,
)
13
from typing import Optional, Tuple, List, Type, Dict
14

15
16
17
18
19
from text_generation_server.utils import (
    initialize_torch_distributed,
    weight_files,
    Weights,
)
20
from text_generation_server.models import Model
Daniël de Kok's avatar
Daniël de Kok committed
21
from text_generation_server.utils.chunks import concat_text_chunks
22
from text_generation_server.utils.import_utils import SYSTEM
23
from text_generation_server.utils.quantization import get_loader
24
from text_generation_server.utils.tokens import batch_top_tokens
25
26
from text_generation_server.models.types import (
    Batch,
Nicolas Patry's avatar
Nicolas Patry committed
27
    Tokens,
28
29
30
31
32
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
33

34
35
tracer = trace.get_tracer(__name__)

36
37

@dataclass
38
class CausalLMBatch(Batch):
39
40
    batch_id: int
    requests: List[generate_pb2.Request]
41
    requests_idx_mapping: Dict[int, int]
OlivierDehaene's avatar
OlivierDehaene committed
42
43
44
45

    # Decoder values
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
46
    position_ids: torch.Tensor
OlivierDehaene's avatar
OlivierDehaene committed
47
48
49
    past_key_values: Optional[List[Tuple]]

    # All tokens
50
    all_input_ids: List[torch.Tensor]
OlivierDehaene's avatar
OlivierDehaene committed
51
52
53

    # Lengths of all generations present in the batch
    input_lengths: List[int]
54
55
    prefix_offsets: List[int]
    read_offsets: List[int]
OlivierDehaene's avatar
OlivierDehaene committed
56
57

    # Generation helpers
58
59
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
60
61
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
OlivierDehaene's avatar
OlivierDehaene committed
62
63

    # Metadata used for padding
64
    max_input_length: int
65
    padding_right_offset: int
66

67
68
69
    # Maximum number of tokens this batch will grow to
    max_tokens: int

70
71
72
    # Past metadata
    keys_head_dim_last: bool = True

73
74
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
75
            id=self.batch_id,
76
            request_ids=[r.id for r in self.requests],
77
            size=len(self),
78
            max_tokens=self.max_tokens,
79
80
81
82
        )

    @classmethod
    def from_pb(
83
84
85
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
86
        dtype: torch.dtype,
87
        device: torch.device,
88
89
90
91
    ) -> "CausalLMBatch":
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
92
        top_n_tokens = []
93
94
        prefix_offsets = []
        read_offsets = []
95
        requests_idx_mapping = {}
96
97

        # Parse batch
98
        max_truncation = 0
99
        padding_right_offset = 0
100
        max_decode_tokens = 0
101
102
        for i, r in enumerate(pb.requests):
            requests_idx_mapping[r.id] = i
Daniël de Kok's avatar
Daniël de Kok committed
103
104
            inputs.append(concat_text_chunks(r.input_chunks.chunks))

drbh's avatar
drbh committed
105
106
107
            next_token_choosers.append(
                NextTokenChooser.from_pb(r.parameters, device, tokenizer)
            )
108
109
110
111
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
112
            top_n_tokens.append(r.top_n_tokens)
113
            max_truncation = max(max_truncation, r.truncate)
114
            max_decode_tokens += stopping_criteria.max_new_tokens
115
116
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
117
118
            )

OlivierDehaene's avatar
OlivierDehaene committed
119
        tokenized_inputs = tokenizer(
120
121
122
            inputs,
            return_tensors="pt",
            padding=True,
123
            return_token_type_ids=False,
124
125
            truncation=True,
            max_length=max_truncation,
126
        ).to(device)
127
128
        for _ in pb.requests:
            input_len = tokenized_inputs["input_ids"].shape[1]
129
            prefix_offsets.append(input_len - 5)
130
            read_offsets.append(input_len)
131

132
133
134
        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

135
136
137
        input_ids = tokenized_inputs["input_ids"]
        # Allocate maximum attention_mask
        attention_mask = input_ids.new_zeros(
138
            (pb.size, max_input_length + padding_right_offset)
139
140
        )
        # Copy tokenizer attention_mask into fully allocated attention_mask
141
        attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
142

143
144
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
145
        all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
Nicolas Patry's avatar
Nicolas Patry committed
146
147
148
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
149

150
        max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
151

152
153
154
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
155
            requests_idx_mapping=requests_idx_mapping,
156
157
            input_ids=input_ids,
            attention_mask=attention_mask,
158
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
159
            past_key_values=None,
160
            all_input_ids=list(all_input_ids),
161
            input_lengths=input_lengths.tolist(),
162
163
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
164
165
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
166
167
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
168
            max_input_length=max_input_length.item(),
169
            padding_right_offset=padding_right_offset,
170
            max_tokens=max_tokens,
171
172
        )

173
    @tracer.start_as_current_span("filter")
174
175
    def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
        if len(request_ids) == 0:
176
            raise ValueError("Batch must have at least one request")
177
        if len(request_ids) == len(self):
178
179
180
181
182
183
            return self

        keep_indices = []

        # New values after filtering
        requests_idx_mapping = {}
184
        requests = []
185
        input_lengths = []
186
187
        prefix_offsets = []
        read_offsets = []
188
189
190
191
192
        all_input_ids = []
        max_input_length = 0

        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
193
        top_n_tokens = []
194

195
        total_remaining_decode_tokens = 0
196
197
        new_padding_right_offset = 0

198
199
200
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
            requests_idx_mapping[request_id] = i
201
202
            keep_indices.append(idx)

203
            requests.append(self.requests[idx])
204
205
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
206
207
208
209
210
211
212
            all_input_ids.append(self.all_input_ids[idx])

            request_input_length = self.input_lengths[idx]
            input_lengths.append(request_input_length)
            max_input_length = max(max_input_length, request_input_length)

            next_token_choosers.append(self.next_token_choosers[idx])
213
214
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
215
            top_n_tokens.append(self.top_n_tokens[idx])
216
            remaining_decode_tokens = (
217
218
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
219
220
221
222
            total_remaining_decode_tokens += remaining_decode_tokens
            new_padding_right_offset = max(
                new_padding_right_offset, remaining_decode_tokens
            )
223
224
225
226

        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
        input_ids = self.input_ids[keep_indices]
        position_ids = self.position_ids[keep_indices]
227
228
        self.attention_mask = self.attention_mask[
            keep_indices,
229
230
231
232
            -(self.padding_right_offset + max_input_length) : (
                self.attention_mask.shape[1] - self.padding_right_offset
            )
            + new_padding_right_offset,
233
234
        ]

235
        # Ensure that past_key_values tensors can be updated in-place
236
        if type(self.past_key_values[0]) is tuple:
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
            self.past_key_values = [list(layer) for layer in self.past_key_values]

        # Update tensors in-place to allow incremental garbage collection
        past_kv_length = max_input_length - 1
        for layer in self.past_key_values:
            past_keys, past_values = layer
            if len(past_keys.shape) == 3:
                # Force past to be of dim [self_size, num_heads, ...] for easy indexing
                past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
                past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
            if self.keys_head_dim_last:
                layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]
            else:
                layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]
            del past_keys
            layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
            del past_values

Nicolas Patry's avatar
Nicolas Patry committed
255
        top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
256
        max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
257

258
259
260
261
262
263
        self.requests = requests
        self.requests_idx_mapping = requests_idx_mapping
        self.input_ids = input_ids
        self.position_ids = position_ids
        self.all_input_ids = all_input_ids
        self.input_lengths = input_lengths
264
265
        self.prefix_offsets = prefix_offsets
        self.read_offsets = read_offsets
266
267
        self.next_token_choosers = next_token_choosers
        self.stopping_criterias = stopping_criterias
Nicolas Patry's avatar
Nicolas Patry committed
268
269
        self.top_n_tokens = top_n_tokens
        self.top_n_tokens_tensor = top_n_tokens_tensor
270
271
        self.max_input_length = max_input_length
        self.padding_right_offset = new_padding_right_offset
272
        self.max_tokens = max_tokens
273
274

        return self
275

276
    @classmethod
277
    @tracer.start_as_current_span("concatenate")
278
279
    def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
        # Used for padding
280
        total_batch_size = 0
281
        max_input_length = 0
282
283
        padding_right_offset = 0
        for batch in batches:
284
            total_batch_size += len(batch)
285
            max_input_length = max(max_input_length, batch.max_input_length)
286
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
287
288
289

        # Batch attributes
        requests = []
290
        requests_idx_mapping = {}
OlivierDehaene's avatar
OlivierDehaene committed
291
        input_lengths = []
292
293
        prefix_offsets = []
        read_offsets = []
294
295
296
        all_input_ids = []
        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
297
        top_n_tokens = []
298
        max_tokens = 0
299

OlivierDehaene's avatar
OlivierDehaene committed
300
301
302
        # Batch tensors
        input_ids = None
        attention_mask = None
303
        position_ids = None
OlivierDehaene's avatar
OlivierDehaene committed
304
        past_key_values = []
Nicolas Patry's avatar
Nicolas Patry committed
305
        top_n_tokens_tensor = None
OlivierDehaene's avatar
OlivierDehaene committed
306

307
308
309
310
311
        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
OlivierDehaene's avatar
OlivierDehaene committed
312
            input_lengths.extend(batch.input_lengths)
313
314
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
315
316
317
            all_input_ids.extend(batch.all_input_ids)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)
Nicolas Patry's avatar
Nicolas Patry committed
318
            top_n_tokens.extend(batch.top_n_tokens)
319

320
321
322
323
324
325
326
            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + start_index

327
            # Slicing end index for this batch
328
            end_index = start_index + len(batch)
329
330

            # We only concatenate batches that did at least one step
331
332
            if batch.past_key_values is None:
                raise ValueError("only concatenate prefilled batches")
333

OlivierDehaene's avatar
OlivierDehaene committed
334
335
336
337
            # Create empty tensor
            # input_ids is always of shape [batch_size, 1]
            # We do not need to pad it
            if input_ids is None:
338
                input_ids = batch.input_ids.new_empty((total_batch_size, 1))
OlivierDehaene's avatar
OlivierDehaene committed
339
340
341
342
343
            # Copy to correct indices
            input_ids[start_index:end_index] = batch.input_ids

            # Create padded tensor
            if attention_mask is None:
344
                attention_mask = batch.attention_mask.new_zeros(
345
                    (total_batch_size, max_input_length + padding_right_offset),
346
347
                )

Nicolas Patry's avatar
Nicolas Patry committed
348
349
350
351
352
353
            if top_n_tokens_tensor is None:
                top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
                    total_batch_size,
                )
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor

354
            # We need to slice the attention mask to remove padding from previous steps
355
            # and to remove unused allocated space
356
            left_offset = max_input_length - batch.max_input_length
357
            batch_left_offset = (
358
                batch.attention_mask.shape[1]
359
                - batch.max_input_length
360
                - batch.padding_right_offset
361
            )
OlivierDehaene's avatar
OlivierDehaene committed
362
            attention_mask[
363
364
365
366
367
368
                start_index:end_index,
                left_offset:-padding_right_offset,
            ] = batch.attention_mask[
                :,
                batch_left_offset : -batch.padding_right_offset,
            ]
369

370
371
372
373
374
375
            # Create empty tensor
            # position_ids is always of shape [batch_size, 1]
            if position_ids is None:
                position_ids = batch.position_ids.new_empty((total_batch_size, 1))
            position_ids[start_index:end_index] = batch.position_ids

376
377
378
379
            # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
            # BLOOM Keys:   [batch_size * num_heads, head_dim, seq_length]
            # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
            # And ensure that we can update tensors in-place
380
            if isinstance(batch.past_key_values[0], tuple):
381
                batch.past_key_values = [
382
383
                    [t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
                    for layer in batch.past_key_values
384
                ]
385
            elif len(batch.past_key_values[0][0].shape) == 3:
386
387
388
389
                for layer in batch.past_key_values:
                    for k, t in enumerate(layer):
                        layer[k] = t.view(len(batch), -1, *t.shape[-2:])

390
391
392
393
            # Add eventual padding tokens that were added while concatenating
            max_tokens += batch.max_tokens + (
                max_input_length - batch.max_input_length
            ) * len(batch)
394

395
396
            start_index = end_index

397
398
399
400
401
402
403
404
405
        first_past_kvs = batches[0].past_key_values
        _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape

        padded_past_values_shape = (
            total_batch_size,
            num_heads,
            max_input_length - 1,
            head_dim,
        )
406

407
408
409
410
411
412
413
414
415
416
        if batches[0].keys_head_dim_last:
            padded_past_keys_shape = padded_past_values_shape
        else:
            # seq_length is last for BLOOM
            padded_past_keys_shape = (
                total_batch_size,
                num_heads,
                head_dim,
                max_input_length - 1,
            )
417

418
419
420
421
422
423
424
425
426
427
428
429
430
431
        # Iterate over attention layers
        # Concatenate past key values layer by layer to allow incremental garbage collection
        for j in range(len(first_past_kvs)):
            padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape)
            start_index = 0
            for batch in batches:
                past_keys = batch.past_key_values[j][0]
                # Clear reference to the original tensor
                batch.past_key_values[j][0] = None

                # Slicing end index for this batch
                end_index = start_index + len(batch)
                # We slice the keys to remove the padding from previous batches
                past_seq_len = batch.max_input_length - 1
432
                if batch.keys_head_dim_last:
drbh's avatar
drbh committed
433
434
435
                    padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
                        past_keys[:, :, -past_seq_len:, :]
                    )
436
                else:
437
                    # BLOOM case
drbh's avatar
drbh committed
438
439
440
                    padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
                        past_keys[:, :, :, -past_seq_len:]
                    )
441
442
443
444
                del past_keys

                start_index = end_index

445
446
447
            padded_past_values = first_past_kvs[j][1].new_zeros(
                padded_past_values_shape
            )
448
449
450
451
452
453
454
455
456
457
            start_index = 0
            for batch in batches:
                past_values = batch.past_key_values[j][1]
                # Clear reference to the original tensor
                batch.past_key_values[j][1] = None

                # Slicing end index for this batch
                end_index = start_index + len(batch)
                # We slice the past values to remove the padding from previous batches
                past_seq_len = batch.max_input_length - 1
drbh's avatar
drbh committed
458
459
460
                padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
                    past_values[:, :, -past_seq_len:, :]
                )
461
462
                del past_values

463
                # Update values
464
465
466
                start_index = end_index

            past_key_values.append([padded_past_keys, padded_past_values])
467
468
469
470

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
471
            requests_idx_mapping=requests_idx_mapping,
472
            input_ids=input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
473
            attention_mask=attention_mask,
474
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
475
            past_key_values=past_key_values,
476
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
477
            input_lengths=input_lengths,
478
479
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
480
481
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
482
483
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
484
            max_input_length=max_input_length,
485
            padding_right_offset=padding_right_offset,
486
            keys_head_dim_last=batches[0].keys_head_dim_last,
487
            max_tokens=max_tokens,
488
        )
489

490
491
492
    def __len__(self):
        return len(self.requests)

493

494
@dataclass
495
class CausalLMBatchKeysLast(CausalLMBatch):
496
497
498
    keys_head_dim_last: bool = False


499
class CausalLM(Model):
500
501
502
    def __init__(
        self,
        model_id: str,
503
504
505
506
507
508
509
510
511
512
513
        model_class,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
        speculator: Optional[str] = None,
        dtype: Optional[torch.dtype] = None,
        default_dtype=torch.float16,
        trust_remote_code: bool = False,
        tokenizer_class=AutoTokenizer,
        config_class=AutoConfig,
        batch_class=CausalLMBatch,
    ):
Nicolas Patry's avatar
Nicolas Patry committed
514
        self.quantize = quantize
515
516
517
518
519
        self.batch_class = batch_class
        self.process_group, rank, world_size = initialize_torch_distributed()
        if torch.cuda.is_available():
            device = torch.device(f"cuda:{rank}")
            dtype = default_dtype if dtype is None else dtype
520
521
522
        elif hasattr(torch, "xpu") and torch.xpu.is_available():
            device = torch.device(f"xpu:{rank}")
            dtype = default_dtype if dtype is None else dtype
523
        elif SYSTEM == "ipex":
524
525
526
            device = torch.device("cpu")
            # Float16 doesn't exist on target.
            dtype = torch.bfloat16 if dtype is None else dtype
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
        else:
            device = torch.device("cpu")
            dtype = torch.float32 if dtype is None else dtype

        tokenizer = tokenizer_class.from_pretrained(
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
        )

        config = config_class.from_pretrained(
            model_id,
            revision=revision,
            trust_remote_code=trust_remote_code,
        )
        config.quantize = quantize
        config.speculator = speculator
        if tokenizer.pad_token_id is None:
547
548
549
550
551
552
            if config.pad_token_id is not None:
                tokenizer.pad_token_id = config.pad_token_id
            elif config.eos_token_id is not None:
                tokenizer.pad_token_id = config.eos_token_id
            elif tokenizer.eos_token_id is not None:
                tokenizer.pad_token_id = tokenizer.eos_token_id
553
554

        torch.distributed.barrier(group=self.process_group)
555
556
557
        weights_loader = get_loader(
            quantize=quantize, model_id=model_id, revision=revision
        )
558
559
        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
        weights = Weights(
560
561
562
563
564
            filenames,
            device=device,
            dtype=dtype,
            process_group=self.process_group,
            weights_loader=weights_loader,
565
566
        )

567
568
        prefix = ""
        model = model_class(prefix, config, weights)
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585

        torch.distributed.barrier(group=self.process_group)
        super().__init__(
            model_id=model_id,
            model=model,
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
            rank=rank,
            world_size=world_size,
        )

    @classmethod
    def fallback(
        cls,
        model_id: str,
586
        revision: Optional[str] = None,
587
        quantize: Optional[str] = None,
Nicolas Patry's avatar
Nicolas Patry committed
588
        speculator: Optional[str] = None,
589
        dtype: Optional[torch.dtype] = None,
590
        trust_remote_code: bool = False,
591
    ):
Nicolas Patry's avatar
Nicolas Patry committed
592
593
        if speculator:
            raise RuntimeError("Speculator decoding is not enabled for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
594

595
        device_count = 0
596
597
        if torch.cuda.is_available():
            device = torch.device("cuda")
598
599
600
601
602
            device_count = torch.cuda.device_count()
            dtype = torch.float16 if dtype is None else dtype
        elif hasattr(torch, "xpu") and torch.xpu.is_available():
            device = torch.device("xpu")
            device_count = torch.xpu.device_count()
603
            dtype = torch.float16 if dtype is None else dtype
604
        else:
605
606
607
            if quantize:
                raise ValueError("quantization is not available on CPU")

608
            device = torch.device("cpu")
Wang, Yi's avatar
Wang, Yi committed
609
            dtype = torch.float32 if dtype is None else dtype
610

611
        tokenizer = AutoTokenizer.from_pretrained(
612
613
614
615
616
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
617
        )
618
        model = AutoModelForCausalLM.from_pretrained(
619
            model_id,
620
            revision=revision,
621
            torch_dtype=dtype,
Nicolas Patry's avatar
Nicolas Patry committed
622
            device_map=("auto" if device_count > 1 else None),
623
            load_in_8bit=quantize == "bitsandbytes",
624
            trust_remote_code=trust_remote_code,
625
        )
Nicolas Patry's avatar
Nicolas Patry committed
626
        if device_count == 1 and quantize != "bitsandbytes":
627
            model = model.to(device)
628

629
630
631
632
633
634
635
636
637
638
        if tokenizer.pad_token_id is None:
            if model.config.pad_token_id is not None:
                tokenizer.pad_token_id = model.config.pad_token_id
            elif model.config.eos_token_id is not None:
                tokenizer.pad_token_id = model.config.eos_token_id
            elif tokenizer.eos_token_id is not None:
                tokenizer.pad_token_id = tokenizer.eos_token_id
            else:
                tokenizer.add_special_tokens({"pad_token": "[PAD]"})

639
640
641
642
643
644
        self = cls.__new__(
            cls,
        )
        self.batch_class = CausalLMBatch
        super().__init__(
            self,
drbh's avatar
drbh committed
645
            model_id=model_id,
646
            model=model,
647
648
649
650
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
651
        )
652
        self.quantize = quantize
653
        return self
654
655
656

    @property
    def batch_type(self) -> Type[CausalLMBatch]:
657
        return self.batch_class
658

659
    def forward(
660
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
661
662
663
    ) -> Tuple[
        torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]
    ]:
664
        # Model Forward
665
666
667
668
669
670
671
672
673
674
675
        kwargs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "past_key_values": past_key_values,
            "use_cache": True,
            "return_dict": True,
        }
        if self.has_position_ids:
            kwargs["position_ids"] = position_ids

        outputs = self.model.forward(**kwargs)
676
677
678
679
680
        if isinstance(outputs, tuple):
            outputs, speculative_logits = outputs
        else:
            speculative_logits = None
        return outputs.logits, speculative_logits, outputs.past_key_values
681

682
    @tracer.start_as_current_span("generate_token")
683
684
    def generate_token(
        self, batch: CausalLMBatch
685
686
    ) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]:
        start = time.time_ns()
687
688
689
        # slice the attention mask to the correct shape
        attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]

690
        logits, speculative_logits, past = self.forward(
691
            batch.input_ids,
692
            attention_mask,
693
694
            batch.position_ids,
            batch.past_key_values,
695
696
        )

697
698
        # Results
        generations: List[Generation] = []
699
        stopped = True
700

Nicolas Patry's avatar
Nicolas Patry committed
701
702
        # Speculation is not active for causal
        accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
Nicolas Patry's avatar
Nicolas Patry committed
703
704
705
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
            batch.top_n_tokens,
            batch.top_n_tokens_tensor,
706
            torch.log_softmax(logits[:, -1], -1),
Nicolas Patry's avatar
Nicolas Patry committed
707
            accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
708
709
        )

710
711
        start_decode = time.time_ns()

712
713
714
        # Zipped iterator
        iterator = zip(
            batch.requests,
OlivierDehaene's avatar
OlivierDehaene committed
715
            batch.input_lengths,
716
717
            batch.prefix_offsets,
            batch.read_offsets,
718
719
720
721
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
Nicolas Patry's avatar
Nicolas Patry committed
722
723
724
            batch.top_n_tokens,
            batch_top_token_ids,
            batch_top_token_logprobs,
725
726
727
728
729
730
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
731
732
            prefix_offset,
            read_offset,
733
734
735
            logits,
            next_token_chooser,
            stopping_criteria,
OlivierDehaene's avatar
OlivierDehaene committed
736
            all_input_ids,
Nicolas Patry's avatar
Nicolas Patry committed
737
738
739
            top_n_tokens,
            top_token_ids,
            top_token_logprobs,
740
741
        ) in enumerate(iterator):
            # Select next token
742
            next_token_id, logprobs = next_token_chooser(
743
                all_input_ids.view(1, -1), logits[-1:, :]
744
            )
745
746

            # Append next token to all tokens
747
            all_input_ids = torch.cat([all_input_ids, next_token_id])
OlivierDehaene's avatar
OlivierDehaene committed
748
749
            new_input_length = input_length + 1

750
751
752
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
753
754
            next_token_text, prefix_offset, read_offset = self.decode_token(
                all_input_ids[:, 0], prefix_offset, read_offset
755
            )
756
757

            # Evaluate stopping criteria
758
            stop, reason = stopping_criteria(
759
760
                next_token_id_squeezed,
                next_token_text,
761
            )
762

763
            if not stop:
764
                stopped = False
765

766
767
768
769
770
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
771
772
                    output_text, _, _ = self.decode_token(
                        all_input_ids[:, 0],
OlivierDehaene's avatar
OlivierDehaene committed
773
774
775
776
777
778
                        prefix_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens
                        - 1,
                        read_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens,
                        skip_special_tokens=True,
779
780
781
782
783
784
785
786
787
788
789
790
791
792
                    )
                    # Get seed
                    if isinstance(next_token_chooser.choice, Sampling):
                        seed = next_token_chooser.choice.seed
                    else:
                        seed = None

                    generated_text = GeneratedText(
                        output_text, stopping_criteria.current_tokens, reason, seed
                    )
                else:
                    generated_text = None

                # Prefill
793
                if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
794
795
796
797
798
799
800
801
802
803
804
805
                    # Remove generated token to only have prefill and add nan for first prompt token
                    prefill_logprobs = [float("nan")] + torch.log_softmax(
                        logits, -1
                    ).gather(1, all_input_ids[1:]).squeeze(1)[
                        -new_input_length:-1
                    ].tolist()
                    prefill_token_ids = all_input_ids[-new_input_length:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
806
                    prefill_tokens = Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
807
808
809
810
                        prefill_token_ids,
                        prefill_logprobs,
                        prefill_texts,
                        is_special=[],
811
812
813
814
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
815
                if top_n_tokens > 0:
Nicolas Patry's avatar
Nicolas Patry committed
816
                    all_top_tokens = []
drbh's avatar
drbh committed
817
                    for top_token_ids, top_token_logprobs in zip(
818
819
                        top_token_ids, top_token_logprobs
                    ):
Nicolas Patry's avatar
Nicolas Patry committed
820
821
822
823
824
825
                        toptoken_texts = self.tokenizer.batch_decode(
                            top_token_ids,
                            clean_up_tokenization_spaces=False,
                            skip_special_tokens=False,
                        )
                        special_toptokens = [
826
827
                            token_id in self.all_special_ids
                            for token_id in top_token_ids
Nicolas Patry's avatar
Nicolas Patry committed
828
829
830
831
832
833
834
835
836
                        ]
                        top_tokens = Tokens(
                            top_token_ids,
                            top_token_logprobs,
                            toptoken_texts,
                            special_toptokens,
                        )
                        all_top_tokens.append(top_tokens)
                    top_tokens = all_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
837
838
839
                else:
                    top_tokens = None

840
841
842
                generation = Generation(
                    request.id,
                    prefill_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
843
                    Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
844
845
846
847
848
                        [next_token_id_squeezed],
                        [next_token_logprob],
                        [next_token_text],
                        [next_token_id_squeezed.item() in self.all_special_ids],
                    ),
849
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
850
                    top_tokens,
851
852
                )

853
                generations.append(generation)
854

855
            # Update values
drbh's avatar
drbh committed
856
857
858
            batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
                next_token_id_squeezed.item()
            )
859
860
861
            batch.input_ids[i, 0] = next_token_id
            batch.all_input_ids[i] = all_input_ids
            batch.input_lengths[i] = new_input_length
862
863
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
864
865
            batch.max_input_length = max(batch.max_input_length, new_input_length)

866
        # We finished all generations in the batch; there is no next batch
867
        if stopped:
868
869
870
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, None, (forward_ns, decode_ns)
871

872
873
        # Slice unused values from prefill
        batch.input_ids = batch.input_ids[:, :1]
874

875
        # Update attention_mask as we added a new token to input_ids
876
877
878
        batch.attention_mask[:, -batch.padding_right_offset] = 1
        # Decrease right offset
        batch.padding_right_offset -= 1
879

880
        # Update position_ids
881
882
883
884
885
        batch.position_ids = batch.position_ids[:, -1:] + 1

        # Update past key values
        batch.past_key_values = past

886
887
888
        forward_ns = start_decode - start
        decode_ns = time.time_ns() - start_decode
        return generations, batch, (forward_ns, decode_ns)