llm.py 6.34 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
API and LLM warpper class for running LLMs locally

Usage:

import os
model_path = os.environ.get("ZH_MODEL_PATH")
model_name = "chatglm2"
colossal_api = ColossalAPI(model_name, model_path)
llm = ColossalLLM(n=1, api=colossal_api)
TEST_PROMPT_CHATGLM="续写文章:惊蛰一过,春寒加剧。先是料料峭峭,继而雨季开始,"
logger.info(llm(TEST_PROMPT_CHATGLM, max_new_tokens=100), verbose=True)

"""
from typing import Any, List, Mapping, Optional

import torch
from colossalqa.local.utils import get_response, post_http_request
from colossalqa.mylogging import get_logger
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from transformers import AutoModelForCausalLM, AutoTokenizer

logger = get_logger()


class ColossalAPI:
    """
    API for calling LLM.generate
    """

    __instances = dict()

    def __init__(self, model_type: str, model_path: str, ckpt_path: str = None) -> None:
        """
        Configurate model
        """
        if model_type + model_path + (ckpt_path or "") in ColossalAPI.__instances:
            return
        else:
            ColossalAPI.__instances[model_type + model_path + (ckpt_path or "")] = self
        self.model_type = model_type
        self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True)

        if ckpt_path is not None:
            state_dict = torch.load(ckpt_path)
            self.model.load_state_dict(state_dict)
        self.model.to(torch.cuda.current_device())

        # Configurate tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

        self.model.eval()

    @staticmethod
    def get_api(model_type: str, model_path: str, ckpt_path: str = None):
        if model_type + model_path + (ckpt_path or "") in ColossalAPI.__instances:
            return ColossalAPI.__instances[model_type + model_path + (ckpt_path or "")]
        else:
            return ColossalAPI(model_type, model_path, ckpt_path)

    def generate(self, input: str, **kwargs) -> str:
        """
        Generate response given the prompt
        Args:
            input: input string
            **kwargs: language model keyword type arguments, such as top_k, top_p, temperature, max_new_tokens...
        Returns:
            output: output string
        """
        if self.model_type in ["chatglm", "chatglm2"]:
            inputs = {
                k: v.to(torch.cuda.current_device()) for k, v in self.tokenizer(input, return_tensors="pt").items()
            }
        else:
            inputs = {
                "input_ids": self.tokenizer(input, return_tensors="pt")["input_ids"].to(torch.cuda.current_device())
            }

        output = self.model.generate(**inputs, **kwargs)
        output = output.cpu()
        prompt_len = inputs["input_ids"].size(1)
        response = output[0, prompt_len:]
        output = self.tokenizer.decode(response, skip_special_tokens=True)
        return output


class VllmAPI:
    def __init__(self, host: str = "localhost", port: int = 8077) -> None:
        # Configurate api for model served through web
        self.host = host
        self.port = port
        self.url = f"http://{self.host}:{self.port}/generate"

    def generate(self, input: str, **kwargs):
        output = get_response(post_http_request(input, self.url, **kwargs))[0]
        return output[len(input) :]


class ColossalLLM(LLM):
    """
    Langchain LLM wrapper for a local LLM
    """

    n: int
    api: Any
    kwargs = {"max_new_tokens": 100}

    @property
    def _llm_type(self) -> str:
        return "custom"

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        logger.info(f"kwargs:{kwargs}\nstop:{stop}\nprompt:{prompt}", verbose=self.verbose)
        for k in self.kwargs:
            if k not in kwargs:
                kwargs[k] = self.kwargs[k]

        generate_args = {k: kwargs[k] for k in kwargs if k not in ["stop", "n"]}
        out = self.api.generate(prompt, **generate_args)
        if isinstance(stop, list) and len(stop) != 0:
            for stopping_words in stop:
                if stopping_words in out:
                    out = out.split(stopping_words)[0]
        logger.info(f"{prompt}{out}", verbose=self.verbose)
        return out

    @property
    def _identifying_params(self) -> Mapping[str, int]:
        """Get the identifying parameters."""
        return {"n": self.n}

139
140
141
142
143
144
145
146
147
148
149
150
151
    def get_token_ids(self, text: str) -> List[int]:
        """Return the ordered ids of the tokens in a text.

        Args:
            text: The string input to tokenize.

        Returns:
            A list of ids corresponding to the tokens in the text, in order they occur
                in the text.
        """
        # use the colossal llm's tokenizer instead of langchain's cached GPT2 tokenizer
        return self.api.tokenizer.encode(text)

152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196

class VllmLLM(LLM):
    """
    Langchain LLM wrapper for a local LLM
    """

    n: int
    api: Any
    kwargs = {"max_new_tokens": 100}

    @property
    def _llm_type(self) -> str:
        return "custom"

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        for k in self.kwargs:
            if k not in kwargs:
                kwargs[k] = self.kwargs[k]
        logger.info(f"kwargs:{kwargs}\nstop:{stop}\nprompt:{prompt}", verbose=self.verbose)
        generate_args = {k: kwargs[k] for k in kwargs if k in ["n", "max_tokens", "temperature", "stream"]}
        out = self.api.generate(prompt, **generate_args)
        if len(stop) != 0:
            for stopping_words in stop:
                if stopping_words in out:
                    out = out.split(stopping_words)[0]
        logger.info(f"{prompt}{out}", verbose=self.verbose)
        return out

    def set_host_port(self, host: str = "localhost", port: int = 8077, **kwargs) -> None:
        if "max_tokens" not in kwargs:
            kwargs["max_tokens"] = 100
        self.kwargs = kwargs
        self.api = VllmAPI(host=host, port=port)

    @property
    def _identifying_params(self) -> Mapping[str, int]:
        """Get the identifying parameters."""
        return {"n": self.n}