Unverified Commit c51020cf authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix the chat template for llava-v1.6-34b & format code (#177)

parent 50afed4e
"""Public API"""
import re
from typing import Callable, List, Optional, Union
......
......@@ -19,7 +19,9 @@ class RuntimeEndpoint(BaseBackend):
self.base_url = base_url
self.auth_token = auth_token
res = http_request(self.base_url + "/get_model_info", auth_token=self.auth_token)
res = http_request(
self.base_url + "/get_model_info", auth_token=self.auth_token
)
assert res.status_code == 200
self.model_info = res.json()
......@@ -37,7 +39,7 @@ class RuntimeEndpoint(BaseBackend):
res = http_request(
self.base_url + "/generate",
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token
auth_token=self.auth_token,
)
assert res.status_code == 200
......@@ -45,14 +47,16 @@ class RuntimeEndpoint(BaseBackend):
res = http_request(
self.base_url + "/generate",
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token
auth_token=self.auth_token,
)
assert res.status_code == 200
def fill_image(self, s: StreamExecutor):
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token)
res = http_request(
self.base_url + "/generate", json=data, auth_token=self.auth_token
)
assert res.status_code == 200
def generate(
......@@ -82,7 +86,9 @@ class RuntimeEndpoint(BaseBackend):
self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token)
res = http_request(
self.base_url + "/generate", json=data, auth_token=self.auth_token
)
obj = res.json()
comp = obj["text"]
return comp, obj["meta_info"]
......@@ -115,7 +121,12 @@ class RuntimeEndpoint(BaseBackend):
data["stream"] = True
self._add_images(s, data)
response = http_request(self.base_url + "/generate", json=data, stream=True, auth_token=self.auth_token)
response = http_request(
self.base_url + "/generate",
json=data,
stream=True,
auth_token=self.auth_token,
)
pos = 0
incomplete_text = ""
......@@ -145,7 +156,9 @@ class RuntimeEndpoint(BaseBackend):
# Cache common prefix
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token)
res = http_request(
self.base_url + "/generate", json=data, auth_token=self.auth_token
)
assert res.status_code == 200
prompt_len = res.json()["meta_info"]["prompt_tokens"]
......@@ -157,7 +170,9 @@ class RuntimeEndpoint(BaseBackend):
"logprob_start_len": max(prompt_len - 2, 0),
}
self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token)
res = http_request(
self.base_url + "/generate", json=data, auth_token=self.auth_token
)
assert res.status_code == 200
obj = res.json()
normalized_prompt_logprob = [
......@@ -172,7 +187,7 @@ class RuntimeEndpoint(BaseBackend):
res = http_request(
self.base_url + "/concate_and_append_request",
json={"src_rids": src_rids, "dst_rid": dst_rid},
auth_token=self.auth_token
auth_token=self.auth_token,
)
assert res.status_code == 200
......
......@@ -116,6 +116,21 @@ register_chat_template(
)
register_chat_template(
ChatTemplate(
name="chatml-llava",
default_system_prompt="Answer the questions.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "\n<|im_end|>\n"),
"user": ("<|im_start|>user\n", "\n<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
image_token=" <image>\n",
)
)
register_chat_template(
ChatTemplate(
name="vicuna_v1.1",
......@@ -168,7 +183,7 @@ register_chat_template(
def match_vicuna(model_path: str):
if "vicuna" in model_path.lower():
return get_chat_template("vicuna_v1.1")
if "llava" in model_path.lower():
if "llava-v1.5" in model_path.lower():
return get_chat_template("vicuna_v1.1")
......@@ -192,6 +207,8 @@ def match_chat_ml(model_path: str):
return get_chat_template("chatml")
if "qwen" in model_path and "chat" in model_path:
return get_chat_template("chatml")
if "llava-v1.6-34b" in model_path:
return get_chat_template("chatml-llava")
@register_chat_template_matching_function
......
......@@ -74,9 +74,9 @@ class SglSamplingParams:
)
return {
"max_tokens_to_sample": self.max_new_tokens,
"stop_sequences": self.stop
if isinstance(self.stop, (list, tuple))
else [self.stop],
"stop_sequences": (
self.stop if isinstance(self.stop, (list, tuple)) else [self.stop]
),
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
......
"""Tracing a program."""
import uuid
from typing import Any, Callable, Dict, List, Optional, Union
......
"""
Backend configurations, may vary with different serving platforms.
"""
from dataclasses import dataclass
......
......@@ -366,7 +366,8 @@ def generate_chat_conv(
if content.type == "text":
real_content += content.text
elif content.type == "image_url":
real_content += "<image>"
# NOTE: Only works for llava
real_content += "<image>\n"
conv.append_image(content.image_url.url)
conv.append_message(conv.roles[0], real_content)
elif msg_role == "assistant":
......
......@@ -31,6 +31,7 @@ from sglang.srt.utils import (
is_multimodal_model,
set_random_seed,
)
from vllm.logger import _default_handler as vllm_default_handler
logger = logging.getLogger("model_rpc")
......@@ -50,6 +51,9 @@ class ModelRpcServer(rpyc.Service):
self.tp_size = server_args.tp_size
self.schedule_heuristic = server_args.schedule_heuristic
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
vllm_default_handler.setLevel(
level=getattr(logging, server_args.log_level.upper())
)
# Init model and tokenizer
self.model_config = ModelConfig(
......@@ -83,9 +87,11 @@ class ModelRpcServer(rpyc.Service):
self.max_num_running_seq = self.max_total_num_token // 2
self.max_prefill_num_token = max(
self.model_config.context_len,
self.max_total_num_token // 6
if server_args.max_prefill_num_token is None
else server_args.max_prefill_num_token,
(
self.max_total_num_token // 6
if server_args.max_prefill_num_token is None
else server_args.max_prefill_num_token
),
)
self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
......@@ -534,7 +540,7 @@ class ModelRpcServer(rpyc.Service):
output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens
)
# For the length of input_ids, which will be accumulated during jump-forward.
# Use the original length of input_ids to calculate the token usage info.
meta_info = {
......
......@@ -112,7 +112,9 @@ class InputMetadata:
(self.batch_size,), dtype=torch.int32, device="cuda"
)
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda")
workspace_buffer = torch.empty(
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
)
if (
self.forward_mode == ForwardMode.PREFILL
or self.forward_mode == ForwardMode.EXTEND
......@@ -121,7 +123,9 @@ class InputMetadata:
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD")
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD"
)
self.prefill_wrapper.begin_forward(
self.qo_indptr,
self.kv_indptr,
......@@ -131,7 +135,9 @@ class InputMetadata:
self.model_runner.model_config.num_key_value_heads // tp_size,
)
else:
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD")
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, "NHD"
)
self.decode_wrapper.begin_forward(
self.kv_indptr,
self.kv_indices,
......
"""Memory pool."""
import logging
import torch
......
"""Inference-only LLaVa model compatible with HuggingFace weights."""
from typing import List, Optional
import numpy as np
......@@ -269,7 +270,6 @@ class LlavaLlamaForCausalLM(nn.Module):
raise ValueError(f"Unexpected select feature: {self.select_feature}")
# load mm_projector
# TODO: support TP?
projector_weights = {
"model.mm_projector.0": "multi_modal_projector.linear_1",
"model.mm_projector.2": "multi_modal_projector.linear_2",
......
"""Inference-only Mistral model."""
from sglang.srt.models.llama2 import LlamaForCausalLM
......
......@@ -97,14 +97,16 @@ class MixtralMoE(nn.Module):
self.experts = nn.ModuleList(
[
MixtralMLP(
self.num_total_experts,
config.hidden_size,
config.intermediate_size,
linear_method=linear_method,
(
MixtralMLP(
self.num_total_experts,
config.hidden_size,
config.intermediate_size,
linear_method=linear_method,
)
if idx in self.expert_indicies
else None
)
if idx in self.expert_indicies
else None
for idx in range(self.num_total_experts)
]
)
......
"""Inference-only Yi-VL model."""
import os
from typing import List, Optional
......
"""Sampling parameters for text generation."""
from typing import List, Optional, Union
_SAMPLING_EPS = 1e-6
......
"""SRT: SGLang Runtime"""
import asyncio
import json
import multiprocessing as mp
......@@ -493,7 +494,7 @@ def launch_server(server_args, pipe_finish_writer):
# Warmup
try:
print("Warmup...", flush=True)
# print("Warmup...", flush=True)
res = requests.post(
url + "/generate",
json={
......@@ -505,8 +506,8 @@ def launch_server(server_args, pipe_finish_writer):
},
timeout=60,
)
print(f"Warmup done. model response: {res.json()['text']}")
print("=" * 20, "Server is ready", "=" * 20, flush=True)
# print(f"Warmup done. model response: {res.json()['text']}")
# print("=" * 20, "Server is ready", "=" * 20, flush=True)
except requests.exceptions.RequestException as e:
if pipe_finish_writer is not None:
pipe_finish_writer.send(str(e))
......
......@@ -122,7 +122,7 @@ def handle_port_init(
# first check on server port
if not check_port(port):
new_port = alloc_usable_network_port(1, used_list=[port])[0]
print(f"Port {port} is not available, using {new_port} instead.")
print(f"WARNING: Port {port} is not available. Use {new_port} instead.")
port = new_port
# then we check on additional ports
......@@ -157,8 +157,6 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
ss = tokenizer.decode([t_id]).strip()
if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id):
logit_bias[t_id] = -1e5
# else:
# print(ss, t_id)
return logit_bias
......
"""Common utilities for testing and benchmarking"""
import numpy as np
import requests
from sglang.backend.openai import OpenAI
......
......@@ -22,7 +22,7 @@ def get_available_gpu_memory(gpu_id, distributed=True):
if torch.cuda.current_device() != gpu_id:
print(
f"WARN: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
"which may cause useless memory allocation for torch CUDA context.",
)
......@@ -95,7 +95,7 @@ def http_request(url, json=None, stream=False, auth_token=None):
return requests.post(url, json=json, stream=True)
headers = {
"Content-Type": "application/json",
"Authentication": f"Bearer {auth_token}"
"Authentication": f"Bearer {auth_token}",
}
return requests.post(url, json=json, stream=True, headers=headers)
else:
......
"""
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
"""
import json
import unittest
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment