Unverified Commit e35a93fa authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Move output processing logic from scheduler.py into a separate file (#4354)

parent 2c3656f2
import logging import logging
from typing import List, Optional from typing import List
import torch import torch
import torch.distributed as dist import torch.distributed as dist
......
...@@ -441,28 +441,6 @@ class Req: ...@@ -441,28 +441,6 @@ class Req:
all_ids = self.origin_input_ids_unpadded + self.output_ids all_ids = self.origin_input_ids_unpadded + self.output_ids
return all_ids[self.surr_offset :], self.read_offset - self.surr_offset return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
def get_next_inc_detokenization(self):
if self.tokenizer is None:
return False, ""
read_ids, read_offset = self.init_incremental_detokenize()
surr_ids = read_ids[:read_offset]
surr_text = self.tokenizer.decode(
surr_ids,
skip_special_tokens=self.sampling_params.skip_special_tokens,
spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
)
new_text = self.tokenizer.decode(
read_ids,
skip_special_tokens=self.sampling_params.skip_special_tokens,
spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
)
if len(new_text) > len(surr_text) and not new_text.endswith("�"):
return True, new_text[len(surr_text) :]
return False, ""
def check_finished(self): def check_finished(self):
if self.finished(): if self.finished():
return return
......
This diff is collapsed.
This diff is collapsed.
...@@ -82,7 +82,6 @@ from sglang.srt.utils import ( ...@@ -82,7 +82,6 @@ from sglang.srt.utils import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None) SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300 UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
...@@ -119,6 +118,7 @@ class ModelRunner: ...@@ -119,6 +118,7 @@ class ModelRunner:
self.spec_algorithm = SpeculativeAlgorithm.from_string( self.spec_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm server_args.speculative_algorithm
) )
self.page_size = server_args.page_size
self.req_to_token_pool = req_to_token_pool self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
...@@ -161,6 +161,11 @@ class ModelRunner: ...@@ -161,6 +161,11 @@ class ModelRunner:
# Get memory before model loading # Get memory before model loading
min_per_gpu_memory = self.init_torch_distributed() min_per_gpu_memory = self.init_torch_distributed()
# If it is a draft model tp_group can be different.
self.initialize(min_per_gpu_memory)
def initialize(self, min_per_gpu_memory: float):
server_args = self.server_args
self.memory_saver_adapter = TorchMemorySaverAdapter.create( self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=self.server_args.enable_memory_saver enable=self.server_args.enable_memory_saver
) )
...@@ -300,15 +305,16 @@ class ModelRunner: ...@@ -300,15 +305,16 @@ class ModelRunner:
min_per_gpu_memory = get_available_gpu_memory( min_per_gpu_memory = get_available_gpu_memory(
self.device, self.gpu_id, distributed=self.tp_size > 1 self.device, self.gpu_id, distributed=self.tp_size > 1
) )
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
self.tp_group = get_tp_group() self.tp_group = get_tp_group()
self.attention_tp_group = get_attention_tp_group() self.attention_tp_group = get_attention_tp_group()
# Check memory for tensor parallelism # Check memory for tensor parallelism
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
if self.tp_size > 1: if self.tp_size > 1:
if min_per_gpu_memory < local_gpu_memory * 0.9: if min_per_gpu_memory < local_gpu_memory * 0.9:
raise ValueError( raise ValueError(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes." "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
) )
logger.info( logger.info(
...@@ -698,6 +704,12 @@ class ModelRunner: ...@@ -698,6 +704,12 @@ class ModelRunner:
) )
self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens) self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
self.max_total_num_tokens = (
self.max_total_num_tokens
// self.server_args.page_size
* self.server_args.page_size
)
if self.max_total_num_tokens <= 0: if self.max_total_num_tokens <= 0:
raise RuntimeError( raise RuntimeError(
"Not enough memory. Please try to increase --mem-fraction-static." "Not enough memory. Please try to increase --mem-fraction-static."
...@@ -783,7 +795,6 @@ class ModelRunner: ...@@ -783,7 +795,6 @@ class ModelRunner:
# Init streams # Init streams
if self.server_args.speculative_algorithm == "EAGLE": if self.server_args.speculative_algorithm == "EAGLE":
self.plan_stream_for_flashinfer = torch.cuda.Stream() self.plan_stream_for_flashinfer = torch.cuda.Stream()
self.attn_backend = FlashInferAttnBackend(self) self.attn_backend = FlashInferAttnBackend(self)
elif self.server_args.attention_backend == "triton": elif self.server_args.attention_backend == "triton":
assert self.sliding_window_size is None, ( assert self.sliding_window_size is None, (
......
...@@ -20,14 +20,13 @@ import random ...@@ -20,14 +20,13 @@ import random
import tempfile import tempfile
from typing import List, Optional from typing import List, Optional
import torch
from sglang.srt.hf_transformers_utils import check_gguf_file from sglang.srt.hf_transformers_utils import check_gguf_file
from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.utils import ( from sglang.srt.utils import (
get_amdgpu_memory_capacity, get_amdgpu_memory_capacity,
get_hpu_memory_capacity, get_hpu_memory_capacity,
get_nvgpu_memory_capacity, get_nvgpu_memory_capacity,
is_cuda,
is_flashinfer_available, is_flashinfer_available,
is_hip, is_hip,
is_port_available, is_port_available,
...@@ -71,6 +70,7 @@ class ServerArgs: ...@@ -71,6 +70,7 @@ class ServerArgs:
schedule_policy: str = "fcfs" schedule_policy: str = "fcfs"
schedule_conservativeness: float = 1.0 schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0 cpu_offload_gb: int = 0
page_size: int = 1
# Other runtime options # Other runtime options
tp_size: int = 1 tp_size: int = 1
...@@ -190,10 +190,10 @@ class ServerArgs: ...@@ -190,10 +190,10 @@ class ServerArgs:
if self.random_seed is None: if self.random_seed is None:
self.random_seed = random.randint(0, 1 << 30) self.random_seed = random.randint(0, 1 << 30)
if is_hip(): if is_cuda():
gpu_mem = get_amdgpu_memory_capacity()
elif torch.cuda.is_available():
gpu_mem = get_nvgpu_memory_capacity() gpu_mem = get_nvgpu_memory_capacity()
elif is_hip():
gpu_mem = get_amdgpu_memory_capacity()
elif self.device == "hpu": elif self.device == "hpu":
gpu_mem = get_hpu_memory_capacity() gpu_mem = get_hpu_memory_capacity()
else: else:
...@@ -258,7 +258,7 @@ class ServerArgs: ...@@ -258,7 +258,7 @@ class ServerArgs:
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
) )
# Others # Data parallelism attention
if self.enable_dp_attention: if self.enable_dp_attention:
self.dp_size = self.tp_size self.dp_size = self.tp_size
assert self.tp_size % self.dp_size == 0 assert self.tp_size % self.dp_size == 0
...@@ -507,6 +507,12 @@ class ServerArgs: ...@@ -507,6 +507,12 @@ class ServerArgs:
default=ServerArgs.cpu_offload_gb, default=ServerArgs.cpu_offload_gb,
help="How many GBs of RAM to reserve for CPU offloading.", help="How many GBs of RAM to reserve for CPU offloading.",
) )
parser.add_argument(
"--page-size",
type=int,
default=ServerArgs.page_size,
help="The number of tokens in a page.",
)
# Other runtime options # Other runtime options
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment