Unverified Commit 009075d8 authored by WRH's avatar WRH Committed by GitHub
Browse files

[Feature] Add a torch client (#19)

* draft torch client

* deal with space of tokenizer

* support tensor parallel

* fix

* fix

* move folder

* move instruction to readme

* move to torch/

* rename client to chat

* very bad response

* stash

* rename streamer

* support internlm

* change default args

* remove test

* improve instructions

* remove module docstring

* decrease header level of torch model
parent 76ae8627
......@@ -11,3 +11,4 @@ dist/
examples/cpp/llama/*.csv
*.npy
*.weight
*.pyc
......@@ -122,6 +122,29 @@ python3 lmdeploy.app {server_ip_addresss}:33337 internlm
For the deployment of other supported models, such as LLaMA, vicuna, you can find the guide from [here](docs/en/serving.md)
### Inference with PyTorch
#### Single GPU
```shell
python3 -m lmdeploy.torch.chat $NAME_OR_PATH_TO_HF_MODEL\
--max_new_tokens 64 \
--temperture 0.8 \
--top_p 0.95 \
--seed 0
```
#### Tensor Parallel with DeepSpeed
```shell
deepspeed --module --num_gpus 2 lmdeploy.torch.chat \
$NAME_OR_PATH_TO_HF_MODEL \
--max_new_tokens 64 \
--temperture 0.8 \
--top_p 0.95 \
--seed 0
```
## Quantization
In fp16 mode, kv_cache int8 quantization can be enabled, and a single card can serve more users.
......
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import os
import warnings
import fire
import torch
try:
import deepspeed
_is_deepspeed_available = True
except ImportError:
_is_deepspeed_available = False
try:
from transformers import (AutoModelForCausalLM, AutoTokenizer,
GenerationConfig)
from .utils import get_utils
_is_transformers_available = True
except ImportError:
_is_transformers_available = False
def input_prompt():
print('\ndouble enter to end input >>> ', end='')
sentinel = '' # ends when this string is seen
return '\n'.join(iter(input, sentinel))
def init_model(
model_path: str,
tokenizer_path: str,
use_fast_tokenizer=True,
local_rank=0,
world_size=1,
):
"""Note:
If the model is converted from new version of transformers,
use_fast_tokenizer should be True.
If using depodaca/llama-xb-hf, use_fast_tokenizer should be False.
"""
if not _is_transformers_available:
raise ImportError('transformers is not installed.\n'
'Please install with `pip install transformers`.\n')
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path,
use_fast=use_fast_tokenizer,
trust_remote_code=True)
torch.set_default_device(local_rank)
model = AutoModelForCausalLM.from_pretrained(model_path,
torch_dtype=torch.float16,
trust_remote_code=True)
if not _is_deepspeed_available:
warnings.warn('deepspeed is not installed, ',
'use plain huggingface model.')
else:
model = deepspeed.init_inference(
model=model, # Transformers models
mp_size=world_size, # Number of GPU
dtype=torch.float16, # dtype of the weights (fp16)
replace_with_kernel_inject=True,
# replace the model with the kernel injector
max_out_tokens=2048,
)
# print(f"model is loaded on device {model.device}")
return tokenizer, model
def main(
model_path: str,
tokenizer_path: str = None,
max_new_tokens: int = 64,
temperature: float = 0.8,
top_p: float = 0.95,
seed: int = 0,
use_fast_tokenizer: bool = True,
):
torch.manual_seed(seed)
local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
if not tokenizer_path:
tokenizer_path = model_path
tokenizer, model = init_model(
model_path,
tokenizer_path,
use_fast_tokenizer=use_fast_tokenizer,
local_rank=local_rank,
world_size=world_size,
)
gen_config = GenerationConfig(
max_new_tokens=max_new_tokens,
do_sample=temperature > 0,
temperature=temperature,
top_p=top_p,
)
Decorator, Streamer, stop_criteria = get_utils(model)
# warmup
warmup_config = GenerationConfig(
max_new_tokens=1,
do_sample=temperature > 0,
temperature=temperature,
top_p=top_p,
)
model.generate(torch.tensor([[1]]), warmup_config)
# print("READY ...")
_on_master = local_rank == 0
_is_dist = world_size > 1
while True:
# Receive prompt on master
if _on_master:
prompt = input_prompt()
else:
prompt = None
# Broadcast prompt to all workers
if _is_dist:
prompt = [prompt]
torch.distributed.broadcast_object_list(prompt, src=0)
prompt = prompt[0]
if prompt == 'exit':
exit(0)
# Re-config during runtime
if prompt.startswith('config set'):
try:
keqv = prompt.split()[-1]
k, v = keqv.split('=')
v = eval(v)
gen_config.__setattr__(k, v)
print(f'Worker {local_rank} set {k} to {repr(v)}')
except: # noqa
print('illegal instruction')
else:
if _on_master:
streamer = Streamer(tokenizer)
else:
streamer = None
prompt = Decorator.decorate(prompt)
ids = tokenizer.encode(prompt, return_tensors='pt')
model.generate(ids,
gen_config,
streamer=streamer,
stopping_criteria=stop_criteria)
if __name__ == '__main__':
fire.Fire(main)
# Copyright (c) OpenMMLab. All rights reserved.
import re
from transformers import (PreTrainedTokenizerFast, StoppingCriteria,
StoppingCriteriaList)
from transformers.generation.streamers import BaseStreamer
def get_utils(model):
name = model.__class__.__name__
if name == 'InferenceEngine':
name = model.module.__class__.__name__
if name == 'InternLMForCausalLM':
stop_criteria = InternLMStoppingCriteria()
stop_criteria = StoppingCriteriaList([stop_criteria])
return InternLMDecorator, InternLMStreamer, stop_criteria
else:
return BaseDecorator, DecodeOutputStreamer, None
class DecodeOutputStreamer(BaseStreamer):
"""Output generated tokens to shell."""
def __init__(self, tokenizer, skip_prompt=True) -> None:
super().__init__()
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.gen_len = 0
if isinstance(tokenizer, PreTrainedTokenizerFast):
self.decode = self._decode_with_raw_id
self.hex_regex = re.compile(r'^<0x([0-9ABCDEF]+)>$')
else:
self.decode = self._decode_fallback
def _decode_with_raw_id(self, value):
tok = self.tokenizer._convert_id_to_token(value)
if tok.startswith('▁'): # sentencepiece
space = ' '
tok = tok[1:]
else:
space = ''
if res := self.hex_regex.match(tok):
tok = chr(int(res.group(1), 16))
if tok == '</s>':
tok = '\n'
return space + tok
def _decode_fallback(self, value):
tok = self.tokenizer.decode(value,
skip_special_tokens=False,
clean_up_tokenization_spaces=False)
return tok + ' '
def put(self, value):
if self.gen_len == 0 and self.skip_prompt:
pass
else:
tok = self.decode(value[0])
print(tok, end='', flush=True)
self.gen_len += 1
def end(self):
print('\n')
class InternLMStreamer(DecodeOutputStreamer):
"""Output generated tokens to shell."""
def __init__(self, tokenizer, skip_prompt=True) -> None:
BaseStreamer().__init__()
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.gen_len = 0
self.hex_regex = re.compile(r'^<0x([0-9ABCDEF]+)>$')
def decode(self, value):
tok = self.tokenizer.decode(value)
if res := self.hex_regex.match(tok):
tok = chr(int(res.group(1), 16))
if tok == '</s>' or tok == '<eoa>':
tok = '\n'
return tok
class BaseDecorator:
@classmethod
def decorate(cls, prompt):
return prompt
@classmethod
def extract(cls, gen_out):
return gen_out
class InternLMDecorator(BaseDecorator):
regex = re.compile(r'<\|Bot\|>:(.*)')
@classmethod
def decorate(cls, prompt):
return f'<|User|>:{prompt}<eoh>'
@classmethod
def extract(cls, gen_out):
return cls.regex.search(gen_out).group(1)
class InternLMStoppingCriteria(StoppingCriteria):
def __call__(self, input_ids, *args, **kwargs) -> bool:
return input_ids[0, -1] in [2, 103028]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment