Unverified Commit bc6915e3 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve type annotation and styles (#2926)

parent a883f079
......@@ -103,6 +103,7 @@ def tree_search(s, question, num_branches):
def main(args):
lines = read_jsonl(args.data_path)
lines = list(lines)
# Construct prompts
num_branches = 2
......
......@@ -226,8 +226,9 @@ class Req:
else origin_input_ids # Before image padding
)
self.origin_input_ids = origin_input_ids
self.output_ids = [] # Each decode stage's output ids
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
# Each decode stage's output ids
self.output_ids = []
# fill_ids = origin_input_ids + output_ids. Updated if chunked.
self.session_id = session_id
self.input_embeds = input_embeds
......@@ -265,6 +266,7 @@ class Req:
# Prefix info
self.prefix_indices = []
# Tokens to run prefill. input_tokens - shared_prefix_tokens.
# Updated if chunked.
self.extend_input_len = 0
self.last_node = None
......@@ -280,10 +282,10 @@ class Req:
self.top_logprobs_num = top_logprobs_num
# Logprobs (return value)
self.input_token_logprobs_val = None
self.input_token_logprobs_idx = None
self.input_top_logprobs_val = None
self.input_top_logprobs_idx = None
self.input_token_logprobs_val: Optional[List[float]] = None
self.input_token_logprobs_idx: Optional[List[int]] = None
self.input_top_logprobs_val: Optional[List[float]] = None
self.input_top_logprobs_idx: Optional[List[int]] = None
if return_logprob:
self.output_token_logprobs_val = []
......
......@@ -22,8 +22,9 @@ import time
import warnings
from collections import deque
from concurrent import futures
from dataclasses import dataclass
from types import SimpleNamespace
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union
import psutil
import setproctitle
......@@ -102,6 +103,19 @@ logger = logging.getLogger(__name__)
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
@dataclass
class GenerationBatchResult:
logits_output: LogitsProcessorOutput
next_token_ids: List[int]
bid: int
@dataclass
class EmbeddingBatchResult:
embeddings: torch.Tensor
bid: int
class Scheduler:
"""A scheduler that manages a tensor parallel GPU worker."""
......@@ -411,16 +425,16 @@ class Scheduler:
self.watchdog_last_time = time.time()
while True:
current = time.time()
if self.cur_batch is not None:
if self.watchdog_last_forward_ct == self.forward_ct:
if time.time() > self.watchdog_last_time + self.watchdog_timeout:
if current > self.watchdog_last_time + self.watchdog_timeout:
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
break
else:
self.watchdog_last_forward_ct = self.forward_ct
self.watchdog_last_time = time.time()
time.sleep(self.watchdog_timeout / 2)
self.watchdog_last_time = current
time.sleep(self.watchdog_timeout // 2)
# Wait sometimes so that the parent process can print the error.
time.sleep(5)
self.parent_process.send_signal(signal.SIGQUIT)
......@@ -1018,7 +1032,9 @@ class Scheduler:
batch.prepare_for_decode()
return batch
def run_batch(self, batch: ScheduleBatch):
def run_batch(
self, batch: ScheduleBatch
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
"""Run a batch."""
self.forward_ct += 1
......@@ -1040,15 +1056,26 @@ class Scheduler:
else:
assert False, "batch.extend_num_tokens == 0, this is unexpected!"
batch.output_ids = next_token_ids
ret = logits_output, next_token_ids, model_worker_batch.bid
ret = GenerationBatchResult(
logits_output=logits_output,
next_token_ids=next_token_ids,
bid=model_worker_batch.bid,
)
else: # embedding or reward model
assert batch.extend_num_tokens != 0
model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
ret = embeddings, model_worker_batch.bid
ret = EmbeddingBatchResult(
embeddings=embeddings, bid=model_worker_batch.bid
)
return ret
def process_batch_result(self, batch: ScheduleBatch, result):
def process_batch_result(
self,
batch: ScheduleBatch,
result: Union[GenerationBatchResult, EmbeddingBatchResult],
):
if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result)
if batch.is_empty():
......@@ -1057,17 +1084,29 @@ class Scheduler:
self.process_batch_result_prefill(batch, result)
elif batch.forward_mode.is_idle():
if self.enable_overlap:
self.tp_worker.resolve_batch_result(result[-1])
self.tp_worker.resolve_batch_result(result.bid)
elif batch.forward_mode.is_dummy_first():
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
def process_batch_result_prefill(
self,
batch: ScheduleBatch,
result: Union[GenerationBatchResult, EmbeddingBatchResult],
):
skip_stream_req = None
if self.is_generation:
logits_output, next_token_ids, bid = result
(
logits_output,
next_token_ids,
bid,
) = (
result.logits_output,
result.next_token_ids,
result.bid,
)
if self.enable_overlap:
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
......@@ -1125,7 +1164,7 @@ class Scheduler:
batch.next_batch_sampling_info.sampling_info_done.set()
else: # embedding or reward model
embeddings, bid = result
embeddings, bid = result.embeddings, result.bid
embeddings = embeddings.tolist()
# Check finish conditions
......@@ -1149,8 +1188,16 @@ class Scheduler:
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
def process_batch_result_decode(self, batch: ScheduleBatch, result):
logits_output, next_token_ids, bid = result
def process_batch_result_decode(
self,
batch: ScheduleBatch,
result: GenerationBatchResult,
):
logits_output, next_token_ids, bid = (
result.logits_output,
result.next_token_ids,
result.bid,
)
self.num_generated_tokens += len(batch.reqs)
if self.enable_overlap:
......
......@@ -37,6 +37,7 @@ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBack
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
from sglang.srt.layers.dp_attention import (
get_attention_tp_group,
get_attention_tp_size,
initialize_dp_attention,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
......@@ -532,7 +533,7 @@ class ModelRunner:
)
else:
cell_size = (
self.model_config.get_num_kv_heads(self.tp_size)
self.model_config.get_num_kv_heads(get_attention_tp_size())
* self.model_config.head_dim
* self.model_config.num_hidden_layers
* 2
......@@ -626,7 +627,7 @@ class ModelRunner:
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(self.tp_size),
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
device=self.device,
......@@ -637,7 +638,7 @@ class ModelRunner:
self.token_to_kv_pool = MHATokenToKVPool(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(self.tp_size),
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
device=self.device,
......
......@@ -180,6 +180,7 @@ class CompletionRequest(BaseModel):
ignore_eos: bool = False
skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None
class CompletionResponseChoice(BaseModel):
......@@ -322,6 +323,7 @@ class ChatCompletionRequest(BaseModel):
ignore_eos: bool = False
skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None
class FunctionResponse(BaseModel):
......
......@@ -842,7 +842,6 @@ class Engine:
generator = ret.body_iterator
async def generator_wrapper():
offset = 0
while True:
......
......@@ -239,8 +239,8 @@ class ServerArgs:
# Others
if self.enable_dp_attention:
assert self.tp_size % self.dp_size == 0
self.dp_size = self.tp_size
assert self.tp_size % self.dp_size == 0
self.chunked_prefill_size = self.chunked_prefill_size // 2
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
logger.warning(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment