santacoder.py 2.16 KB
Newer Older
1
2
3
import torch
import torch.distributed

4
from typing import Optional, List
5
6
from transformers import AutoTokenizer, AutoModelForCausalLM

7
from text_generation_server.models import CausalLM
8
9
10
11
12
13
14
15
16

FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>"
FIM_SUFFIX = "<fim-suffix>"
FIM_PAD = "<fim-pad>"
EOD = "<|endoftext|>"


class SantaCoder(CausalLM):
17
18
19
20
21
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
22
        dtype: Optional[torch.dtype] = None,
23
        trust_remote_code: bool = False,
24
    ):
25
26
        if torch.cuda.is_available():
            device = torch.device("cuda")
27
            dtype = torch.float16 if dtype is None else dtype
28
29
30
31
32
33
34
        else:
            if quantize:
                raise ValueError("quantization is not available on CPU")

            device = torch.device("cpu")
            dtype = torch.float32

35
        tokenizer = AutoTokenizer.from_pretrained(
36
37
38
39
40
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
41
        )
42
43
44
45
46
47
48
49
50
51
52
53
54
        tokenizer.add_special_tokens(
            {
                "additional_special_tokens": [
                    EOD,
                    FIM_PREFIX,
                    FIM_MIDDLE,
                    FIM_SUFFIX,
                    FIM_PAD,
                ],
                "pad_token": EOD,
            }
        )

55
56
57
58
59
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            revision=revision,
            torch_dtype=dtype,
            load_in_8bit=quantize == "bitsandbytes",
60
            trust_remote_code=trust_remote_code,
61
        ).to(device)
62
63

        super(CausalLM, self).__init__(
64
            model=model,
65
66
67
68
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
69
70
71
72
73
        )

    def decode(self, generated_ids: List[int]) -> str:
        # Do not skip special tokens as they are used for custom parsing rules of the generated text
        return self.tokenizer.decode(
74
            generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
75
        )