Unverified Commit abd5385a authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Move `global_server_args_dict` (#642)

parent 3de2f30a
...@@ -7,8 +7,8 @@ from torch import nn ...@@ -7,8 +7,8 @@ from torch import nn
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.extend_attention import extend_attention_fwd from sglang.srt.layers.extend_attention import extend_attention_fwd
from sglang.srt.layers.token_attention import token_attention_fwd from sglang.srt.layers.token_attention import token_attention_fwd
from sglang.srt.managers.controller.infer_batch import global_server_args_dict
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
from sglang.srt.server import global_server_args_dict
class RadixAttention(nn.Module): class RadixAttention(nn.Module):
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.managers.controller.model_runner import global_server_args_dict from sglang.srt.server import global_server_args_dict
from sglang.srt.utils import wrap_kernel_launcher from sglang.srt.utils import wrap_kernel_launcher
if global_server_args_dict.get("attention_reduce_in_fp32", False): if global_server_args_dict.get("attention_reduce_in_fp32", False):
......
...@@ -16,9 +16,6 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool ...@@ -16,9 +16,6 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
# Store some global server args
global_server_args_dict = {}
class ForwardMode(IntEnum): class ForwardMode(IntEnum):
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case. # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
......
...@@ -20,12 +20,7 @@ from vllm.model_executor.model_loader import get_model ...@@ -20,12 +20,7 @@ from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.managers.controller.infer_batch import ( from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata
Batch,
ForwardMode,
InputMetadata,
global_server_args_dict,
)
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
...@@ -91,12 +86,6 @@ class ModelRunner: ...@@ -91,12 +86,6 @@ class ModelRunner:
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes." "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
) )
# Set some global args
global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer
global_server_args_dict["attention_reduce_in_fp32"] = (
server_args.attention_reduce_in_fp32
)
# Load the model and create memory pool # Load the model and create memory pool
self.load_model() self.load_model()
self.init_memory_pool(total_gpu_memory) self.init_memory_pool(total_gpu_memory)
......
...@@ -64,6 +64,9 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) ...@@ -64,6 +64,9 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
app = FastAPI() app = FastAPI()
tokenizer_manager = None tokenizer_manager = None
# Put some args for easily access
global_server_args_dict = {}
@app.get("/health") @app.get("/health")
async def health() -> Response: async def health() -> Response:
...@@ -135,6 +138,14 @@ async def openai_v1_chat_completions(raw_request: Request): ...@@ -135,6 +138,14 @@ async def openai_v1_chat_completions(raw_request: Request):
return await v1_chat_completions(tokenizer_manager, raw_request) return await v1_chat_completions(tokenizer_manager, raw_request)
def _set_global_server_args(server_args: ServerArgs):
global global_server_args_dict
global_server_args_dict = {
"disable_flashinfer": server_args.disable_flashinfer,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
}
def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None): def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None):
global tokenizer_manager global tokenizer_manager
...@@ -163,6 +174,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg ...@@ -163,6 +174,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
# TODO: replace this with huggingface transformers template # TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api(server_args.chat_template) load_chat_template_for_openai_api(server_args.chat_template)
_set_global_server_args(server_args)
# Allocate ports # Allocate ports
assert server_args.tp_size % server_args.nnodes == 0 assert server_args.tp_size % server_args.nnodes == 0
tp_size_local = server_args.tp_size // server_args.nnodes tp_size_local = server_args.tp_size // server_args.nnodes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment