rw.py 2.76 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Optional, Tuple

from text_generation_server.models import CausalLM


class RW(CausalLM):
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
Nicolas Patry's avatar
Nicolas Patry committed
15
        speculator: Optional[str] = None,
16
        dtype: Optional[torch.dtype] = None,
17
18
        trust_remote_code: bool = False,
    ):
Nicolas Patry's avatar
Nicolas Patry committed
19
        if speculator:
OlivierDehaene's avatar
OlivierDehaene committed
20
21
            raise RuntimeError("Medusa decoding is not enabled for AutoModel")

22
23
        if torch.cuda.is_available():
            device = torch.device("cuda")
24
            dtype = torch.float16 if dtype is None else dtype
25
26
27
28
29
        else:
            if quantize:
                raise ValueError("quantization is not available on CPU")

            device = torch.device("cpu")
Wang, Yi's avatar
Wang, Yi committed
30
            dtype = torch.float32 if dtype is None else dtype
31
32
33
34
35
36
37
38
39
40
41
42

        tokenizer = AutoTokenizer.from_pretrained(
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
        )
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            revision=revision,
            torch_dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
43
44
45
46
47
            device_map=(
                "auto"
                if torch.cuda.is_available() and torch.cuda.device_count() > 1
                else None
            ),
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
            load_in_8bit=quantize == "bitsandbytes",
            trust_remote_code=trust_remote_code,
        )
        if torch.cuda.is_available() and torch.cuda.device_count() == 1:
            model = model.cuda()

        if tokenizer.pad_token_id is None:
            if model.config.pad_token_id is not None:
                tokenizer.pad_token_id = model.config.pad_token_id
            elif model.config.eos_token_id is not None:
                tokenizer.pad_token_id = model.config.eos_token_id
            elif tokenizer.eos_token_id is not None:
                tokenizer.pad_token_id = tokenizer.eos_token_id
            else:
                tokenizer.add_special_tokens({"pad_token": "[PAD]"})

        super(CausalLM, self).__init__(
drbh's avatar
drbh committed
65
            model_id=model_id,
66
67
68
69
70
71
72
73
74
            model=model,
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
        )

    def forward(
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
75
    ):
76
        # Model Forward
77
        outputs, speculative_logits = self.model.forward(
78
79
80
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
81
            use_cache=True,
82
        )
83
84

        return outputs.logits, speculative_logits, outputs.past_key_values