Unverified Commit 23f05005 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Format code & move functions (#155)

parent a7334aee
...@@ -193,6 +193,7 @@ def match_chat_ml(model_path: str): ...@@ -193,6 +193,7 @@ def match_chat_ml(model_path: str):
if "qwen" in model_path and "chat" in model_path: if "qwen" in model_path and "chat" in model_path:
return get_chat_template("chatml") return get_chat_template("chatml")
@register_chat_template_matching_function @register_chat_template_matching_function
def match_chat_yi(model_path: str): def match_chat_yi(model_path: str):
model_path = model_path.lower() model_path = model_path.lower()
......
...@@ -64,13 +64,19 @@ class LogitsProcessor(nn.Module): ...@@ -64,13 +64,19 @@ class LogitsProcessor(nn.Module):
torch.arange(all_logprobs.shape[0], device="cuda"), torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]), torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
] ]
logprobs_cumsum = torch.cumsum(prefill_logprobs, dim=0, dtype=torch.float32) logprobs_cumsum = torch.cumsum(
prefill_logprobs, dim=0, dtype=torch.float32
)
start = input_metadata.extend_start_loc.clone() start = input_metadata.extend_start_loc.clone()
end = start + input_metadata.extend_seq_lens - 2 end = start + input_metadata.extend_seq_lens - 2
start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1) start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1) end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + prefill_logprobs[start] sum_logp = (
logprobs_cumsum[end]
- logprobs_cumsum[start]
+ prefill_logprobs[start]
)
normalized_logprobs = sum_logp / ( normalized_logprobs = sum_logp / (
(input_metadata.extend_seq_lens - 1).clamp(min=1) (input_metadata.extend_seq_lens - 1).clamp(min=1)
) )
......
...@@ -13,14 +13,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( ...@@ -13,14 +13,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
class RadixAttention(nn.Module): class RadixAttention(nn.Module):
def __init__( def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id):
self,
num_heads,
head_dim,
scaling,
num_kv_heads,
layer_id
):
super().__init__() super().__init__()
self.tp_q_head_num = num_heads self.tp_q_head_num = num_heads
self.tp_k_head_num = num_kv_heads self.tp_k_head_num = num_kv_heads
......
...@@ -100,6 +100,7 @@ class BatchStrOut: ...@@ -100,6 +100,7 @@ class BatchStrOut:
class FlushCacheReq: class FlushCacheReq:
pass pass
@dataclass @dataclass
class DetokenizeReqInput: class DetokenizeReqInput:
input_ids: List[int] input_ids: List[int]
...@@ -11,8 +11,8 @@ import rpyc ...@@ -11,8 +11,8 @@ import rpyc
import torch import torch
from rpyc.utils.classic import obtain from rpyc.utils.classic import obtain
from rpyc.utils.server import ThreadedServer from rpyc.utils.server import ThreadedServer
from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
BatchTokenIDOut, BatchTokenIDOut,
...@@ -391,8 +391,12 @@ class ModelRpcServer(rpyc.Service): ...@@ -391,8 +391,12 @@ class ModelRpcServer(rpyc.Service):
logprobs = None logprobs = None
if batch.extend_num_tokens != 0: if batch.extend_num_tokens != 0:
# Forward # Forward
logits, (prefill_logprobs, normalized_logprobs, last_logprobs) = ( logits, (
self.model_runner.forward(batch, ForwardMode.EXTEND, batch.return_logprob) prefill_logprobs,
normalized_logprobs,
last_logprobs,
) = self.model_runner.forward(
batch, ForwardMode.EXTEND, batch.return_logprob
) )
if prefill_logprobs is not None: if prefill_logprobs is not None:
logprobs = prefill_logprobs.cpu().tolist() logprobs = prefill_logprobs.cpu().tolist()
...@@ -407,7 +411,9 @@ class ModelRpcServer(rpyc.Service): ...@@ -407,7 +411,9 @@ class ModelRpcServer(rpyc.Service):
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead. # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
reqs = batch.reqs reqs = batch.reqs
if last_logprobs is not None: if last_logprobs is not None:
last_logprobs = last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist() last_logprobs = (
last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
)
# Check finish condition # Check finish condition
pt = 0 pt = 0
...@@ -482,7 +488,9 @@ class ModelRpcServer(rpyc.Service): ...@@ -482,7 +488,9 @@ class ModelRpcServer(rpyc.Service):
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead. # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
reqs = batch.reqs reqs = batch.reqs
if last_logprobs is not None: if last_logprobs is not None:
last_logprobs = last_logprobs[torch.arange(len(reqs)), next_token_ids].tolist() last_logprobs = last_logprobs[
torch.arange(len(reqs)), next_token_ids
].tolist()
# Check finish condition # Check finish condition
for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)): for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)):
...@@ -620,15 +628,16 @@ class ModelRpcClient: ...@@ -620,15 +628,16 @@ class ModelRpcClient:
self.step = async_wrap("step") self.step = async_wrap("step")
def start_model_process(port): def _init_service(port):
def _init_service(port): t = ThreadedServer(
t = ThreadedServer( ModelRpcServer(),
ModelRpcServer(), port=port,
port=port, protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
protocol_config={"allow_pickle": True, "sync_request_timeout": 1800}, )
) t.start()
t.start()
def start_model_process(port):
proc = multiprocessing.Process(target=_init_service, args=(port,)) proc = multiprocessing.Process(target=_init_service, args=(port,))
proc.start() proc.start()
time.sleep(1) time.sleep(1)
......
...@@ -17,8 +17,8 @@ from vllm.model_executor.model_loader import _set_default_torch_dtype ...@@ -17,8 +17,8 @@ from vllm.model_executor.model_loader import _set_default_torch_dtype
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
import sglang import sglang
QUANTIONCONFIG_MAPPING = {'awq': AWQConfig,
'gptq': GPTQConfig} QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig}
logger = logging.getLogger("model_runner") logger = logging.getLogger("model_runner")
...@@ -283,9 +283,13 @@ class ModelRunner: ...@@ -283,9 +283,13 @@ class ModelRunner:
self.model_config.hf_config, "quantization_config", None self.model_config.hf_config, "quantization_config", None
) )
if hf_quant_config is not None: if hf_quant_config is not None:
quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_config['quant_method']) quant_config_class = QUANTIONCONFIG_MAPPING.get(
hf_quant_config["quant_method"]
)
if quant_config_class is None: if quant_config_class is None:
raise ValueError(f"Unsupported quantization method: {hf_quant_config['quant_method']}") raise ValueError(
f"Unsupported quantization method: {hf_quant_config['quant_method']}"
)
quant_config = quant_config_class.from_config(hf_quant_config) quant_config = quant_config_class.from_config(hf_quant_config)
logger.info(f"quant_config: {quant_config}") logger.info(f"quant_config: {quant_config}")
linear_method = quant_config.get_linear_method() linear_method = quant_config.get_linear_method()
......
...@@ -42,14 +42,14 @@ class QWenMLP(nn.Module): ...@@ -42,14 +42,14 @@ class QWenMLP(nn.Module):
2 * [intermediate_size], 2 * [intermediate_size],
bias=False, bias=False,
gather_output=False, gather_output=False,
linear_method=linear_method linear_method=linear_method,
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, input_is_parallel=True,
linear_method=linear_method linear_method=linear_method,
) )
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError( raise ValueError(
...@@ -74,7 +74,7 @@ class QWenAttention(nn.Module): ...@@ -74,7 +74,7 @@ class QWenAttention(nn.Module):
layer_id: int = 0, layer_id: int = 0,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
linear_method: Optional[LinearMethodBase] = None linear_method: Optional[LinearMethodBase] = None,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -86,18 +86,18 @@ class QWenAttention(nn.Module): ...@@ -86,18 +86,18 @@ class QWenAttention(nn.Module):
# pylint: disable=invalid-name # pylint: disable=invalid-name
self.c_attn = QKVParallelLinear( self.c_attn = QKVParallelLinear(
hidden_size, hidden_size,
self.head_dim, self.head_dim,
self.total_num_heads, self.total_num_heads,
bias=True, bias=True,
linear_method=linear_method linear_method=linear_method,
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, input_is_parallel=True,
linear_method=linear_method linear_method=linear_method,
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
...@@ -143,12 +143,16 @@ class QWenBlock(nn.Module): ...@@ -143,12 +143,16 @@ class QWenBlock(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
layer_id=layer_id, layer_id=layer_id,
linear_method=linear_method linear_method=linear_method,
) )
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2, linear_method=linear_method) self.mlp = QWenMLP(
config.hidden_size,
config.intermediate_size // 2,
linear_method=linear_method,
)
def forward( def forward(
self, self,
...@@ -186,7 +190,10 @@ class QWenModel(nn.Module): ...@@ -186,7 +190,10 @@ class QWenModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.h = nn.ModuleList( self.h = nn.ModuleList(
[QWenBlock(config, i, linear_method=linear_method) for i in range(config.num_hidden_layers)] [
QWenBlock(config, i, linear_method=linear_method)
for i in range(config.num_hidden_layers)
]
) )
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
......
...@@ -4,14 +4,17 @@ from typing import List, Optional ...@@ -4,14 +4,17 @@ from typing import List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from sglang.srt.models.llava import (
LlavaLlamaForCausalLM,
clip_vision_embed_forward,
monkey_path_clip_vision_embed_forward,
)
from transformers import CLIPVisionModel, LlavaConfig from transformers import CLIPVisionModel, LlavaConfig
from vllm.model_executor.weight_utils import ( from vllm.model_executor.weight_utils import (
default_weight_loader, default_weight_loader,
hf_model_weights_iterator, hf_model_weights_iterator,
) )
from sglang.srt.models.llava import LlavaLlamaForCausalLM, clip_vision_embed_forward, monkey_path_clip_vision_embed_forward
class YiVLForCausalLM(LlavaLlamaForCausalLM): class YiVLForCausalLM(LlavaLlamaForCausalLM):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -19,7 +22,9 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM): ...@@ -19,7 +22,9 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
super().__init__(self.config) super().__init__(self.config)
self.multi_modal_projector = YiVLMultiModalProjector(self.config) self.multi_modal_projector = YiVLMultiModalProjector(self.config)
self.vision_tower_subfolder = self.config.mm_vision_tower.replace("./", "") # Everything after "./" self.vision_tower_subfolder = self.config.mm_vision_tower.replace(
"./", ""
) # Everything after "./"
def load_weights( def load_weights(
self, self,
...@@ -30,7 +35,9 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM): ...@@ -30,7 +35,9 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
): ):
# We have to use the subfolder of the main model directory (e.g. 01-ai/Yi-VL-6B) # We have to use the subfolder of the main model directory (e.g. 01-ai/Yi-VL-6B)
self.vision_tower = CLIPVisionModel.from_pretrained( self.vision_tower = CLIPVisionModel.from_pretrained(
model_name_or_path, torch_dtype=torch.float16, subfolder=self.vision_tower_subfolder model_name_or_path,
torch_dtype=torch.float16,
subfolder=self.vision_tower_subfolder,
).cuda() ).cuda()
self.vision_tower.eval() self.vision_tower.eval()
...@@ -80,14 +87,19 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM): ...@@ -80,14 +87,19 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
monkey_path_clip_vision_embed_forward() monkey_path_clip_vision_embed_forward()
class YiVLMultiModalProjector(nn.Module): class YiVLMultiModalProjector(nn.Module):
def __init__(self, config: LlavaConfig): def __init__(self, config: LlavaConfig):
super().__init__() super().__init__()
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size) self.linear_1 = nn.Linear(
config.vision_config.hidden_size, config.text_config.hidden_size
)
self.ln_1 = nn.LayerNorm(config.text_config.hidden_size) self.ln_1 = nn.LayerNorm(config.text_config.hidden_size)
self.act = nn.GELU() self.act = nn.GELU()
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size) self.linear_2 = nn.Linear(
config.text_config.hidden_size, config.text_config.hidden_size
)
self.ln_2 = nn.LayerNorm(config.text_config.hidden_size) self.ln_2 = nn.LayerNorm(config.text_config.hidden_size)
def forward(self, image_features): def forward(self, image_features):
...@@ -98,4 +110,5 @@ class YiVLMultiModalProjector(nn.Module): ...@@ -98,4 +110,5 @@ class YiVLMultiModalProjector(nn.Module):
hidden_states = self.ln_2(hidden_states) hidden_states = self.ln_2(hidden_states)
return hidden_states return hidden_states
EntryClass = YiVLForCausalLM
\ No newline at end of file EntryClass = YiVLForCausalLM
...@@ -63,6 +63,7 @@ chat_template_name = None ...@@ -63,6 +63,7 @@ chat_template_name = None
# FIXME: Remove this once we drop support for pydantic 1.x # FIXME: Remove this once we drop support for pydantic 1.x
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1 IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
def jsonify_pydantic_model(obj: BaseModel): def jsonify_pydantic_model(obj: BaseModel):
if IS_PYDANTIC_1: if IS_PYDANTIC_1:
return obj.json(ensure_ascii=False) return obj.json(ensure_ascii=False)
...@@ -165,7 +166,7 @@ async def v1_completions(raw_request: Request): ...@@ -165,7 +166,7 @@ async def v1_completions(raw_request: Request):
prompt_tokens = content["meta_info"]["prompt_tokens"] prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"] completion_tokens = content["meta_info"]["completion_tokens"]
if not stream_buffer: # The first chunk if not stream_buffer: # The first chunk
if request.echo: if request.echo:
# Prepend prompt in response text. # Prepend prompt in response text.
text = request.prompt + text text = request.prompt + text
...@@ -219,7 +220,9 @@ async def v1_completions(raw_request: Request): ...@@ -219,7 +220,9 @@ async def v1_completions(raw_request: Request):
token_logprob_pos = prompt_tokens token_logprob_pos = prompt_tokens
logprobs = ( logprobs = (
await make_openai_style_logprobs(ret["meta_info"]["token_logprob"][token_logprob_pos:]) await make_openai_style_logprobs(
ret["meta_info"]["token_logprob"][token_logprob_pos:]
)
if request.logprobs is not None if request.logprobs is not None
else None else None
) )
......
...@@ -114,7 +114,7 @@ class ServerArgs: ...@@ -114,7 +114,7 @@ class ServerArgs:
"--max-prefill-num-token", "--max-prefill-num-token",
type=int, type=int,
default=ServerArgs.max_prefill_num_token, default=ServerArgs.max_prefill_num_token,
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length." help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
) )
parser.add_argument( parser.add_argument(
"--tp-size", "--tp-size",
......
...@@ -259,4 +259,4 @@ def load_image(image_file): ...@@ -259,4 +259,4 @@ def load_image(image_file):
else: else:
image = Image.open(BytesIO(base64.b64decode(image_file))) image = Image.open(BytesIO(base64.b64decode(image_file)))
return image return image
\ No newline at end of file
...@@ -12,6 +12,7 @@ import argparse ...@@ -12,6 +12,7 @@ import argparse
import requests import requests
def test_decode(url, return_logprob): def test_decode(url, return_logprob):
response = requests.post( response = requests.post(
url + "/generate", url + "/generate",
...@@ -27,6 +28,7 @@ def test_decode(url, return_logprob): ...@@ -27,6 +28,7 @@ def test_decode(url, return_logprob):
) )
print(response.json()) print(response.json())
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1") parser.add_argument("--host", type=str, default="http://127.0.0.1")
......
...@@ -12,6 +12,7 @@ import json ...@@ -12,6 +12,7 @@ import json
import requests import requests
def test_decode_stream(url, return_logprob): def test_decode_stream(url, return_logprob):
response = requests.post( response = requests.post(
url + "/generate", url + "/generate",
...@@ -39,7 +40,7 @@ def test_decode_stream(url, return_logprob): ...@@ -39,7 +40,7 @@ def test_decode_stream(url, return_logprob):
assert data["meta_info"]["prompt_logprob"] is not None assert data["meta_info"]["prompt_logprob"] is not None
assert data["meta_info"]["token_logprob"] is not None assert data["meta_info"]["token_logprob"] is not None
assert data["meta_info"]["normalized_prompt_logprob"] is not None assert data["meta_info"]["normalized_prompt_logprob"] is not None
if prev == 0: # Skip prompt logprobs if prev == 0: # Skip prompt logprobs
prev = data["meta_info"]["prompt_tokens"] prev = data["meta_info"]["prompt_tokens"]
for token_txt, _, logprob in data["meta_info"]["token_logprob"][prev:]: for token_txt, _, logprob in data["meta_info"]["token_logprob"][prev:]:
print(f"{token_txt}\t{logprob}", flush=True) print(f"{token_txt}\t{logprob}", flush=True)
...@@ -50,6 +51,7 @@ def test_decode_stream(url, return_logprob): ...@@ -50,6 +51,7 @@ def test_decode_stream(url, return_logprob):
prev = len(output) prev = len(output)
print("") print("")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1") parser.add_argument("--host", type=str, default="http://127.0.0.1")
......
...@@ -64,9 +64,8 @@ def test_completion_stream(args, echo, logprobs): ...@@ -64,9 +64,8 @@ def test_completion_stream(args, echo, logprobs):
first = False first = False
if logprobs: if logprobs:
print( print(
f"{r.choices[0].text:12s}\t" f"{r.choices[0].text:12s}\t" f"{r.choices[0].logprobs.token_logprobs}",
f"{r.choices[0].logprobs.token_logprobs}", flush=True,
flush=True
) )
else: else:
print(r.choices[0].text, end="", flush=True) print(r.choices[0].text, end="", flush=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment