bloom.py 1.27 KB
Newer Older
1
2
3
import torch
import torch.distributed

4
from typing import Optional, Type
5

6
7
8
from transformers import (
    PreTrainedTokenizerBase,
)
9

10
11
12
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
13
14


15
16
17
class BloomCausalLMBatch(CausalLMBatch):
    @classmethod
    def from_pb(
18
19
20
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
21
        dtype: torch.dtype,
22
        device: torch.device,
23
    ) -> "CausalLMBatch":
24
        batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
25
26
27
28
        batch.keys_head_dim_last = False
        return batch


29
30
31
32
class BLOOMSharded(CausalLM):
    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return BloomCausalLMBatch
33

34
35
36
    def forward(
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
    ):
37
        outputs, speculative_logits = self.model.forward(
38
39
            input_ids=input_ids,
            attention_mask=attention_mask,
40
            position_ids=position_ids,
41
42
43
44
            past_key_values=past_key_values,
            use_cache=True,
        )

45
        logits = outputs.logits
46
        return logits, speculative_logits, outputs.past_key_values