"git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "bbd72a6c979d387a9c5670c3dcbb157434f5cb5b"
Unverified Commit ffe4ba9c authored by Chen Xin's avatar Chen Xin Committed by GitHub
Browse files

Fix crash and remove `sys_instruct` from `chat.py` and `client.py`(#591)

* fix crash

* update profile_generation.py

* format

* use self.bos_id

* remove sys_instruct
parent af2f072e
...@@ -30,7 +30,7 @@ pip install nvidia-ml-py ...@@ -30,7 +30,7 @@ pip install nvidia-ml-py
```bash ```bash
python profile_generation.py \ python profile_generation.py \
--model-path /path/to/your/model \ --model-path /path/to/your/model \
--concurrency 1 8 --prompt-tokens 0 512 --completion-tokens 2048 512 --concurrency 1 8 --prompt-tokens 1 512 --completion-tokens 2048 512
``` ```
## profile serving ## profile serving
......
...@@ -90,7 +90,7 @@ def warmup(model, ...@@ -90,7 +90,7 @@ def warmup(model,
def profile_throughput(model_path: str, def profile_throughput(model_path: str,
concurrency: int = 1, concurrency: int = 1,
input_seqlen: int = 0, input_seqlen: int = 1,
output_seqlen: int = 512, output_seqlen: int = 512,
test_round: int = 10, test_round: int = 10,
tp: int = 1): tp: int = 1):
...@@ -99,8 +99,10 @@ def profile_throughput(model_path: str, ...@@ -99,8 +99,10 @@ def profile_throughput(model_path: str,
tm_model = TurboMind(model_path=model_path, tp=tp) tm_model = TurboMind(model_path=model_path, tp=tp)
# make up a prompt that can be tokenized into {input_seqlen} tokens # make up a prompt that can be tokenized into {input_seqlen} tokens
prompt = '' if input_seqlen == 0 else 'hi' + ' hi' * (input_seqlen - 1) assert input_seqlen > 0, 'input_seqlen should > 0'
prompt = 'hi'
input_ids = tokenizer.encode(prompt) input_ids = tokenizer.encode(prompt)
input_ids = input_ids * input_seqlen
warmup(tm_model, concurrency, input_ids, output_seqlen) warmup(tm_model, concurrency, input_ids, output_seqlen)
......
...@@ -20,7 +20,6 @@ def input_prompt(model_name): ...@@ -20,7 +20,6 @@ def input_prompt(model_name):
def main(tritonserver_addr: str, def main(tritonserver_addr: str,
session_id: int = 1, session_id: int = 1,
cap: str = 'chat', cap: str = 'chat',
sys_instruct: str = None,
stream_output: bool = True, stream_output: bool = True,
**kwargs): **kwargs):
"""An example to communicate with inference server through the command line """An example to communicate with inference server through the command line
...@@ -32,13 +31,11 @@ def main(tritonserver_addr: str, ...@@ -32,13 +31,11 @@ def main(tritonserver_addr: str,
session_id (int): the identical id of a session session_id (int): the identical id of a session
cap (str): the capability of a model. For example, codellama has cap (str): the capability of a model. For example, codellama has
the ability among ['completion', 'infill', 'instruct', 'python'] the ability among ['completion', 'infill', 'instruct', 'python']
sys_instruct (str): the content of 'system' role, which is used by
conversational model
stream_output (bool): indicator for streaming output or not stream_output (bool): indicator for streaming output or not
**kwargs (dict): other arguments for initializing model's chat template **kwargs (dict): other arguments for initializing model's chat template
""" """
log_level = os.environ.get('SERVICE_LOG_LEVEL', 'WARNING') log_level = os.environ.get('SERVICE_LOG_LEVEL', 'WARNING')
kwargs.update(capability=cap, system=sys_instruct) kwargs.update(capability=cap)
chatbot = Chatbot(tritonserver_addr, chatbot = Chatbot(tritonserver_addr,
log_level=log_level, log_level=log_level,
display=stream_output, display=stream_output,
......
...@@ -459,6 +459,10 @@ class Chatbot: ...@@ -459,6 +459,10 @@ class Chatbot:
session.sequence_length = 0 session.sequence_length = 0
input_ids, input_lengths = self.preprocess(prompt) input_ids, input_lengths = self.preprocess(prompt)
# will crash if last_token_id == eos_id and send empty input_ids
if sequence_end and request_output_len == 0:
input_ids = np.array([[self.bos_id]], dtype=np.uint32)
input_lengths = np.array([[1]], dtype=np.uint32)
input_tokens = input_lengths.squeeze() input_tokens = input_lengths.squeeze()
if self.profile_generation: if self.profile_generation:
yield StatusCode.TRITON_STREAM_ING, \ yield StatusCode.TRITON_STREAM_ING, \
......
...@@ -73,7 +73,6 @@ def get_gen_param(cap, ...@@ -73,7 +73,6 @@ def get_gen_param(cap,
def main(model_path, def main(model_path,
session_id: int = 1, session_id: int = 1,
cap: str = 'chat', cap: str = 'chat',
sys_instruct: str = None,
tp=1, tp=1,
stream_output=True, stream_output=True,
**kwargs): **kwargs):
...@@ -85,8 +84,6 @@ def main(model_path, ...@@ -85,8 +84,6 @@ def main(model_path,
session_id (int): the identical id of a session session_id (int): the identical id of a session
cap (str): the capability of a model. For example, codellama has cap (str): the capability of a model. For example, codellama has
the ability among ['completion', 'infilling', 'chat', 'python'] the ability among ['completion', 'infilling', 'chat', 'python']
sys_instruct (str): the content of 'system' role, which is used by
conversational model
tp (int): GPU number used in tensor parallelism tp (int): GPU number used in tensor parallelism
stream_output (bool): indicator for streaming output or not stream_output (bool): indicator for streaming output or not
**kwarg (dict): other arguments for initializing model's chat template **kwarg (dict): other arguments for initializing model's chat template
...@@ -100,9 +97,7 @@ def main(model_path, ...@@ -100,9 +97,7 @@ def main(model_path,
step = 0 step = 0
seed = random.getrandbits(64) seed = random.getrandbits(64)
model_name = tm_model.model_name model_name = tm_model.model_name
model = MODELS.get(model_name)(capability=cap, **kwargs) \ model = MODELS.get(model_name)(capability=cap, **kwargs)
if sys_instruct is None else MODELS.get(model_name)(
capability=cap, system=sys_instruct, **kwargs)
print(f'session {session_id}') print(f'session {session_id}')
while True: while True:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment