Commit fa1f5b7d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'origin/new' into v0.5.0-dtk24.04.1

parents 9c9467da 68291efd
......@@ -2,13 +2,38 @@ from vllm.sampling_params import SamplingParams
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
import asyncio
from vllm.utils import FlexibleArgumentParser
from transformers import AutoTokenizer, AutoModel
from transformers import AutoTokenizer
import logging
import argparse
import sys
vllm_logger = logging.getLogger("vllm")
vllm_logger.setLevel(logging.WARNING)
class FlexibleArgumentParser(argparse.ArgumentParser):
"""ArgumentParser that allows both underscore and dash in names."""
def parse_args(self, args=None, namespace=None):
if args is None:
args = sys.argv[1:]
# Convert underscores to dashes and vice versa in argument names
processed_args = []
for arg in args:
if arg.startswith('--'):
if '=' in arg:
key, value = arg.split('=', 1)
key = '--' + key[len('--'):].replace('_', '-')
processed_args.append(f'{key}={value}')
else:
processed_args.append('--' +
arg[len('--'):].replace('_', '-'))
else:
processed_args.append(arg)
return super().parse_args(processed_args, namespace)
parser = FlexibleArgumentParser()
parser.add_argument('--template', type=str, help="Path to template")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
......@@ -18,9 +43,17 @@ args = parser.parse_args()
# {"role": "user", "content": "I'd like to show off how chat templating works!"},
# ]
# tokenizer = AutoTokenizer.from_pretrained("/models/llama2/Llama-2-7b-chat-hf")
# aaaa = tokenizer.chat_template
# print(aaaa)
tokenizer = AutoTokenizer.from_pretrained(args.model)
try:
f = open(args.template,'r')
tokenizer.chat_template = f.read()
except Exception as e:
print('except:',e)
finally:
f.close()
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
......@@ -37,17 +70,15 @@ def build_prompt(history):
return prompt
history = "<s>[INST] Hello, how are you? [/INST] I'm doing great. How can I help you today?</s>"
history = []
while True:
query = input("\n用户:")
if query.strip() == "stop":
break
query = history + "<s>[INST] " + query + " [/INST]"
history.append({"role": "user", "content": query})
new_query = tokenizer.apply_chat_template(history, tokenize=False)
example_input = {
"prompt": query,
"prompt": new_query,
"stream": False,
"temperature": 0.0,
"request_id": 0,
......@@ -61,19 +92,19 @@ while True:
start = 0
end = 0
last = ""
response = ""
async def process_results():
async for output in results_generator:
global end
global start
global last
global response
print(output.outputs[0].text[start:], end="", flush=True)
length = len(output.outputs[0].text)
start = length
last = output.outputs[0].text
response = output.outputs[0].text
asyncio.run(process_results())
history += "<s>[INST] " + query + " [/INST]" + last + "</s>"
history.append({"role": "assistant", "content": response})
print()
#print(history)
{% if messages[0]['role'] == 'system' %}
{% set system_message = '<<SYS>>\n' + messages[0]['content'] | trim + '\n<</SYS>>\n\n' %}
{% set messages = messages[1:] %}
{% else %}
{% set system_message = '' %}
{% endif %}
{% for message in messages %}
{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{% endif %}
{% if loop.index0 == 0 %}
{% set content = system_message + message['content'] %}
{% else %}
{% set content = message['content'] %}
{% endif %}
{% if message['role'] == 'user' %}
{{ bos_token + '[INST] ' + content | trim + ' [/INST]' }}
{% elif message['role'] == 'assistant' %}
{{ ' ' + content | trim + ' ' + eos_token }}
{% endif %}
{% endfor %}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment