Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f86c1e61
Unverified
Commit
f86c1e61
authored
Sep 29, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 29, 2024
Browse files
Move scheduler code from tp_worker.py to scheduler.py (#1538)
parent
acaffd23
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
933 additions
and
870 deletions
+933
-870
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+12
-4
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+3
-8
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+5
-4
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+879
-8
python/sglang/srt/managers/scheduler_policy.py
python/sglang/srt/managers/scheduler_policy.py
+2
-2
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+21
-832
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+2
-2
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+9
-10
No files found.
python/sglang/bench_latency.py
View file @
f86c1e61
...
...
@@ -167,9 +167,13 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
assert
len
(
input_ids
[
i
])
>
bench_args
.
cut_len
tmp_input_ids
=
input_ids
[
i
][:
bench_args
.
cut_len
]
req
=
Req
(
rid
=
i
,
origin_input_text
=
prompts
[
i
],
origin_input_ids
=
tmp_input_ids
)
req
=
Req
(
rid
=
i
,
origin_input_text
=
prompts
[
i
],
origin_input_ids
=
tmp_input_ids
,
sampling_params
=
sampling_params
,
)
req
.
prefix_indices
=
[]
req
.
sampling_params
=
sampling_params
req
.
fill_ids
=
req
.
origin_input_ids
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
reqs
.
append
(
req
)
...
...
@@ -199,9 +203,13 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
reqs
=
[]
for
i
in
range
(
len
(
input_ids
)):
req
=
Req
(
rid
=
i
,
origin_input_text
=
""
,
origin_input_ids
=
list
(
input_ids
[
i
]))
req
=
Req
(
rid
=
i
,
origin_input_text
=
""
,
origin_input_ids
=
list
(
input_ids
[
i
]),
sampling_params
=
sampling_params
,
)
req
.
prefix_indices
=
[]
req
.
sampling_params
=
sampling_params
req
.
fill_ids
=
req
.
origin_input_ids
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
reqs
.
append
(
req
)
...
...
python/sglang/srt/managers/io_struct.py
View file @
f86c1e61
...
...
@@ -18,7 +18,6 @@ The definition of objects transfered between different
processes (TokenizerManager, DetokenizerManager, Controller).
"""
import
copy
import
uuid
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Union
...
...
@@ -53,12 +52,12 @@ class GenerateReqInput:
stream
:
bool
=
False
# The modalities of the image data [image, multi-images, video]
modalities
:
Optional
[
List
[
str
]]
=
None
is_single
:
bool
=
True
# LoRA related
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
# Whether it is a single request or a batch request
is_single
:
bool
=
True
def
post_init
(
self
):
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
...
...
@@ -307,10 +306,6 @@ class BatchTokenIDOut:
meta_info
:
List
[
Dict
]
finished_reason
:
List
[
BaseFinishReason
]
def
__post_init__
(
self
):
# deepcopy meta_info to avoid modification in place
self
.
meta_info
=
copy
.
deepcopy
(
self
.
meta_info
)
@
dataclass
class
BatchStrOut
:
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
f86c1e61
...
...
@@ -31,6 +31,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
...
...
@@ -143,6 +144,7 @@ class Req:
rid
:
str
,
origin_input_text
:
str
,
origin_input_ids
:
Tuple
[
int
],
sampling_params
:
SamplingParams
,
lora_path
:
Optional
[
str
]
=
None
,
):
# Input and output info
...
...
@@ -152,6 +154,8 @@ class Req:
self
.
origin_input_ids
=
origin_input_ids
self
.
output_ids
=
[]
# Each decode stage's output ids
self
.
fill_ids
=
None
# fill_ids = origin_input_ids + output_ids
self
.
sampling_params
=
sampling_params
self
.
lora_path
=
lora_path
# Memory info
...
...
@@ -160,6 +164,7 @@ class Req:
# Check finish
self
.
tokenizer
=
None
self
.
finished_reason
=
None
self
.
stream
=
False
# For incremental decoding
# ----- | --------- read_ids -------|
...
...
@@ -187,10 +192,6 @@ class Req:
self
.
extend_input_len
=
0
self
.
last_node
=
None
# Sampling parameters
self
.
sampling_params
=
None
self
.
stream
=
False
# Logprobs (arguments)
self
.
return_logprob
=
False
self
.
logprob_start_len
=
0
...
...
python/sglang/srt/managers/scheduler.py
View file @
f86c1e61
This diff is collapsed.
Click to expand it.
python/sglang/srt/managers/
policy_
scheduler.py
→
python/sglang/srt/managers/scheduler
_policy
.py
View file @
f86c1e61
...
...
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""Request
policy
scheduler"""
"""Request scheduler
policy
"""
import
os
import
random
...
...
@@ -32,7 +32,7 @@ from sglang.srt.mem_cache.radix_cache import TreeNode
CLIP_MAX_NEW_TOKENS
=
int
(
os
.
environ
.
get
(
"SGLANG_CLIP_MAX_NEW_TOKENS"
,
"4096"
))
class
Policy
Scheduler
:
class
Scheduler
Policy
:
def
__init__
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
):
if
tree_cache
.
disable
and
policy
in
[
"lpm"
,
"dfs-weight"
]:
# LPM and DFS-weight is meaningless when the tree cache is disabled.
...
...
python/sglang/srt/managers/tp_worker.py
View file @
f86c1e61
This diff is collapsed.
Click to expand it.
python/sglang/srt/mem_cache/memory_pool.py
View file @
f86c1e61
...
...
@@ -27,11 +27,11 @@ logger = logging.getLogger(__name__)
class
ReqToTokenPool
:
"""A memory pool that maps a request to its token locations."""
def
__init__
(
self
,
size
:
int
,
max_context_len
:
int
):
def
__init__
(
self
,
size
:
int
,
max_context_len
:
int
,
device
:
str
):
self
.
size
=
size
self
.
free_slots
=
list
(
range
(
size
))
self
.
req_to_token
=
torch
.
empty
(
(
size
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
"cuda"
(
size
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
device
)
def
alloc
(
self
,
need_size
:
int
)
->
List
[
int
]:
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
f86c1e61
...
...
@@ -87,6 +87,7 @@ class ModelRunner:
self
.
model_config
.
hf_config
.
architectures
)
# Model-specific adjustment
if
(
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
and
not
self
.
server_args
.
disable_mla
...
...
@@ -94,6 +95,13 @@ class ModelRunner:
logger
.
info
(
"MLA optimization is tunred on. Use triton backend."
)
self
.
server_args
.
attention_backend
=
"triton"
if
self
.
is_multimodal_model
:
logger
.
info
(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args
.
chunked_prefill_size
=
None
server_args
.
mem_fraction_static
*=
0.95
global_server_args_dict
.
update
(
{
"attention_backend"
:
server_args
.
attention_backend
,
...
...
@@ -104,14 +112,6 @@ class ModelRunner:
}
)
# Model-specific adjustment
if
self
.
is_multimodal_model
:
logger
.
info
(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args
.
chunked_prefill_size
=
None
server_args
.
mem_fraction_static
*=
0.95
# Init componnets
min_per_gpu_memory
=
self
.
init_torch_distributed
()
self
.
sampler
=
Sampler
()
...
...
@@ -400,8 +400,7 @@ class ModelRunner:
)
self
.
req_to_token_pool
=
ReqToTokenPool
(
max_num_reqs
+
1
,
self
.
model_config
.
context_len
+
4
,
max_num_reqs
+
1
,
self
.
model_config
.
context_len
+
4
,
device
=
"cuda"
)
if
(
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment