Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
abd5385a
Unverified
Commit
abd5385a
authored
Jul 17, 2024
by
Liangsheng Yin
Committed by
GitHub
Jul 17, 2024
Browse files
Move `global_server_args_dict` (#642)
parent
3de2f30a
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
16 additions
and
17 deletions
+16
-17
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+1
-1
python/sglang/srt/layers/token_attention.py
python/sglang/srt/layers/token_attention.py
+1
-1
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+0
-3
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+1
-12
python/sglang/srt/server.py
python/sglang/srt/server.py
+13
-0
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
abd5385a
...
@@ -7,8 +7,8 @@ from torch import nn
...
@@ -7,8 +7,8 @@ from torch import nn
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.layers.extend_attention
import
extend_attention_fwd
from
sglang.srt.layers.extend_attention
import
extend_attention_fwd
from
sglang.srt.layers.token_attention
import
token_attention_fwd
from
sglang.srt.layers.token_attention
import
token_attention_fwd
from
sglang.srt.managers.controller.infer_batch
import
global_server_args_dict
from
sglang.srt.managers.controller.model_runner
import
ForwardMode
,
InputMetadata
from
sglang.srt.managers.controller.model_runner
import
ForwardMode
,
InputMetadata
from
sglang.srt.server
import
global_server_args_dict
class
RadixAttention
(
nn
.
Module
):
class
RadixAttention
(
nn
.
Module
):
...
...
python/sglang/srt/layers/token_attention.py
View file @
abd5385a
...
@@ -5,7 +5,7 @@ import torch
...
@@ -5,7 +5,7 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.
managers.controller.model_runn
er
import
global_server_args_dict
from
sglang.srt.
serv
er
import
global_server_args_dict
from
sglang.srt.utils
import
wrap_kernel_launcher
from
sglang.srt.utils
import
wrap_kernel_launcher
if
global_server_args_dict
.
get
(
"attention_reduce_in_fp32"
,
False
):
if
global_server_args_dict
.
get
(
"attention_reduce_in_fp32"
,
False
):
...
...
python/sglang/srt/managers/controller/infer_batch.py
View file @
abd5385a
...
@@ -16,9 +16,6 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
...
@@ -16,9 +16,6 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
# Store some global server args
global_server_args_dict
=
{}
class
ForwardMode
(
IntEnum
):
class
ForwardMode
(
IntEnum
):
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
...
...
python/sglang/srt/managers/controller/model_runner.py
View file @
abd5385a
...
@@ -20,12 +20,7 @@ from vllm.model_executor.model_loader import get_model
...
@@ -20,12 +20,7 @@ from vllm.model_executor.model_loader import get_model
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.managers.controller.infer_batch
import
(
from
sglang.srt.managers.controller.infer_batch
import
Batch
,
ForwardMode
,
InputMetadata
Batch
,
ForwardMode
,
InputMetadata
,
global_server_args_dict
,
)
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
...
@@ -91,12 +86,6 @@ class ModelRunner:
...
@@ -91,12 +86,6 @@ class ModelRunner:
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
)
)
# Set some global args
global_server_args_dict
[
"disable_flashinfer"
]
=
server_args
.
disable_flashinfer
global_server_args_dict
[
"attention_reduce_in_fp32"
]
=
(
server_args
.
attention_reduce_in_fp32
)
# Load the model and create memory pool
# Load the model and create memory pool
self
.
load_model
()
self
.
load_model
()
self
.
init_memory_pool
(
total_gpu_memory
)
self
.
init_memory_pool
(
total_gpu_memory
)
...
...
python/sglang/srt/server.py
View file @
abd5385a
...
@@ -64,6 +64,9 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
...
@@ -64,6 +64,9 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
app
=
FastAPI
()
app
=
FastAPI
()
tokenizer_manager
=
None
tokenizer_manager
=
None
# Put some args for easily access
global_server_args_dict
=
{}
@
app
.
get
(
"/health"
)
@
app
.
get
(
"/health"
)
async
def
health
()
->
Response
:
async
def
health
()
->
Response
:
...
@@ -135,6 +138,14 @@ async def openai_v1_chat_completions(raw_request: Request):
...
@@ -135,6 +138,14 @@ async def openai_v1_chat_completions(raw_request: Request):
return
await
v1_chat_completions
(
tokenizer_manager
,
raw_request
)
return
await
v1_chat_completions
(
tokenizer_manager
,
raw_request
)
def
_set_global_server_args
(
server_args
:
ServerArgs
):
global
global_server_args_dict
global_server_args_dict
=
{
"disable_flashinfer"
:
server_args
.
disable_flashinfer
,
"attention_reduce_in_fp32"
:
server_args
.
attention_reduce_in_fp32
,
}
def
launch_server
(
server_args
:
ServerArgs
,
pipe_finish_writer
,
model_overide_args
=
None
):
def
launch_server
(
server_args
:
ServerArgs
,
pipe_finish_writer
,
model_overide_args
=
None
):
global
tokenizer_manager
global
tokenizer_manager
...
@@ -163,6 +174,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
...
@@ -163,6 +174,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
# TODO: replace this with huggingface transformers template
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api
(
server_args
.
chat_template
)
load_chat_template_for_openai_api
(
server_args
.
chat_template
)
_set_global_server_args
(
server_args
)
# Allocate ports
# Allocate ports
assert
server_args
.
tp_size
%
server_args
.
nnodes
==
0
assert
server_args
.
tp_size
%
server_args
.
nnodes
==
0
tp_size_local
=
server_args
.
tp_size
//
server_args
.
nnodes
tp_size_local
=
server_args
.
tp_size
//
server_args
.
nnodes
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment