chat.py 6.81 KB
Newer Older
WRH's avatar
WRH committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
"""Chat through command line.
WRH's avatar
WRH committed
3

4
5
This submodule allows user to chat with language model through command line,
and optionally accelerate model using backends like deepspeed.
WRH's avatar
WRH committed
6

7
Example 1: Chat with default setting
WRH's avatar
WRH committed
8

9
10
11
```python
python -m lmdeploy.pytorch.chat $PATH_TO_HF_MODEL
```
WRH's avatar
WRH committed
12

13
Example 2: Disable sampling
WRH's avatar
WRH committed
14

15
16
17
18
19
```python
python -m lmdeploy.pytorch.chat \
    $PATH_TO_LLAMA_MODEL_IN_HF_FORMAT \
    --temperature 0
```
WRH's avatar
WRH committed
20

21
Example 3: Accelerate with deepspeed inference
WRH's avatar
WRH committed
22

23
24
25
26
27
```python
python -m lmdeploy.pytorch.chat \
    $PATH_TO_LLAMA_MODEL_IN_HF_FORMAT \
    --accel deepspeed
```
WRH's avatar
WRH committed
28

29
30
31
Note: to use deepspeed, you need to install deepspeed,
    and if hope to accelerate InternLM, you need a customized version
    https://github.com/wangruohui/DeepSpeed/tree/support_internlm_0.10.0
32

33
Example 4: Tensor parallel the model on 2 GPUs
WRH's avatar
WRH committed
34

35
36
37
38
39
```python
deepspeed --module --num_gpus 2 lmdeploy.pytorch.chat \
    $PATH_TO_LLAMA_MODEL_IN_HF_FORMAT \
    --accel deepspeed \
```
WRH's avatar
WRH committed
40

41
42
This module also allow the following control commands to change
generation behaviors during chat.
43

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
- `exit`: terminate and exit chat
- `config set key=value`: change generation config `key` to `value`,
    e.g. config temperature=0 disable sampling for following chats
- `clear`: clear chat history
"""

import itertools
import logging
from typing import Optional

import torch
from transformers import GenerationConfig, PreTrainedModel

from .adapters import init_adapter
from .dist import get_local_rank, get_rank, get_world_size
from .model import accel_model, init_model
from .session import BasicSessionManagerWithHistory
from .utils import BasicStreamer, TerminalIO, control

logger = logging.getLogger(__name__)


def set_logging(log_file: str, debug: bool):
    torch.set_printoptions(linewidth=120)
    level = logging.DEBUG if debug else logging.INFO
    log_file = log_file or 'chat.log'
    if r := get_rank() != 0:
        log_file = log_file + f'.{r}'
    logging.basicConfig(level=level,
                        format=('%(filename)s: '
                                '%(levelname)s: '
                                '%(funcName)s(): '
                                '%(lineno)d:\t'
                                '%(message)s'),
                        filename=log_file,
                        filemode='w')
    print(f'Worker {get_rank()} logging to {log_file}')
WRH's avatar
WRH committed
81
82
83
84


def main(
    model_path: str,
85
86
87
    tokenizer_path: Optional[str] = None,
    accel: Optional[str] = None,
    max_new_tokens: int = 128,
WRH's avatar
WRH committed
88
89
90
91
    temperature: float = 0.8,
    top_p: float = 0.95,
    seed: int = 0,
    use_fast_tokenizer: bool = True,
92
93
94
95
96
    max_alloc: int = 2048,
    max_session_len: int = None,
    log_file: Optional[str] = None,
    debug: bool = False,
    adapter: Optional[str] = None,
WRH's avatar
WRH committed
97
):
98
    """Chat with model through terminal.
99
100
101
102

    Args:
        model_path (str): Path to model.
        tokenizer_path (str): Path to tokenizer.
103
        accel (str): Model accelerator.
104
105
106
107
108
        max_new_tokens (int): Maximum number of tokens to generate.
        temperature (float): Temperature for sampling.
        top_p (float): Top p for sampling.
        seed (int): Random seed.
        use_fast_tokenizer (bool): Whether to use fast tokenizer.
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
            This argument is directly pass to transformer's ``AutoTokenizer.from_pretrained``.
            Generally, user should choose to use fast tokenizers.
            But if using fast raise some error, try to force using a slow one.
        max_alloc (int): Maximum memory to allocate (for deepspeed).
        max_session_len (int): Maximum number of tokens allowed for all chat sessions.
            This include both history and current session.
        log_file (str): Path to log file.
        debug (bool): Whether to enable debug mode.
        adapter (str): Force to use an adapter.
            Generally user should not use this argument because adapter is selected based
            on the type of model. Only when it is impossible, e.g. distinguishing llama 1/2
            based on `LlamaforCausalLM` class, this argument is required.
            Currently, only "llama1" is acceptable for llama1 models.
    """  # noqa: E501
    set_logging(log_file, debug)

    # workers should sync in sampling
WRH's avatar
WRH committed
126
127
    torch.manual_seed(seed)

128
129
    local_rank = get_local_rank()
    world_size = get_world_size()
WRH's avatar
WRH committed
130

131
    # Init model and tokenizer
WRH's avatar
WRH committed
132
133
134
    if not tokenizer_path:
        tokenizer_path = model_path

135
    model, tokenizer = init_model(
WRH's avatar
WRH committed
136
137
138
139
140
        model_path,
        tokenizer_path,
        use_fast_tokenizer=use_fast_tokenizer,
    )

141
142
    # Init adapter based on model and tokenizer
    adapter = init_adapter(model, tokenizer, adapter)
WRH's avatar
WRH committed
143

144
145
146
147
148
    # Accelerate model
    model: PreTrainedModel = accel_model(model,
                                         accel,
                                         max_alloc=max_alloc,
                                         tp_size=world_size)
WRH's avatar
WRH committed
149
150
151
152
153
154
155
156

    # warmup
    warmup_config = GenerationConfig(
        max_new_tokens=1,
        do_sample=temperature > 0,
        temperature=temperature,
        top_p=top_p,
    )
157
    model.generate(torch.tensor([[6]], device=get_local_rank()), warmup_config)
WRH's avatar
WRH committed
158

159
160
161
162
163
164
    gen_config = GenerationConfig(
        max_new_tokens=max_new_tokens,
        do_sample=temperature > 0,
        temperature=temperature,
        top_p=top_p,
    )
WRH's avatar
WRH committed
165

166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    # Session manager handling history
    max_session_len = max_alloc if max_session_len is None else max_session_len
    sm = BasicSessionManagerWithHistory(max_session_len=max_session_len,
                                        start_ids=adapter.start_ids,
                                        sep_ids=adapter.sep_ids)
    io = TerminalIO()
    streamer = BasicStreamer(adapter.decode, io.output)

    for r in itertools.count(1):
        # User input from IO
        logger.info(f'Round {r}')

        prompt: str = io.input()
        logger.info(f'User input: {prompt}')

        # Allow user to change config during runtime or exit
        if control(prompt, gen_config, sm):
            continue

        # Tokenize and apply model specific templates
        input_ids = adapter.encode_and_decorate(prompt)
        logger.info(f'Input ids:\n{input_ids}')

        # Prepend chat history (tensor concatenation)
        input_ids = sm.prepend_history(input_ids)
        logger.info(f'Input ids with history:\n{input_ids}')

        # Generate
        input_ids = input_ids.cuda(local_rank)
        # returned tensor including input and generated output
        output = model.generate(input_ids,
                                gen_config,
                                streamer=streamer,
                                stopping_criteria=adapter.stopping_criteria)
        logger.info(f'Output:\n{output}')

        # Save output into session manager and maybe trim some history
        sm.add_to_history(output)


def cli():
207
208
    import fire

WRH's avatar
WRH committed
209
    fire.Fire(main)
210
211
212
213


if __name__ == '__main__':
    cli()