"vscode:/vscode.git/clone" did not exist on "e61b68e01c8b10e9bcd3235baa8d163a1e4f3379"
Unverified Commit c51020cf authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix the chat template for llava-v1.6-34b & format code (#177)

parent 50afed4e
"""Public API""" """Public API"""
import re import re
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
......
...@@ -19,7 +19,9 @@ class RuntimeEndpoint(BaseBackend): ...@@ -19,7 +19,9 @@ class RuntimeEndpoint(BaseBackend):
self.base_url = base_url self.base_url = base_url
self.auth_token = auth_token self.auth_token = auth_token
res = http_request(self.base_url + "/get_model_info", auth_token=self.auth_token) res = http_request(
self.base_url + "/get_model_info", auth_token=self.auth_token
)
assert res.status_code == 200 assert res.status_code == 200
self.model_info = res.json() self.model_info = res.json()
...@@ -37,7 +39,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -37,7 +39,7 @@ class RuntimeEndpoint(BaseBackend):
res = http_request( res = http_request(
self.base_url + "/generate", self.base_url + "/generate",
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}}, json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token auth_token=self.auth_token,
) )
assert res.status_code == 200 assert res.status_code == 200
...@@ -45,14 +47,16 @@ class RuntimeEndpoint(BaseBackend): ...@@ -45,14 +47,16 @@ class RuntimeEndpoint(BaseBackend):
res = http_request( res = http_request(
self.base_url + "/generate", self.base_url + "/generate",
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}}, json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token auth_token=self.auth_token,
) )
assert res.status_code == 200 assert res.status_code == 200
def fill_image(self, s: StreamExecutor): def fill_image(self, s: StreamExecutor):
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}} data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data) self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token) res = http_request(
self.base_url + "/generate", json=data, auth_token=self.auth_token
)
assert res.status_code == 200 assert res.status_code == 200
def generate( def generate(
...@@ -82,7 +86,9 @@ class RuntimeEndpoint(BaseBackend): ...@@ -82,7 +86,9 @@ class RuntimeEndpoint(BaseBackend):
self._add_images(s, data) self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token) res = http_request(
self.base_url + "/generate", json=data, auth_token=self.auth_token
)
obj = res.json() obj = res.json()
comp = obj["text"] comp = obj["text"]
return comp, obj["meta_info"] return comp, obj["meta_info"]
...@@ -115,7 +121,12 @@ class RuntimeEndpoint(BaseBackend): ...@@ -115,7 +121,12 @@ class RuntimeEndpoint(BaseBackend):
data["stream"] = True data["stream"] = True
self._add_images(s, data) self._add_images(s, data)
response = http_request(self.base_url + "/generate", json=data, stream=True, auth_token=self.auth_token) response = http_request(
self.base_url + "/generate",
json=data,
stream=True,
auth_token=self.auth_token,
)
pos = 0 pos = 0
incomplete_text = "" incomplete_text = ""
...@@ -145,7 +156,9 @@ class RuntimeEndpoint(BaseBackend): ...@@ -145,7 +156,9 @@ class RuntimeEndpoint(BaseBackend):
# Cache common prefix # Cache common prefix
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}} data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data) self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token) res = http_request(
self.base_url + "/generate", json=data, auth_token=self.auth_token
)
assert res.status_code == 200 assert res.status_code == 200
prompt_len = res.json()["meta_info"]["prompt_tokens"] prompt_len = res.json()["meta_info"]["prompt_tokens"]
...@@ -157,7 +170,9 @@ class RuntimeEndpoint(BaseBackend): ...@@ -157,7 +170,9 @@ class RuntimeEndpoint(BaseBackend):
"logprob_start_len": max(prompt_len - 2, 0), "logprob_start_len": max(prompt_len - 2, 0),
} }
self._add_images(s, data) self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token) res = http_request(
self.base_url + "/generate", json=data, auth_token=self.auth_token
)
assert res.status_code == 200 assert res.status_code == 200
obj = res.json() obj = res.json()
normalized_prompt_logprob = [ normalized_prompt_logprob = [
...@@ -172,7 +187,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -172,7 +187,7 @@ class RuntimeEndpoint(BaseBackend):
res = http_request( res = http_request(
self.base_url + "/concate_and_append_request", self.base_url + "/concate_and_append_request",
json={"src_rids": src_rids, "dst_rid": dst_rid}, json={"src_rids": src_rids, "dst_rid": dst_rid},
auth_token=self.auth_token auth_token=self.auth_token,
) )
assert res.status_code == 200 assert res.status_code == 200
......
...@@ -116,6 +116,21 @@ register_chat_template( ...@@ -116,6 +116,21 @@ register_chat_template(
) )
register_chat_template(
ChatTemplate(
name="chatml-llava",
default_system_prompt="Answer the questions.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "\n<|im_end|>\n"),
"user": ("<|im_start|>user\n", "\n<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
image_token=" <image>\n",
)
)
register_chat_template( register_chat_template(
ChatTemplate( ChatTemplate(
name="vicuna_v1.1", name="vicuna_v1.1",
...@@ -168,7 +183,7 @@ register_chat_template( ...@@ -168,7 +183,7 @@ register_chat_template(
def match_vicuna(model_path: str): def match_vicuna(model_path: str):
if "vicuna" in model_path.lower(): if "vicuna" in model_path.lower():
return get_chat_template("vicuna_v1.1") return get_chat_template("vicuna_v1.1")
if "llava" in model_path.lower(): if "llava-v1.5" in model_path.lower():
return get_chat_template("vicuna_v1.1") return get_chat_template("vicuna_v1.1")
...@@ -192,6 +207,8 @@ def match_chat_ml(model_path: str): ...@@ -192,6 +207,8 @@ def match_chat_ml(model_path: str):
return get_chat_template("chatml") return get_chat_template("chatml")
if "qwen" in model_path and "chat" in model_path: if "qwen" in model_path and "chat" in model_path:
return get_chat_template("chatml") return get_chat_template("chatml")
if "llava-v1.6-34b" in model_path:
return get_chat_template("chatml-llava")
@register_chat_template_matching_function @register_chat_template_matching_function
......
...@@ -74,9 +74,9 @@ class SglSamplingParams: ...@@ -74,9 +74,9 @@ class SglSamplingParams:
) )
return { return {
"max_tokens_to_sample": self.max_new_tokens, "max_tokens_to_sample": self.max_new_tokens,
"stop_sequences": self.stop "stop_sequences": (
if isinstance(self.stop, (list, tuple)) self.stop if isinstance(self.stop, (list, tuple)) else [self.stop]
else [self.stop], ),
"temperature": self.temperature, "temperature": self.temperature,
"top_p": self.top_p, "top_p": self.top_p,
"top_k": self.top_k, "top_k": self.top_k,
......
"""Tracing a program.""" """Tracing a program."""
import uuid import uuid
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
......
""" """
Backend configurations, may vary with different serving platforms. Backend configurations, may vary with different serving platforms.
""" """
from dataclasses import dataclass from dataclasses import dataclass
......
...@@ -366,7 +366,8 @@ def generate_chat_conv( ...@@ -366,7 +366,8 @@ def generate_chat_conv(
if content.type == "text": if content.type == "text":
real_content += content.text real_content += content.text
elif content.type == "image_url": elif content.type == "image_url":
real_content += "<image>" # NOTE: Only works for llava
real_content += "<image>\n"
conv.append_image(content.image_url.url) conv.append_image(content.image_url.url)
conv.append_message(conv.roles[0], real_content) conv.append_message(conv.roles[0], real_content)
elif msg_role == "assistant": elif msg_role == "assistant":
......
...@@ -31,6 +31,7 @@ from sglang.srt.utils import ( ...@@ -31,6 +31,7 @@ from sglang.srt.utils import (
is_multimodal_model, is_multimodal_model,
set_random_seed, set_random_seed,
) )
from vllm.logger import _default_handler as vllm_default_handler
logger = logging.getLogger("model_rpc") logger = logging.getLogger("model_rpc")
...@@ -50,6 +51,9 @@ class ModelRpcServer(rpyc.Service): ...@@ -50,6 +51,9 @@ class ModelRpcServer(rpyc.Service):
self.tp_size = server_args.tp_size self.tp_size = server_args.tp_size
self.schedule_heuristic = server_args.schedule_heuristic self.schedule_heuristic = server_args.schedule_heuristic
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
vllm_default_handler.setLevel(
level=getattr(logging, server_args.log_level.upper())
)
# Init model and tokenizer # Init model and tokenizer
self.model_config = ModelConfig( self.model_config = ModelConfig(
...@@ -83,9 +87,11 @@ class ModelRpcServer(rpyc.Service): ...@@ -83,9 +87,11 @@ class ModelRpcServer(rpyc.Service):
self.max_num_running_seq = self.max_total_num_token // 2 self.max_num_running_seq = self.max_total_num_token // 2
self.max_prefill_num_token = max( self.max_prefill_num_token = max(
self.model_config.context_len, self.model_config.context_len,
(
self.max_total_num_token // 6 self.max_total_num_token // 6
if server_args.max_prefill_num_token is None if server_args.max_prefill_num_token is None
else server_args.max_prefill_num_token, else server_args.max_prefill_num_token
),
) )
self.int_token_logit_bias = torch.tensor( self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size) get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
......
...@@ -112,7 +112,9 @@ class InputMetadata: ...@@ -112,7 +112,9 @@ class InputMetadata:
(self.batch_size,), dtype=torch.int32, device="cuda" (self.batch_size,), dtype=torch.int32, device="cuda"
) )
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda") workspace_buffer = torch.empty(
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
)
if ( if (
self.forward_mode == ForwardMode.PREFILL self.forward_mode == ForwardMode.PREFILL
or self.forward_mode == ForwardMode.EXTEND or self.forward_mode == ForwardMode.EXTEND
...@@ -121,7 +123,9 @@ class InputMetadata: ...@@ -121,7 +123,9 @@ class InputMetadata:
(self.batch_size + 1,), dtype=torch.int32, device="cuda" (self.batch_size + 1,), dtype=torch.int32, device="cuda"
) )
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0) self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD") self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD"
)
self.prefill_wrapper.begin_forward( self.prefill_wrapper.begin_forward(
self.qo_indptr, self.qo_indptr,
self.kv_indptr, self.kv_indptr,
...@@ -131,7 +135,9 @@ class InputMetadata: ...@@ -131,7 +135,9 @@ class InputMetadata:
self.model_runner.model_config.num_key_value_heads // tp_size, self.model_runner.model_config.num_key_value_heads // tp_size,
) )
else: else:
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD") self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, "NHD"
)
self.decode_wrapper.begin_forward( self.decode_wrapper.begin_forward(
self.kv_indptr, self.kv_indptr,
self.kv_indices, self.kv_indices,
......
"""Memory pool.""" """Memory pool."""
import logging import logging
import torch import torch
......
"""Inference-only LLaVa model compatible with HuggingFace weights.""" """Inference-only LLaVa model compatible with HuggingFace weights."""
from typing import List, Optional from typing import List, Optional
import numpy as np import numpy as np
...@@ -269,7 +270,6 @@ class LlavaLlamaForCausalLM(nn.Module): ...@@ -269,7 +270,6 @@ class LlavaLlamaForCausalLM(nn.Module):
raise ValueError(f"Unexpected select feature: {self.select_feature}") raise ValueError(f"Unexpected select feature: {self.select_feature}")
# load mm_projector # load mm_projector
# TODO: support TP?
projector_weights = { projector_weights = {
"model.mm_projector.0": "multi_modal_projector.linear_1", "model.mm_projector.0": "multi_modal_projector.linear_1",
"model.mm_projector.2": "multi_modal_projector.linear_2", "model.mm_projector.2": "multi_modal_projector.linear_2",
......
"""Inference-only Mistral model.""" """Inference-only Mistral model."""
from sglang.srt.models.llama2 import LlamaForCausalLM from sglang.srt.models.llama2 import LlamaForCausalLM
......
...@@ -97,6 +97,7 @@ class MixtralMoE(nn.Module): ...@@ -97,6 +97,7 @@ class MixtralMoE(nn.Module):
self.experts = nn.ModuleList( self.experts = nn.ModuleList(
[ [
(
MixtralMLP( MixtralMLP(
self.num_total_experts, self.num_total_experts,
config.hidden_size, config.hidden_size,
...@@ -105,6 +106,7 @@ class MixtralMoE(nn.Module): ...@@ -105,6 +106,7 @@ class MixtralMoE(nn.Module):
) )
if idx in self.expert_indicies if idx in self.expert_indicies
else None else None
)
for idx in range(self.num_total_experts) for idx in range(self.num_total_experts)
] ]
) )
......
"""Inference-only Yi-VL model.""" """Inference-only Yi-VL model."""
import os import os
from typing import List, Optional from typing import List, Optional
......
"""Sampling parameters for text generation.""" """Sampling parameters for text generation."""
from typing import List, Optional, Union from typing import List, Optional, Union
_SAMPLING_EPS = 1e-6 _SAMPLING_EPS = 1e-6
......
"""SRT: SGLang Runtime""" """SRT: SGLang Runtime"""
import asyncio import asyncio
import json import json
import multiprocessing as mp import multiprocessing as mp
...@@ -493,7 +494,7 @@ def launch_server(server_args, pipe_finish_writer): ...@@ -493,7 +494,7 @@ def launch_server(server_args, pipe_finish_writer):
# Warmup # Warmup
try: try:
print("Warmup...", flush=True) # print("Warmup...", flush=True)
res = requests.post( res = requests.post(
url + "/generate", url + "/generate",
json={ json={
...@@ -505,8 +506,8 @@ def launch_server(server_args, pipe_finish_writer): ...@@ -505,8 +506,8 @@ def launch_server(server_args, pipe_finish_writer):
}, },
timeout=60, timeout=60,
) )
print(f"Warmup done. model response: {res.json()['text']}") # print(f"Warmup done. model response: {res.json()['text']}")
print("=" * 20, "Server is ready", "=" * 20, flush=True) # print("=" * 20, "Server is ready", "=" * 20, flush=True)
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
pipe_finish_writer.send(str(e)) pipe_finish_writer.send(str(e))
......
...@@ -122,7 +122,7 @@ def handle_port_init( ...@@ -122,7 +122,7 @@ def handle_port_init(
# first check on server port # first check on server port
if not check_port(port): if not check_port(port):
new_port = alloc_usable_network_port(1, used_list=[port])[0] new_port = alloc_usable_network_port(1, used_list=[port])[0]
print(f"Port {port} is not available, using {new_port} instead.") print(f"WARNING: Port {port} is not available. Use {new_port} instead.")
port = new_port port = new_port
# then we check on additional ports # then we check on additional ports
...@@ -157,8 +157,6 @@ def get_int_token_logit_bias(tokenizer, vocab_size): ...@@ -157,8 +157,6 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
ss = tokenizer.decode([t_id]).strip() ss = tokenizer.decode([t_id]).strip()
if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id): if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id):
logit_bias[t_id] = -1e5 logit_bias[t_id] = -1e5
# else:
# print(ss, t_id)
return logit_bias return logit_bias
......
"""Common utilities for testing and benchmarking""" """Common utilities for testing and benchmarking"""
import numpy as np import numpy as np
import requests import requests
from sglang.backend.openai import OpenAI from sglang.backend.openai import OpenAI
......
...@@ -22,7 +22,7 @@ def get_available_gpu_memory(gpu_id, distributed=True): ...@@ -22,7 +22,7 @@ def get_available_gpu_memory(gpu_id, distributed=True):
if torch.cuda.current_device() != gpu_id: if torch.cuda.current_device() != gpu_id:
print( print(
f"WARN: current device is not {gpu_id}, but {torch.cuda.current_device()}, ", f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
"which may cause useless memory allocation for torch CUDA context.", "which may cause useless memory allocation for torch CUDA context.",
) )
...@@ -95,7 +95,7 @@ def http_request(url, json=None, stream=False, auth_token=None): ...@@ -95,7 +95,7 @@ def http_request(url, json=None, stream=False, auth_token=None):
return requests.post(url, json=json, stream=True) return requests.post(url, json=json, stream=True)
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authentication": f"Bearer {auth_token}" "Authentication": f"Bearer {auth_token}",
} }
return requests.post(url, json=json, stream=True, headers=headers) return requests.post(url, json=json, stream=True, headers=headers)
else: else:
......
""" """
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
""" """
import json import json
import unittest import unittest
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment