Unverified Commit 23f05005 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Format code & move functions (#155)

parent a7334aee
......@@ -193,6 +193,7 @@ def match_chat_ml(model_path: str):
if "qwen" in model_path and "chat" in model_path:
return get_chat_template("chatml")
@register_chat_template_matching_function
def match_chat_yi(model_path: str):
model_path = model_path.lower()
......
......@@ -64,13 +64,19 @@ class LogitsProcessor(nn.Module):
torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
]
logprobs_cumsum = torch.cumsum(prefill_logprobs, dim=0, dtype=torch.float32)
logprobs_cumsum = torch.cumsum(
prefill_logprobs, dim=0, dtype=torch.float32
)
start = input_metadata.extend_start_loc.clone()
end = start + input_metadata.extend_seq_lens - 2
start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + prefill_logprobs[start]
sum_logp = (
logprobs_cumsum[end]
- logprobs_cumsum[start]
+ prefill_logprobs[start]
)
normalized_logprobs = sum_logp / (
(input_metadata.extend_seq_lens - 1).clamp(min=1)
)
......
......@@ -13,14 +13,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
class RadixAttention(nn.Module):
def __init__(
self,
num_heads,
head_dim,
scaling,
num_kv_heads,
layer_id
):
def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id):
super().__init__()
self.tp_q_head_num = num_heads
self.tp_k_head_num = num_kv_heads
......
......@@ -100,6 +100,7 @@ class BatchStrOut:
class FlushCacheReq:
pass
@dataclass
class DetokenizeReqInput:
input_ids: List[int]
......@@ -11,8 +11,8 @@ import rpyc
import torch
from rpyc.utils.classic import obtain
from rpyc.utils.server import ThreadedServer
from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import (
BatchTokenIDOut,
......@@ -391,8 +391,12 @@ class ModelRpcServer(rpyc.Service):
logprobs = None
if batch.extend_num_tokens != 0:
# Forward
logits, (prefill_logprobs, normalized_logprobs, last_logprobs) = (
self.model_runner.forward(batch, ForwardMode.EXTEND, batch.return_logprob)
logits, (
prefill_logprobs,
normalized_logprobs,
last_logprobs,
) = self.model_runner.forward(
batch, ForwardMode.EXTEND, batch.return_logprob
)
if prefill_logprobs is not None:
logprobs = prefill_logprobs.cpu().tolist()
......@@ -407,7 +411,9 @@ class ModelRpcServer(rpyc.Service):
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
reqs = batch.reqs
if last_logprobs is not None:
last_logprobs = last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
last_logprobs = (
last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
)
# Check finish condition
pt = 0
......@@ -482,7 +488,9 @@ class ModelRpcServer(rpyc.Service):
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
reqs = batch.reqs
if last_logprobs is not None:
last_logprobs = last_logprobs[torch.arange(len(reqs)), next_token_ids].tolist()
last_logprobs = last_logprobs[
torch.arange(len(reqs)), next_token_ids
].tolist()
# Check finish condition
for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)):
......@@ -620,15 +628,16 @@ class ModelRpcClient:
self.step = async_wrap("step")
def start_model_process(port):
def _init_service(port):
t = ThreadedServer(
ModelRpcServer(),
port=port,
protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
)
t.start()
def _init_service(port):
t = ThreadedServer(
ModelRpcServer(),
port=port,
protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
)
t.start()
def start_model_process(port):
proc = multiprocessing.Process(target=_init_service, args=(port,))
proc.start()
time.sleep(1)
......
......@@ -17,8 +17,8 @@ from vllm.model_executor.model_loader import _set_default_torch_dtype
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
import sglang
QUANTIONCONFIG_MAPPING = {'awq': AWQConfig,
'gptq': GPTQConfig}
QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig}
logger = logging.getLogger("model_runner")
......@@ -283,9 +283,13 @@ class ModelRunner:
self.model_config.hf_config, "quantization_config", None
)
if hf_quant_config is not None:
quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_config['quant_method'])
quant_config_class = QUANTIONCONFIG_MAPPING.get(
hf_quant_config["quant_method"]
)
if quant_config_class is None:
raise ValueError(f"Unsupported quantization method: {hf_quant_config['quant_method']}")
raise ValueError(
f"Unsupported quantization method: {hf_quant_config['quant_method']}"
)
quant_config = quant_config_class.from_config(hf_quant_config)
logger.info(f"quant_config: {quant_config}")
linear_method = quant_config.get_linear_method()
......
......@@ -42,14 +42,14 @@ class QWenMLP(nn.Module):
2 * [intermediate_size],
bias=False,
gather_output=False,
linear_method=linear_method
linear_method=linear_method,
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
linear_method=linear_method
linear_method=linear_method,
)
if hidden_act != "silu":
raise ValueError(
......@@ -74,7 +74,7 @@ class QWenAttention(nn.Module):
layer_id: int = 0,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
linear_method: Optional[LinearMethodBase] = None
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.hidden_size = hidden_size
......@@ -86,18 +86,18 @@ class QWenAttention(nn.Module):
# pylint: disable=invalid-name
self.c_attn = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
bias=True,
linear_method=linear_method
hidden_size,
self.head_dim,
self.total_num_heads,
bias=True,
linear_method=linear_method,
)
self.c_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
linear_method=linear_method
linear_method=linear_method,
)
self.rotary_emb = get_rope(
self.head_dim,
......@@ -143,12 +143,16 @@ class QWenBlock(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
layer_id=layer_id,
linear_method=linear_method
linear_method=linear_method,
)
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2, linear_method=linear_method)
self.mlp = QWenMLP(
config.hidden_size,
config.intermediate_size // 2,
linear_method=linear_method,
)
def forward(
self,
......@@ -186,7 +190,10 @@ class QWenModel(nn.Module):
config.hidden_size,
)
self.h = nn.ModuleList(
[QWenBlock(config, i, linear_method=linear_method) for i in range(config.num_hidden_layers)]
[
QWenBlock(config, i, linear_method=linear_method)
for i in range(config.num_hidden_layers)
]
)
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
......
......@@ -4,14 +4,17 @@ from typing import List, Optional
import torch
import torch.nn as nn
from sglang.srt.models.llava import (
LlavaLlamaForCausalLM,
clip_vision_embed_forward,
monkey_path_clip_vision_embed_forward,
)
from transformers import CLIPVisionModel, LlavaConfig
from vllm.model_executor.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.models.llava import LlavaLlamaForCausalLM, clip_vision_embed_forward, monkey_path_clip_vision_embed_forward
class YiVLForCausalLM(LlavaLlamaForCausalLM):
def __init__(self, *args, **kwargs):
......@@ -19,7 +22,9 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
super().__init__(self.config)
self.multi_modal_projector = YiVLMultiModalProjector(self.config)
self.vision_tower_subfolder = self.config.mm_vision_tower.replace("./", "") # Everything after "./"
self.vision_tower_subfolder = self.config.mm_vision_tower.replace(
"./", ""
) # Everything after "./"
def load_weights(
self,
......@@ -30,7 +35,9 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
):
# We have to use the subfolder of the main model directory (e.g. 01-ai/Yi-VL-6B)
self.vision_tower = CLIPVisionModel.from_pretrained(
model_name_or_path, torch_dtype=torch.float16, subfolder=self.vision_tower_subfolder
model_name_or_path,
torch_dtype=torch.float16,
subfolder=self.vision_tower_subfolder,
).cuda()
self.vision_tower.eval()
......@@ -80,14 +87,19 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
monkey_path_clip_vision_embed_forward()
class YiVLMultiModalProjector(nn.Module):
def __init__(self, config: LlavaConfig):
super().__init__()
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size)
self.linear_1 = nn.Linear(
config.vision_config.hidden_size, config.text_config.hidden_size
)
self.ln_1 = nn.LayerNorm(config.text_config.hidden_size)
self.act = nn.GELU()
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size)
self.linear_2 = nn.Linear(
config.text_config.hidden_size, config.text_config.hidden_size
)
self.ln_2 = nn.LayerNorm(config.text_config.hidden_size)
def forward(self, image_features):
......@@ -98,4 +110,5 @@ class YiVLMultiModalProjector(nn.Module):
hidden_states = self.ln_2(hidden_states)
return hidden_states
EntryClass = YiVLForCausalLM
\ No newline at end of file
EntryClass = YiVLForCausalLM
......@@ -63,6 +63,7 @@ chat_template_name = None
# FIXME: Remove this once we drop support for pydantic 1.x
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
def jsonify_pydantic_model(obj: BaseModel):
if IS_PYDANTIC_1:
return obj.json(ensure_ascii=False)
......@@ -165,7 +166,7 @@ async def v1_completions(raw_request: Request):
prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"]
if not stream_buffer: # The first chunk
if not stream_buffer: # The first chunk
if request.echo:
# Prepend prompt in response text.
text = request.prompt + text
......@@ -219,7 +220,9 @@ async def v1_completions(raw_request: Request):
token_logprob_pos = prompt_tokens
logprobs = (
await make_openai_style_logprobs(ret["meta_info"]["token_logprob"][token_logprob_pos:])
await make_openai_style_logprobs(
ret["meta_info"]["token_logprob"][token_logprob_pos:]
)
if request.logprobs is not None
else None
)
......
......@@ -114,7 +114,7 @@ class ServerArgs:
"--max-prefill-num-token",
type=int,
default=ServerArgs.max_prefill_num_token,
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length."
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
)
parser.add_argument(
"--tp-size",
......
......@@ -259,4 +259,4 @@ def load_image(image_file):
else:
image = Image.open(BytesIO(base64.b64decode(image_file)))
return image
\ No newline at end of file
return image
......@@ -12,6 +12,7 @@ import argparse
import requests
def test_decode(url, return_logprob):
response = requests.post(
url + "/generate",
......@@ -27,6 +28,7 @@ def test_decode(url, return_logprob):
)
print(response.json())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
......
......@@ -12,6 +12,7 @@ import json
import requests
def test_decode_stream(url, return_logprob):
response = requests.post(
url + "/generate",
......@@ -39,7 +40,7 @@ def test_decode_stream(url, return_logprob):
assert data["meta_info"]["prompt_logprob"] is not None
assert data["meta_info"]["token_logprob"] is not None
assert data["meta_info"]["normalized_prompt_logprob"] is not None
if prev == 0: # Skip prompt logprobs
if prev == 0: # Skip prompt logprobs
prev = data["meta_info"]["prompt_tokens"]
for token_txt, _, logprob in data["meta_info"]["token_logprob"][prev:]:
print(f"{token_txt}\t{logprob}", flush=True)
......@@ -50,6 +51,7 @@ def test_decode_stream(url, return_logprob):
prev = len(output)
print("")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
......
......@@ -64,9 +64,8 @@ def test_completion_stream(args, echo, logprobs):
first = False
if logprobs:
print(
f"{r.choices[0].text:12s}\t"
f"{r.choices[0].logprobs.token_logprobs}",
flush=True
f"{r.choices[0].text:12s}\t" f"{r.choices[0].logprobs.token_logprobs}",
flush=True,
)
else:
print(r.choices[0].text, end="", flush=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment