flash_causal_lm.py 17.6 KB
Newer Older
1
2
3
4
5
6
7
8
import torch
import torch.distributed

from torch.nn import functional as F

from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
9
from typing import Optional, Tuple, List, Type, Union, Dict
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
    NextTokenChooser,
    StoppingCriteria,
    Sampling,
)

tracer = trace.get_tracer(__name__)


@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
32
33
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
34
35

    # Decoder values
36
37
    input_ids: List[torch.Tensor]
    position_ids: List[torch.Tensor]
38
    # cumulative sequence lengths
39
    cu_seqlens: List[int]
40
    max_seqlen: int
41
    past_key_values: Optional[List[torch.Tensor]]
42
43
44
45
46
47
48

    # All tokens
    all_input_ids: List[List[int]]
    all_input_ids_tensor: List[torch.Tensor]

    # Lengths of all generations present in the batch
    input_lengths: List[int]
49
50
    offsets: List[Optional[int]]
    token_offsets: List[Optional[int]]
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

    # Generation helpers
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]

    def to_pb(self) -> generate_pb2.Batch:
        return generate_pb2.Batch(
            id=self.batch_id, requests=self.requests, size=len(self)
        )

    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
67
    ) -> "FlashCausalLMBatch":
68
69
70
71
72
73
        input_ids = []
        position_ids = []
        cu_seqlens = [0]
        max_seqlen = 0

        input_lengths = []
74
75
        offsets = []
        token_offsets = []
76
77
        all_input_ids = []
        all_input_ids_tensor = []
78
        requests_idx_mapping = {}
79
80
81
82
83
84
85
86

        next_token_choosers = []
        stopping_criterias = []

        # Cumulative length
        cumulative_length = 0

        # Parse batch
87
88
89
90
        for i, r in enumerate(pb.requests):
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

91
92
93
            tokenized_input = tokenizer(
                r.inputs, truncation=True, max_length=r.truncate
            )["input_ids"]
94

95
96
97
            input_length = len(tokenized_input)
            max_seqlen = max(max_seqlen, input_length)
            input_lengths.append(input_length)
98

99
100
            offsets.append(None)
            token_offsets.append(None)
101
102
103
104
105
106
            all_input_ids.append(tokenized_input)

            tokenized_input = torch.tensor(tokenized_input, device=device)
            input_ids.append(tokenized_input)

            # Position ids
107
108
109
            position_ids.append(
                torch.arange(0, input_length, dtype=torch.int32, device=device)
            )
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

            # Add cumulative lengths of all previous inputs
            cu_seqlens.append(cumulative_length + input_length)

            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
            all_input_ids_tensor.append(
                F.pad(tokenized_input, (0, stopping_criteria.max_new_tokens))
            )

            # Update
            cumulative_length += input_length

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
129
            requests_idx_mapping=requests_idx_mapping,
130
131
132
133
134
135
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
            past_key_values=None,
            input_lengths=input_lengths,
136
137
            offsets=offsets,
            token_offsets=token_offsets,
138
139
140
141
142
143
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
        )

144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    @tracer.start_as_current_span("filter")
    def filter(self, requests: List[generate_pb2.Request]) -> "FlashCausalLMBatch":
        if len(requests) == 0:
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
        if len(requests) == len(self):
            return self

        # Cumulative length
        cumulative_length = 0

        # New values after filtering
        requests_idx_mapping = {}

        input_ids = []
        position_ids = []
        cu_seqlens = [0]
        max_seqlen = 0
        past_key_values = []

        all_input_ids = []
        all_input_ids_tensor = []

167
        input_lengths = []
168
169
        offsets = []
        token_offsets = []
170

171
172
173
        next_token_choosers = []
        stopping_criterias = []

174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
        for i, r in enumerate(requests):
            idx = self.requests_idx_mapping[r.id]
            requests_idx_mapping[r.id] = i

            # Get length
            request_input_length = self.input_lengths[idx]

            input_ids.append(self.input_ids[idx])
            position_ids.append(self.position_ids[idx])
            cu_seqlens.append(cumulative_length + request_input_length)
            max_seqlen = max(max_seqlen, request_input_length)
            past_key_values.append(self.past_key_values[idx])

            all_input_ids.append(self.all_input_ids[idx])
            all_input_ids_tensor.append(self.all_input_ids_tensor[idx])

            input_lengths.append(request_input_length)
            offsets.append(self.offsets[idx])
            token_offsets.append(self.token_offsets[idx])

            next_token_choosers.append(self.next_token_choosers[idx])
            stopping_criterias.append(self.stopping_criterias[idx])

            cumulative_length += request_input_length

        return FlashCausalLMBatch(
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
            offsets=offsets,
            token_offsets=token_offsets,
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

224
225
        input_ids = []
        position_ids = []
226
        cu_seqlens = [0]
227
228
229
        max_seqlen = 0
        past_key_values = []

230
231
232
233
234
235
236
237
238
239
        all_input_ids = []
        all_input_ids_tensor = []

        input_lengths = []
        offsets = []
        token_offsets = []

        next_token_choosers = []
        stopping_criterias = []

240
        # Cumulative length
241
242
        cumulative_batch_size = 0
        cumulative_length = 0
243
244
245

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

            input_ids.extend(batch.input_ids)
            position_ids.extend(batch.position_ids)
            # Add cumulative lengths of all previous inputs
            cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
            max_seqlen = max(max_seqlen, batch.max_seqlen)
            past_key_values.extend(batch.past_key_values)

            all_input_ids.extend(batch.all_input_ids)
            all_input_ids_tensor.extend(batch.all_input_ids_tensor)

264
            input_lengths.extend(batch.input_lengths)
265
266
            offsets.extend(batch.offsets)
            token_offsets.extend(batch.token_offsets)
267

268
269
270
271
272
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

            # Update
            cumulative_length += batch.cu_seqlens[-1]
273
            cumulative_batch_size += len(batch)
274
275
276
277

        return FlashCausalLMBatch(
            batch_id=batches[0].batch_id,
            requests=requests,
278
            requests_idx_mapping=requests_idx_mapping,
279
280
281
282
283
284
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
285
286
            offsets=offsets,
            token_offsets=token_offsets,
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
        )

    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
        model_cls: Type[PreTrainedModel],
        model_id: str,
        revision: Optional[str] = None,
303
304
        quantize: bool = False,
        decode_buffer: int = 3,
305
306
307
308
309
310
311
312
    ):
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        else:
            raise NotImplementedError("FlashCausalLM is only available on GPU")

        tokenizer = AutoTokenizer.from_pretrained(
313
            model_id, revision=revision, padding_side="left", truncation_side="left"
314
315
316
317
318
319
        )
        self.model = (
            model_cls.from_pretrained(
                model_id,
                revision=revision,
                torch_dtype=dtype,
320
                load_in_8bit=quantize,
321
322
            )
            .eval()
323
            .to(device)
324
325
326
        )

        super(FlashCausalLM, self).__init__(
327
            tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

    def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
        return self.tokenizer.decode(
            generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        cu_seqlens: torch.Tensor,
        max_s: int,
        past_key_values: Optional = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Model Forward
        return self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
            max_s=max_s,
            past_key_values=past_key_values,
        )

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
        # Shortcut when batch_size == 1
        if len(batch) == 1:
            input_ids = batch.input_ids[0].view(-1)
            past_key_values = (
                batch.past_key_values[0] if batch.past_key_values is not None else None
            )
        else:
            # Concatenate tensors
            input_ids = torch.cat(batch.input_ids).view(-1)
            past_key_values = (
                torch.cat(batch.past_key_values, dim=1)
                if batch.past_key_values is not None
                else None
            )

        # Concatenate when prefill, torch.tensor when decode
        position_ids = (
            torch.tensor(batch.position_ids, device=self.device)
            if batch.past_key_values is not None
            else torch.cat(batch.position_ids)
        )
        cu_seqlens = torch.tensor(
            batch.cu_seqlens, device=self.device, dtype=torch.int32
        )
384
385

        out, present = self.forward(
386
            input_ids,
387
388
389
            position_ids,
            cu_seqlens,
            batch.max_seqlen,
390
            past_key_values,
391
392
        )

393
394
395
        # Initialize past_key_values in prefill
        if batch.past_key_values is None:
            batch.past_key_values = [None] * len(batch)
396
397
398
399
400
401

        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
402
        stopped = True
403
404
405
406
407

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
408
409
            batch.offsets,
            batch.token_offsets,
410
411
412
413
414
415
416
417
418
419
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
            batch.all_input_ids_tensor,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
420
421
            offset,
            token_offset,
422
423
424
425
426
427
428
429
430
            next_token_chooser,
            stopping_criteria,
            all_input_ids,
            all_input_ids_tensor,
        ) in enumerate(iterator):
            # Indexing metadata
            start_index = cumulative_length
            end_index = cumulative_length + input_length

431
432
            prefill = stopping_criteria.current_tokens == 0
            if prefill:
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
                # Prefill mode
                # out is of shape [cumulative_sequence_lengths, vocab_size]
                logits = out[start_index:end_index]
            else:
                # Decode mode
                # out is of shape [batch_size, vocab_size]
                logits = out[i].unsqueeze(0)

            # Select next token
            next_token_id, logprobs = next_token_chooser(
                all_input_ids_tensor[None, :input_length], logits
            )
            next_token_id_squeezed = next_token_id.squeeze()
            next_token_id_item = next_token_id_squeezed.item()

            # Append next token to all tokens
            all_input_ids.append(next_token_id_item)
            all_input_ids_tensor[input_length] = next_token_id_item

            # Generated token
            next_token_logprob = logprobs[-1, next_token_id_item]
454
455
456
457
            next_token_text, offset, token_offset = self.decode_token(
                all_input_ids,
                offset,
                token_offset,
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
            )

            # Evaluate stopping criteria
            stop, reason = stopping_criteria(
                next_token_id_item,
                next_token_text,
            )

            if stop:
                # Decode generated tokens
                output_text = self.decode(
                    all_input_ids[-stopping_criteria.current_tokens :]
                )
                # Get seed
                if isinstance(next_token_chooser.choice, Sampling):
                    seed = next_token_chooser.choice.seed
                else:
                    seed = None

                generated_text = GeneratedText(
                    output_text, stopping_criteria.current_tokens, reason, seed
                )
480
481
482
483
484

                # CAUTION: generation will be stopped so no need to pad
                # This will make the next forward crash if the request does not get filtered
                new_input_length = input_length
                past = present[:, start_index:end_index]
485
            else:
486
                stopped = False
487
488
                generated_text = None

489
490
491
492
                # Pad present for next iter attention
                new_input_length = input_length + 1
                past = torch.nn.functional.pad(
                    present[:, start_index:end_index], (0, 0, 0, 0, 0, 0, 0, 1)
493
494
495
                )

            # Prefill
496
            if prefill:
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
                # Remove generated token to only have prefill and add nan for first prompt token
                prefill_logprobs = [float("nan")] + logprobs.gather(
                    1, all_input_ids_tensor[1:input_length].unsqueeze(1)
                ).squeeze(1)[:-1].tolist()
                prefill_token_ids = all_input_ids[:-1]
                prefill_texts = self.tokenizer.batch_decode(
                    prefill_token_ids,
                    clean_up_tokenization_spaces=False,
                    skip_special_tokens=False,
                )
                prefill_tokens = PrefillTokens(
                    prefill_token_ids, prefill_logprobs, prefill_texts
                )
            else:
                prefill_tokens = None

            generation = Generation(
                request.id,
                prefill_tokens,
                next_token_id_item,
                next_token_logprob,
                next_token_text,
                next_token_id_item in self.all_special_ids,
                generated_text,
            )

            generations.append(generation)
            cumulative_length += input_length

526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
            # Update values
            batch.input_ids[i] = next_token_id
            batch.position_ids[i] = input_length
            batch.input_lengths[i] = new_input_length
            batch.offsets[i] = offset
            batch.token_offsets[i] = token_offset
            batch.all_input_ids[i] = all_input_ids
            batch.all_input_ids_tensor[i] = all_input_ids_tensor
            batch.max_seqlen = max(batch.max_seqlen, new_input_length)
            batch.past_key_values[i] = past
            # Cumulative sum
            batch.cu_seqlens[(i + 1)] = batch.cu_seqlens[i] + new_input_length

        # No need to return a batch if we know that all requests stopped
        return generations, batch if not stopped else None