Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
cdcbde5f
Unverified
Commit
cdcbde5f
authored
Jul 29, 2024
by
Liangsheng Yin
Committed by
GitHub
Jul 29, 2024
Browse files
Code structure refactor (#807)
parent
21e22b9e
Changes
41
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
77 additions
and
76 deletions
+77
-76
docs/en/hyperparameter_tuning.md
docs/en/hyperparameter_tuning.md
+3
-3
python/sglang/__init__.py
python/sglang/__init__.py
+31
-30
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+2
-2
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+1
-1
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+1
-1
python/sglang/srt/layers/token_attention.py
python/sglang/srt/layers/token_attention.py
+1
-1
python/sglang/srt/managers/controller_multi.py
python/sglang/srt/managers/controller_multi.py
+1
-1
python/sglang/srt/managers/controller_single.py
python/sglang/srt/managers/controller_single.py
+1
-1
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+1
-1
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+1
-1
python/sglang/srt/managers/policy_scheduler.py
python/sglang/srt/managers/policy_scheduler.py
+12
-12
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+2
-2
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+13
-13
python/sglang/srt/mem_cache/flush_cache.py
python/sglang/srt/mem_cache/flush_cache.py
+1
-1
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+0
-0
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+0
-0
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+1
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+3
-3
python/sglang/srt/models/chatglm.py
python/sglang/srt/models/chatglm.py
+1
-1
python/sglang/srt/models/commandr.py
python/sglang/srt/models/commandr.py
+1
-1
No files found.
docs/en/hyperparameter_tuning.md
View file @
cdcbde5f
...
@@ -29,7 +29,7 @@ If OOM happens during prefill, try to decrease `--max-prefill-tokens`.
...
@@ -29,7 +29,7 @@ If OOM happens during prefill, try to decrease `--max-prefill-tokens`.
If OOM happens during decoding, try to decrease `--max-running-requests`.
If OOM happens during decoding, try to decrease `--max-running-requests`.
You can also try to decrease `--mem-fraction-static`, which reduces the memory usage of the KV cache memory pool and helps both prefill and decoding.
You can also try to decrease `--mem-fraction-static`, which reduces the memory usage of the KV cache memory pool and helps both prefill and decoding.
### (Minor) Tune `--schedule-
heurist
ic`
### (Minor) Tune `--schedule-
pol
ic
y
`
If you have many shared prefixes, use the default `--schedule-
heurist
ic lpm`. `lpm` stands for longest prefix match.
If you have many shared prefixes, use the default `--schedule-
pol
ic
y
lpm`. `lpm` stands for longest prefix match.
When you have no shared prefixes at all or you always send the requests with the shared prefixes together,
When you have no shared prefixes at all or you always send the requests with the shared prefixes together,
you can try `--schedule-
heurist
ic fcfs`. `fcfs` stands for first come first serve.
you can try `--schedule-
pol
ic
y
fcfs`. `fcfs` stands for first come first serve.
python/sglang/__init__.py
View file @
cdcbde5f
# SGL API Components
# SGL API Components
from
sglang.api
import
(
from
sglang.api
import
(
Runtime
,
Runtime
,
assistant
,
assistant
,
...
@@ -22,46 +23,46 @@ from sglang.api import (
...
@@ -22,46 +23,46 @@ from sglang.api import (
video
,
video
,
)
)
# Global Configurations
# SGLang DSL APIs
from
sglang.global_config
import
global_config
# SGL Backends
from
sglang.lang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.utils
import
LazyImport
from
sglang.version
import
__version__
Anthropic
=
LazyImport
(
"sglang.lang.backend.anthropic"
,
"Anthropic"
)
LiteLLM
=
LazyImport
(
"sglang.lang.backend.litellm"
,
"LiteLLM"
)
OpenAI
=
LazyImport
(
"sglang.lang.backend.openai"
,
"OpenAI"
)
VertexAI
=
LazyImport
(
"sglang.lang.backend.vertexai"
,
"VertexAI"
)
# public APIs management
__all__
=
[
__all__
=
[
"global_config"
,
"Anthropic"
,
"LiteLLM"
,
"OpenAI"
,
"RuntimeEndpoint"
,
"VertexAI"
,
"function"
,
"Runtime"
,
"Runtime"
,
"set_default_backend"
,
"assistant"
,
"assistant_begin"
,
"assistant_end"
,
"flush_cache"
,
"flush_cache"
,
"
get_server_args
"
,
"
function
"
,
"gen"
,
"gen"
,
"gen_int"
,
"gen_int"
,
"gen_string"
,
"gen_string"
,
"get_server_args"
,
"image"
,
"image"
,
"video"
,
"select"
,
"select"
,
"set_default_backend"
,
"system"
,
"system"
,
"system_begin"
,
"system_end"
,
"user"
,
"user"
,
"assistant"
,
"user_begin"
,
"user_begin"
,
"user_end"
,
"user_end"
,
"assistant_begin"
,
"video"
,
"assistant_end"
,
"system_begin"
,
"system_end"
,
]
]
# Global Configurations
from
sglang.global_config
import
global_config
__all__
+=
[
"global_config"
]
from
sglang.version
import
__version__
__all__
+=
[
"__version__"
]
# SGL Backends
from
sglang.lang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.utils
import
LazyImport
Anthropic
=
LazyImport
(
"sglang.lang.backend.anthropic"
,
"Anthropic"
)
LiteLLM
=
LazyImport
(
"sglang.lang.backend.litellm"
,
"LiteLLM"
)
OpenAI
=
LazyImport
(
"sglang.lang.backend.openai"
,
"OpenAI"
)
VertexAI
=
LazyImport
(
"sglang.lang.backend.vertexai"
,
"VertexAI"
)
__all__
+=
[
"Anthropic"
,
"LiteLLM"
,
"OpenAI"
,
"VertexAI"
,
"RuntimeEndpoint"
]
python/sglang/bench_latency.py
View file @
cdcbde5f
...
@@ -37,9 +37,9 @@ import torch
...
@@ -37,9 +37,9 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.controller.infer_batch
import
Batch
,
ForwardMode
,
Req
from
sglang.srt.managers.schedule_batch
import
Batch
,
ForwardMode
,
Req
from
sglang.srt.managers.controller.model_runner
import
ModelRunner
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
suppress_other_loggers
from
sglang.srt.utils
import
suppress_other_loggers
...
...
python/sglang/srt/layers/logits_processor.py
View file @
cdcbde5f
...
@@ -25,7 +25,7 @@ from vllm.distributed import (
...
@@ -25,7 +25,7 @@ from vllm.distributed import (
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
)
)
from
sglang.srt.m
anagers.controlle
r.model_runner
import
ForwardMode
,
InputMetadata
from
sglang.srt.m
odel_executo
r.model_runner
import
ForwardMode
,
InputMetadata
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
python/sglang/srt/layers/radix_attention.py
View file @
cdcbde5f
...
@@ -22,7 +22,7 @@ from torch import nn
...
@@ -22,7 +22,7 @@ from torch import nn
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.layers.extend_attention
import
extend_attention_fwd
from
sglang.srt.layers.extend_attention
import
extend_attention_fwd
from
sglang.srt.layers.token_attention
import
token_attention_fwd
from
sglang.srt.layers.token_attention
import
token_attention_fwd
from
sglang.srt.m
anagers.controlle
r.model_runner
import
(
from
sglang.srt.m
odel_executo
r.model_runner
import
(
ForwardMode
,
ForwardMode
,
InputMetadata
,
InputMetadata
,
global_server_args_dict
,
global_server_args_dict
,
...
...
python/sglang/srt/layers/token_attention.py
View file @
cdcbde5f
...
@@ -20,7 +20,7 @@ import torch
...
@@ -20,7 +20,7 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.managers.
controller.infer
_batch
import
global_server_args_dict
from
sglang.srt.managers.
schedule
_batch
import
global_server_args_dict
if
global_server_args_dict
.
get
(
"attention_reduce_in_fp32"
,
False
):
if
global_server_args_dict
.
get
(
"attention_reduce_in_fp32"
,
False
):
REDUCE_TRITON_TYPE
=
tl
.
float32
REDUCE_TRITON_TYPE
=
tl
.
float32
...
...
python/sglang/srt/managers/controller
/manager
_multi.py
→
python/sglang/srt/managers/controller_multi.py
View file @
cdcbde5f
...
@@ -27,7 +27,7 @@ from enum import Enum, auto
...
@@ -27,7 +27,7 @@ from enum import Enum, auto
import
numpy
as
np
import
numpy
as
np
import
zmq
import
zmq
from
sglang.srt.managers.controller
.manager
_single
import
(
from
sglang.srt.managers.controller_single
import
(
start_controller_process
as
start_controller_process_single
,
start_controller_process
as
start_controller_process_single
,
)
)
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
...
...
python/sglang/srt/managers/controller
/manager
_single.py
→
python/sglang/srt/managers/controller_single.py
View file @
cdcbde5f
...
@@ -22,7 +22,7 @@ from typing import List
...
@@ -22,7 +22,7 @@ from typing import List
import
zmq
import
zmq
from
sglang.srt.managers.
controller.
tp_worker
import
(
from
sglang.srt.managers.tp_worker
import
(
ModelTpServer
,
ModelTpServer
,
broadcast_recv_input
,
broadcast_recv_input
,
launch_tp_servers
,
launch_tp_servers
,
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
cdcbde5f
...
@@ -25,8 +25,8 @@ import zmq
...
@@ -25,8 +25,8 @@ import zmq
import
zmq.asyncio
import
zmq.asyncio
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.controller.infer_batch
import
FINISH_MATCHED_STR
from
sglang.srt.managers.io_struct
import
BatchStrOut
,
BatchTokenIDOut
from
sglang.srt.managers.io_struct
import
BatchStrOut
,
BatchTokenIDOut
from
sglang.srt.managers.schedule_batch
import
FINISH_MATCHED_STR
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.utils
import
find_printable_text
,
get_exception_traceback
,
graceful_registry
from
sglang.utils
import
find_printable_text
,
get_exception_traceback
,
graceful_registry
...
...
python/sglang/srt/managers/io_struct.py
View file @
cdcbde5f
...
@@ -22,7 +22,7 @@ import uuid
...
@@ -22,7 +22,7 @@ import uuid
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
from
sglang.srt.managers.
controller.infer
_batch
import
BaseFinishReason
from
sglang.srt.managers.
schedule
_batch
import
BaseFinishReason
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.sampling_params
import
SamplingParams
...
...
python/sglang/srt/managers/
controller/schedule_heuristic
.py
→
python/sglang/srt/managers/
policy_scheduler
.py
View file @
cdcbde5f
...
@@ -13,47 +13,47 @@ See the License for the specific language governing permissions and
...
@@ -13,47 +13,47 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
"""
"""
"""Request scheduler
heuristic.
"""
"""Request
policy
scheduler"""
import
random
import
random
from
collections
import
defaultdict
from
collections
import
defaultdict
class
Schedule
Heuristic
:
class
Policy
Schedule
r
:
def
__init__
(
def
__init__
(
self
,
self
,
schedule_heurist
ic
,
pol
ic
y
,
max_running_seqs
,
max_running_seqs
,
max_prefill_num_tokens
,
max_prefill_num_tokens
,
max_total_num_tokens
,
max_total_num_tokens
,
tree_cache
,
tree_cache
,
):
):
if
tree_cache
.
disable
and
schedule_heurist
ic
==
"lpm"
:
if
tree_cache
.
disable
and
pol
ic
y
==
"lpm"
:
# LMP is meaningless when the tree cache is disabled.
# LMP is meaningless when the tree cache is disabled.
schedule_heurist
ic
=
"fcfs"
pol
ic
y
=
"fcfs"
self
.
schedule_heurist
ic
=
schedule_heurist
ic
self
.
pol
ic
y
=
pol
ic
y
self
.
max_running_seqs
=
max_running_seqs
self
.
max_running_seqs
=
max_running_seqs
self
.
max_prefill_num_tokens
=
max_prefill_num_tokens
self
.
max_prefill_num_tokens
=
max_prefill_num_tokens
self
.
max_total_num_tokens
=
max_total_num_tokens
self
.
max_total_num_tokens
=
max_total_num_tokens
self
.
tree_cache
=
tree_cache
self
.
tree_cache
=
tree_cache
def
get_priority_queue
(
self
,
waiting_queue
):
def
get_priority_queue
(
self
,
waiting_queue
):
if
self
.
schedule_heurist
ic
==
"lpm"
:
if
self
.
pol
ic
y
==
"lpm"
:
# longest prefix match
# longest prefix match
waiting_queue
.
sort
(
key
=
lambda
x
:
-
len
(
x
.
prefix_indices
))
waiting_queue
.
sort
(
key
=
lambda
x
:
-
len
(
x
.
prefix_indices
))
return
waiting_queue
return
waiting_queue
elif
self
.
schedule_heurist
ic
==
"fcfs"
:
elif
self
.
pol
ic
y
==
"fcfs"
:
# first come first serve
# first come first serve
return
waiting_queue
return
waiting_queue
elif
self
.
schedule_heurist
ic
==
"lof"
:
elif
self
.
pol
ic
y
==
"lof"
:
# longest output first
# longest output first
waiting_queue
.
sort
(
key
=
lambda
x
:
-
x
.
sampling_params
.
max_new_tokens
)
waiting_queue
.
sort
(
key
=
lambda
x
:
-
x
.
sampling_params
.
max_new_tokens
)
return
waiting_queue
return
waiting_queue
elif
self
.
schedule_heurist
ic
==
"random"
:
elif
self
.
pol
ic
y
==
"random"
:
random
.
shuffle
(
waiting_queue
)
random
.
shuffle
(
waiting_queue
)
return
waiting_queue
return
waiting_queue
elif
self
.
schedule_heurist
ic
==
"dfs-weight"
:
elif
self
.
pol
ic
y
==
"dfs-weight"
:
last_node_to_reqs
=
defaultdict
(
list
)
last_node_to_reqs
=
defaultdict
(
list
)
for
req
in
waiting_queue
:
for
req
in
waiting_queue
:
last_node_to_reqs
[
req
.
last_node
].
append
(
req
)
last_node_to_reqs
[
req
.
last_node
].
append
(
req
)
...
@@ -70,7 +70,7 @@ class ScheduleHeuristic:
...
@@ -70,7 +70,7 @@ class ScheduleHeuristic:
assert
len
(
q
)
==
len
(
waiting_queue
)
assert
len
(
q
)
==
len
(
waiting_queue
)
return
q
return
q
else
:
else
:
raise
ValueError
(
f
"Unknown schedule_
heurist
ic:
{
self
.
schedule_heurist
ic
}
"
)
raise
ValueError
(
f
"Unknown schedule_
pol
ic
y
:
{
self
.
pol
ic
y
}
"
)
def
calc_weight
(
self
,
cur_node
,
node_to_weight
):
def
calc_weight
(
self
,
cur_node
,
node_to_weight
):
for
child
in
cur_node
.
children
.
values
():
for
child
in
cur_node
.
children
.
values
():
...
...
python/sglang/srt/managers/
controller/infer
_batch.py
→
python/sglang/srt/managers/
schedule
_batch.py
View file @
cdcbde5f
...
@@ -28,8 +28,8 @@ from flashinfer.sampling import top_k_top_p_sampling_from_probs
...
@@ -28,8 +28,8 @@ from flashinfer.sampling import top_k_top_p_sampling_from_probs
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.constrained
import
RegexGuide
from
sglang.srt.constrained
import
RegexGuide
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
from
sglang.srt.m
anagers.controller.radix_cache
import
RadixCache
from
sglang.srt.m
em_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.mem
ory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.mem
_cache.radix_cache
import
RadixCache
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
...
...
python/sglang/srt/managers/
controller/
tp_worker.py
→
python/sglang/srt/managers/tp_worker.py
View file @
cdcbde5f
...
@@ -29,23 +29,23 @@ from sglang.global_config import global_config
...
@@ -29,23 +29,23 @@ from sglang.global_config import global_config
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.managers.controller.infer_batch
import
(
FINISH_ABORT
,
BaseFinishReason
,
Batch
,
ForwardMode
,
Req
,
)
from
sglang.srt.managers.controller.model_runner
import
ModelRunner
from
sglang.srt.managers.controller.radix_cache
import
RadixCache
from
sglang.srt.managers.controller.schedule_heuristic
import
ScheduleHeuristic
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
AbortReq
,
BatchTokenIDOut
,
BatchTokenIDOut
,
FlushCacheReq
,
FlushCacheReq
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
)
)
from
sglang.srt.managers.policy_scheduler
import
PolicyScheduler
from
sglang.srt.managers.schedule_batch
import
(
FINISH_ABORT
,
BaseFinishReason
,
Batch
,
ForwardMode
,
Req
,
)
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_int_token_logit_bias
,
get_int_token_logit_bias
,
...
@@ -74,7 +74,7 @@ class ModelTpServer:
...
@@ -74,7 +74,7 @@ class ModelTpServer:
self
.
tp_rank
=
tp_rank
self
.
tp_rank
=
tp_rank
self
.
tp_size
=
server_args
.
tp_size
self
.
tp_size
=
server_args
.
tp_size
self
.
dp_size
=
server_args
.
dp_size
self
.
dp_size
=
server_args
.
dp_size
self
.
schedule_
heurist
ic
=
server_args
.
schedule_
heurist
ic
self
.
schedule_
pol
ic
y
=
server_args
.
schedule_
pol
ic
y
self
.
disable_regex_jump_forward
=
server_args
.
disable_regex_jump_forward
self
.
disable_regex_jump_forward
=
server_args
.
disable_regex_jump_forward
# Chunked prefill
# Chunked prefill
...
@@ -150,8 +150,8 @@ class ModelTpServer:
...
@@ -150,8 +150,8 @@ class ModelTpServer:
disable
=
server_args
.
disable_radix_cache
,
disable
=
server_args
.
disable_radix_cache
,
)
)
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
scheduler
=
Schedule
Heuristic
(
self
.
scheduler
=
Policy
Schedule
r
(
self
.
schedule_
heurist
ic
,
self
.
schedule_
pol
ic
y
,
self
.
max_running_requests
,
self
.
max_running_requests
,
self
.
max_prefill_tokens
,
self
.
max_prefill_tokens
,
self
.
max_total_num_tokens
,
self
.
max_total_num_tokens
,
...
...
python/sglang/srt/flush_cache.py
→
python/sglang/srt/
mem_cache/
flush_cache.py
View file @
cdcbde5f
...
@@ -17,7 +17,7 @@ limitations under the License.
...
@@ -17,7 +17,7 @@ limitations under the License.
Flush the KV cache.
Flush the KV cache.
Usage:
Usage:
python3 -m sglang.srt.flush_cache --url http://localhost:30000
python3 -m sglang.srt.
mem_cache.
flush_cache --url http://localhost:30000
"""
"""
import
argparse
import
argparse
...
...
python/sglang/srt/memory_pool.py
→
python/sglang/srt/
mem_cache/
memory_pool.py
View file @
cdcbde5f
File moved
python/sglang/srt/m
anagers/controller
/radix_cache.py
→
python/sglang/srt/m
em_cache
/radix_cache.py
View file @
cdcbde5f
File moved
python/sglang/srt/m
anagers/controlle
r/cuda_graph_runner.py
→
python/sglang/srt/m
odel_executo
r/cuda_graph_runner.py
View file @
cdcbde5f
...
@@ -29,7 +29,7 @@ from sglang.srt.layers.logits_processor import (
...
@@ -29,7 +29,7 @@ from sglang.srt.layers.logits_processor import (
LogitsMetadata
,
LogitsMetadata
,
LogitsProcessor
,
LogitsProcessor
,
)
)
from
sglang.srt.managers.
controller.infer
_batch
import
(
from
sglang.srt.managers.
schedule
_batch
import
(
Batch
,
Batch
,
ForwardMode
,
ForwardMode
,
InputMetadata
,
InputMetadata
,
...
...
python/sglang/srt/m
anagers/controlle
r/model_runner.py
→
python/sglang/srt/m
odel_executo
r/model_runner.py
View file @
cdcbde5f
...
@@ -40,13 +40,13 @@ from vllm.distributed import (
...
@@ -40,13 +40,13 @@ from vllm.distributed import (
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.managers.
controller.infer
_batch
import
(
from
sglang.srt.managers.
schedule
_batch
import
(
Batch
,
Batch
,
ForwardMode
,
ForwardMode
,
InputMetadata
,
InputMetadata
,
global_server_args_dict
,
global_server_args_dict
,
)
)
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.
mem_cache.
memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_available_gpu_memory
,
get_available_gpu_memory
,
...
@@ -273,7 +273,7 @@ class ModelRunner:
...
@@ -273,7 +273,7 @@ class ModelRunner:
)
)
def
init_cuda_graphs
(
self
):
def
init_cuda_graphs
(
self
):
from
sglang.srt.m
anagers.controlle
r.cuda_graph_runner
import
CudaGraphRunner
from
sglang.srt.m
odel_executo
r.cuda_graph_runner
import
CudaGraphRunner
if
self
.
server_args
.
disable_cuda_graph
or
self
.
server_args
.
disable_flashinfer
:
if
self
.
server_args
.
disable_cuda_graph
or
self
.
server_args
.
disable_flashinfer
:
self
.
cuda_graph_runner
=
None
self
.
cuda_graph_runner
=
None
...
...
python/sglang/srt/models/chatglm.py
View file @
cdcbde5f
...
@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs import ChatGLMConfig
...
@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs import ChatGLMConfig
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.m
anagers.controlle
r.model_runner
import
InputMetadata
from
sglang.srt.m
odel_executo
r.model_runner
import
InputMetadata
LoraConfig
=
None
LoraConfig
=
None
...
...
python/sglang/srt/models/commandr.py
View file @
cdcbde5f
...
@@ -64,7 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
...
@@ -64,7 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.m
anagers.controlle
r.model_runner
import
InputMetadata
from
sglang.srt.m
odel_executo
r.model_runner
import
InputMetadata
@
torch
.
compile
@
torch
.
compile
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment