Unverified Commit bc6915e3 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve type annotation and styles (#2926)

parent a883f079
...@@ -103,6 +103,7 @@ def tree_search(s, question, num_branches): ...@@ -103,6 +103,7 @@ def tree_search(s, question, num_branches):
def main(args): def main(args):
lines = read_jsonl(args.data_path) lines = read_jsonl(args.data_path)
lines = list(lines)
# Construct prompts # Construct prompts
num_branches = 2 num_branches = 2
......
...@@ -226,8 +226,9 @@ class Req: ...@@ -226,8 +226,9 @@ class Req:
else origin_input_ids # Before image padding else origin_input_ids # Before image padding
) )
self.origin_input_ids = origin_input_ids self.origin_input_ids = origin_input_ids
self.output_ids = [] # Each decode stage's output ids # Each decode stage's output ids
self.fill_ids = None # fill_ids = origin_input_ids + output_ids self.output_ids = []
# fill_ids = origin_input_ids + output_ids. Updated if chunked.
self.session_id = session_id self.session_id = session_id
self.input_embeds = input_embeds self.input_embeds = input_embeds
...@@ -265,6 +266,7 @@ class Req: ...@@ -265,6 +266,7 @@ class Req:
# Prefix info # Prefix info
self.prefix_indices = [] self.prefix_indices = []
# Tokens to run prefill. input_tokens - shared_prefix_tokens. # Tokens to run prefill. input_tokens - shared_prefix_tokens.
# Updated if chunked.
self.extend_input_len = 0 self.extend_input_len = 0
self.last_node = None self.last_node = None
...@@ -280,10 +282,10 @@ class Req: ...@@ -280,10 +282,10 @@ class Req:
self.top_logprobs_num = top_logprobs_num self.top_logprobs_num = top_logprobs_num
# Logprobs (return value) # Logprobs (return value)
self.input_token_logprobs_val = None self.input_token_logprobs_val: Optional[List[float]] = None
self.input_token_logprobs_idx = None self.input_token_logprobs_idx: Optional[List[int]] = None
self.input_top_logprobs_val = None self.input_top_logprobs_val: Optional[List[float]] = None
self.input_top_logprobs_idx = None self.input_top_logprobs_idx: Optional[List[int]] = None
if return_logprob: if return_logprob:
self.output_token_logprobs_val = [] self.output_token_logprobs_val = []
......
...@@ -22,8 +22,9 @@ import time ...@@ -22,8 +22,9 @@ import time
import warnings import warnings
from collections import deque from collections import deque
from concurrent import futures from concurrent import futures
from dataclasses import dataclass
from types import SimpleNamespace from types import SimpleNamespace
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple, Union
import psutil import psutil
import setproctitle import setproctitle
...@@ -102,6 +103,19 @@ logger = logging.getLogger(__name__) ...@@ -102,6 +103,19 @@ logger = logging.getLogger(__name__)
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT") test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
@dataclass
class GenerationBatchResult:
logits_output: LogitsProcessorOutput
next_token_ids: List[int]
bid: int
@dataclass
class EmbeddingBatchResult:
embeddings: torch.Tensor
bid: int
class Scheduler: class Scheduler:
"""A scheduler that manages a tensor parallel GPU worker.""" """A scheduler that manages a tensor parallel GPU worker."""
...@@ -411,16 +425,16 @@ class Scheduler: ...@@ -411,16 +425,16 @@ class Scheduler:
self.watchdog_last_time = time.time() self.watchdog_last_time = time.time()
while True: while True:
current = time.time()
if self.cur_batch is not None: if self.cur_batch is not None:
if self.watchdog_last_forward_ct == self.forward_ct: if self.watchdog_last_forward_ct == self.forward_ct:
if time.time() > self.watchdog_last_time + self.watchdog_timeout: if current > self.watchdog_last_time + self.watchdog_timeout:
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})") logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
break break
else: else:
self.watchdog_last_forward_ct = self.forward_ct self.watchdog_last_forward_ct = self.forward_ct
self.watchdog_last_time = time.time() self.watchdog_last_time = current
time.sleep(self.watchdog_timeout / 2) time.sleep(self.watchdog_timeout // 2)
# Wait sometimes so that the parent process can print the error. # Wait sometimes so that the parent process can print the error.
time.sleep(5) time.sleep(5)
self.parent_process.send_signal(signal.SIGQUIT) self.parent_process.send_signal(signal.SIGQUIT)
...@@ -1018,7 +1032,9 @@ class Scheduler: ...@@ -1018,7 +1032,9 @@ class Scheduler:
batch.prepare_for_decode() batch.prepare_for_decode()
return batch return batch
def run_batch(self, batch: ScheduleBatch): def run_batch(
self, batch: ScheduleBatch
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
"""Run a batch.""" """Run a batch."""
self.forward_ct += 1 self.forward_ct += 1
...@@ -1040,15 +1056,26 @@ class Scheduler: ...@@ -1040,15 +1056,26 @@ class Scheduler:
else: else:
assert False, "batch.extend_num_tokens == 0, this is unexpected!" assert False, "batch.extend_num_tokens == 0, this is unexpected!"
batch.output_ids = next_token_ids batch.output_ids = next_token_ids
ret = logits_output, next_token_ids, model_worker_batch.bid
ret = GenerationBatchResult(
logits_output=logits_output,
next_token_ids=next_token_ids,
bid=model_worker_batch.bid,
)
else: # embedding or reward model else: # embedding or reward model
assert batch.extend_num_tokens != 0 assert batch.extend_num_tokens != 0
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch) embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
ret = embeddings, model_worker_batch.bid ret = EmbeddingBatchResult(
embeddings=embeddings, bid=model_worker_batch.bid
)
return ret return ret
def process_batch_result(self, batch: ScheduleBatch, result): def process_batch_result(
self,
batch: ScheduleBatch,
result: Union[GenerationBatchResult, EmbeddingBatchResult],
):
if batch.forward_mode.is_decode(): if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result) self.process_batch_result_decode(batch, result)
if batch.is_empty(): if batch.is_empty():
...@@ -1057,17 +1084,29 @@ class Scheduler: ...@@ -1057,17 +1084,29 @@ class Scheduler:
self.process_batch_result_prefill(batch, result) self.process_batch_result_prefill(batch, result)
elif batch.forward_mode.is_idle(): elif batch.forward_mode.is_idle():
if self.enable_overlap: if self.enable_overlap:
self.tp_worker.resolve_batch_result(result[-1]) self.tp_worker.resolve_batch_result(result.bid)
elif batch.forward_mode.is_dummy_first(): elif batch.forward_mode.is_dummy_first():
batch.next_batch_sampling_info.update_regex_vocab_mask() batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize() self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set() batch.next_batch_sampling_info.sampling_info_done.set()
def process_batch_result_prefill(self, batch: ScheduleBatch, result): def process_batch_result_prefill(
self,
batch: ScheduleBatch,
result: Union[GenerationBatchResult, EmbeddingBatchResult],
):
skip_stream_req = None skip_stream_req = None
if self.is_generation: if self.is_generation:
logits_output, next_token_ids, bid = result (
logits_output,
next_token_ids,
bid,
) = (
result.logits_output,
result.next_token_ids,
result.bid,
)
if self.enable_overlap: if self.enable_overlap:
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid) logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
...@@ -1125,7 +1164,7 @@ class Scheduler: ...@@ -1125,7 +1164,7 @@ class Scheduler:
batch.next_batch_sampling_info.sampling_info_done.set() batch.next_batch_sampling_info.sampling_info_done.set()
else: # embedding or reward model else: # embedding or reward model
embeddings, bid = result embeddings, bid = result.embeddings, result.bid
embeddings = embeddings.tolist() embeddings = embeddings.tolist()
# Check finish conditions # Check finish conditions
...@@ -1149,8 +1188,16 @@ class Scheduler: ...@@ -1149,8 +1188,16 @@ class Scheduler:
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req) self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
def process_batch_result_decode(self, batch: ScheduleBatch, result): def process_batch_result_decode(
logits_output, next_token_ids, bid = result self,
batch: ScheduleBatch,
result: GenerationBatchResult,
):
logits_output, next_token_ids, bid = (
result.logits_output,
result.next_token_ids,
result.bid,
)
self.num_generated_tokens += len(batch.reqs) self.num_generated_tokens += len(batch.reqs)
if self.enable_overlap: if self.enable_overlap:
......
...@@ -37,6 +37,7 @@ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBack ...@@ -37,6 +37,7 @@ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBack
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
get_attention_tp_group, get_attention_tp_group,
get_attention_tp_size,
initialize_dp_attention, initialize_dp_attention,
) )
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
...@@ -532,7 +533,7 @@ class ModelRunner: ...@@ -532,7 +533,7 @@ class ModelRunner:
) )
else: else:
cell_size = ( cell_size = (
self.model_config.get_num_kv_heads(self.tp_size) self.model_config.get_num_kv_heads(get_attention_tp_size())
* self.model_config.head_dim * self.model_config.head_dim
* self.model_config.num_hidden_layers * self.model_config.num_hidden_layers
* 2 * 2
...@@ -626,7 +627,7 @@ class ModelRunner: ...@@ -626,7 +627,7 @@ class ModelRunner:
self.token_to_kv_pool = DoubleSparseTokenToKVPool( self.token_to_kv_pool = DoubleSparseTokenToKVPool(
self.max_total_num_tokens, self.max_total_num_tokens,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(self.tp_size), head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
head_dim=self.model_config.head_dim, head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers, layer_num=self.model_config.num_hidden_layers,
device=self.device, device=self.device,
...@@ -637,7 +638,7 @@ class ModelRunner: ...@@ -637,7 +638,7 @@ class ModelRunner:
self.token_to_kv_pool = MHATokenToKVPool( self.token_to_kv_pool = MHATokenToKVPool(
self.max_total_num_tokens, self.max_total_num_tokens,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(self.tp_size), head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
head_dim=self.model_config.head_dim, head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers, layer_num=self.model_config.num_hidden_layers,
device=self.device, device=self.device,
......
...@@ -180,6 +180,7 @@ class CompletionRequest(BaseModel): ...@@ -180,6 +180,7 @@ class CompletionRequest(BaseModel):
ignore_eos: bool = False ignore_eos: bool = False
skip_special_tokens: bool = True skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None
class CompletionResponseChoice(BaseModel): class CompletionResponseChoice(BaseModel):
...@@ -322,6 +323,7 @@ class ChatCompletionRequest(BaseModel): ...@@ -322,6 +323,7 @@ class ChatCompletionRequest(BaseModel):
ignore_eos: bool = False ignore_eos: bool = False
skip_special_tokens: bool = True skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None
class FunctionResponse(BaseModel): class FunctionResponse(BaseModel):
......
...@@ -842,7 +842,6 @@ class Engine: ...@@ -842,7 +842,6 @@ class Engine:
generator = ret.body_iterator generator = ret.body_iterator
async def generator_wrapper(): async def generator_wrapper():
offset = 0 offset = 0
while True: while True:
......
...@@ -239,8 +239,8 @@ class ServerArgs: ...@@ -239,8 +239,8 @@ class ServerArgs:
# Others # Others
if self.enable_dp_attention: if self.enable_dp_attention:
assert self.tp_size % self.dp_size == 0
self.dp_size = self.tp_size self.dp_size = self.tp_size
assert self.tp_size % self.dp_size == 0
self.chunked_prefill_size = self.chunked_prefill_size // 2 self.chunked_prefill_size = self.chunked_prefill_size // 2
self.schedule_conservativeness = self.schedule_conservativeness * 0.3 self.schedule_conservativeness = self.schedule_conservativeness * 0.3
logger.warning( logger.warning(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment