"vscode:/vscode.git/clone" did not exist on "a362340b33258eae0f48504be09659e2e9dcd035"
chat.py 6.8 KB
Newer Older
WRH's avatar
WRH committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
"""Chat through command line.
WRH's avatar
WRH committed
3

4
5
This submodule allows user to chat with language model through command line,
and optionally accelerate model using backends like deepspeed.
WRH's avatar
WRH committed
6

7
Example 1: Chat with default setting
WRH's avatar
WRH committed
8

9
10
11
```python
python -m lmdeploy.pytorch.chat $PATH_TO_HF_MODEL
```
WRH's avatar
WRH committed
12

13
Example 2: Disable sampling
WRH's avatar
WRH committed
14

15
16
17
18
19
```python
python -m lmdeploy.pytorch.chat \
    $PATH_TO_LLAMA_MODEL_IN_HF_FORMAT \
    --temperature 0
```
WRH's avatar
WRH committed
20

21
Example 3: Accelerate with deepspeed inference
WRH's avatar
WRH committed
22

23
24
25
26
27
```python
python -m lmdeploy.pytorch.chat \
    $PATH_TO_LLAMA_MODEL_IN_HF_FORMAT \
    --accel deepspeed
```
WRH's avatar
WRH committed
28

29
30
31
Note: to use deepspeed, you need to install deepspeed,
    and if hope to accelerate InternLM, you need a customized version
    https://github.com/wangruohui/DeepSpeed/tree/support_internlm_0.10.0
32

33
Example 4: Tensor parallel the model on 2 GPUs
WRH's avatar
WRH committed
34

35
36
37
38
39
```python
deepspeed --module --num_gpus 2 lmdeploy.pytorch.chat \
    $PATH_TO_LLAMA_MODEL_IN_HF_FORMAT \
    --accel deepspeed \
```
WRH's avatar
WRH committed
40

41
42
This module also allow the following control commands to change
generation behaviors during chat.
43

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
- `exit`: terminate and exit chat
- `config set key=value`: change generation config `key` to `value`,
    e.g. config temperature=0 disable sampling for following chats
- `clear`: clear chat history
"""

import itertools
import logging
from typing import Optional

import fire
import torch
from transformers import GenerationConfig, PreTrainedModel

from .adapters import init_adapter
from .dist import get_local_rank, get_rank, get_world_size
from .model import accel_model, init_model
from .session import BasicSessionManagerWithHistory
from .utils import BasicStreamer, TerminalIO, control

logger = logging.getLogger(__name__)


def set_logging(log_file: str, debug: bool):
    torch.set_printoptions(linewidth=120)
    level = logging.DEBUG if debug else logging.INFO
    log_file = log_file or 'chat.log'
    if r := get_rank() != 0:
        log_file = log_file + f'.{r}'
    logging.basicConfig(level=level,
                        format=('%(filename)s: '
                                '%(levelname)s: '
                                '%(funcName)s(): '
                                '%(lineno)d:\t'
                                '%(message)s'),
                        filename=log_file,
                        filemode='w')
    print(f'Worker {get_rank()} logging to {log_file}')
WRH's avatar
WRH committed
82
83
84
85


def main(
    model_path: str,
86
87
88
    tokenizer_path: Optional[str] = None,
    accel: Optional[str] = None,
    max_new_tokens: int = 128,
WRH's avatar
WRH committed
89
90
91
92
    temperature: float = 0.8,
    top_p: float = 0.95,
    seed: int = 0,
    use_fast_tokenizer: bool = True,
93
94
95
96
97
    max_alloc: int = 2048,
    max_session_len: int = None,
    log_file: Optional[str] = None,
    debug: bool = False,
    adapter: Optional[str] = None,
WRH's avatar
WRH committed
98
):
99
    """Chat with model through terminal.
100
101
102
103

    Args:
        model_path (str): Path to model.
        tokenizer_path (str): Path to tokenizer.
104
        accel (str): Model accelerator.
105
106
107
108
109
        max_new_tokens (int): Maximum number of tokens to generate.
        temperature (float): Temperature for sampling.
        top_p (float): Top p for sampling.
        seed (int): Random seed.
        use_fast_tokenizer (bool): Whether to use fast tokenizer.
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
            This argument is directly pass to transformer's ``AutoTokenizer.from_pretrained``.
            Generally, user should choose to use fast tokenizers.
            But if using fast raise some error, try to force using a slow one.
        max_alloc (int): Maximum memory to allocate (for deepspeed).
        max_session_len (int): Maximum number of tokens allowed for all chat sessions.
            This include both history and current session.
        log_file (str): Path to log file.
        debug (bool): Whether to enable debug mode.
        adapter (str): Force to use an adapter.
            Generally user should not use this argument because adapter is selected based
            on the type of model. Only when it is impossible, e.g. distinguishing llama 1/2
            based on `LlamaforCausalLM` class, this argument is required.
            Currently, only "llama1" is acceptable for llama1 models.
    """  # noqa: E501
    set_logging(log_file, debug)

    # workers should sync in sampling
WRH's avatar
WRH committed
127
128
    torch.manual_seed(seed)

129
130
    local_rank = get_local_rank()
    world_size = get_world_size()
WRH's avatar
WRH committed
131

132
    # Init model and tokenizer
WRH's avatar
WRH committed
133
134
135
    if not tokenizer_path:
        tokenizer_path = model_path

136
    model, tokenizer = init_model(
WRH's avatar
WRH committed
137
138
139
140
141
        model_path,
        tokenizer_path,
        use_fast_tokenizer=use_fast_tokenizer,
    )

142
143
    # Init adapter based on model and tokenizer
    adapter = init_adapter(model, tokenizer, adapter)
WRH's avatar
WRH committed
144

145
146
147
148
149
    # Accelerate model
    model: PreTrainedModel = accel_model(model,
                                         accel,
                                         max_alloc=max_alloc,
                                         tp_size=world_size)
WRH's avatar
WRH committed
150
151
152
153
154
155
156
157

    # warmup
    warmup_config = GenerationConfig(
        max_new_tokens=1,
        do_sample=temperature > 0,
        temperature=temperature,
        top_p=top_p,
    )
158
    model.generate(torch.tensor([[6]], device=get_local_rank()), warmup_config)
WRH's avatar
WRH committed
159

160
161
162
163
164
165
    gen_config = GenerationConfig(
        max_new_tokens=max_new_tokens,
        do_sample=temperature > 0,
        temperature=temperature,
        top_p=top_p,
    )
WRH's avatar
WRH committed
166

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    # Session manager handling history
    max_session_len = max_alloc if max_session_len is None else max_session_len
    sm = BasicSessionManagerWithHistory(max_session_len=max_session_len,
                                        start_ids=adapter.start_ids,
                                        sep_ids=adapter.sep_ids)
    io = TerminalIO()
    streamer = BasicStreamer(adapter.decode, io.output)

    for r in itertools.count(1):
        # User input from IO
        logger.info(f'Round {r}')

        prompt: str = io.input()
        logger.info(f'User input: {prompt}')

        # Allow user to change config during runtime or exit
        if control(prompt, gen_config, sm):
            continue

        # Tokenize and apply model specific templates
        input_ids = adapter.encode_and_decorate(prompt)
        logger.info(f'Input ids:\n{input_ids}')

        # Prepend chat history (tensor concatenation)
        input_ids = sm.prepend_history(input_ids)
        logger.info(f'Input ids with history:\n{input_ids}')

        # Generate
        input_ids = input_ids.cuda(local_rank)
        # returned tensor including input and generated output
        output = model.generate(input_ids,
                                gen_config,
                                streamer=streamer,
                                stopping_criteria=adapter.stopping_criteria)
        logger.info(f'Output:\n{output}')

        # Save output into session manager and maybe trim some history
        sm.add_to_history(output)


def cli():
WRH's avatar
WRH committed
208
    fire.Fire(main)
209
210
211
212


if __name__ == '__main__':
    cli()