Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f86c1e61
Unverified
Commit
f86c1e61
authored
Sep 29, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 29, 2024
Browse files
Move scheduler code from tp_worker.py to scheduler.py (#1538)
parent
acaffd23
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
933 additions
and
870 deletions
+933
-870
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+12
-4
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+3
-8
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+5
-4
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+879
-8
python/sglang/srt/managers/scheduler_policy.py
python/sglang/srt/managers/scheduler_policy.py
+2
-2
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+21
-832
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+2
-2
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+9
-10
No files found.
python/sglang/bench_latency.py
View file @
f86c1e61
...
@@ -167,9 +167,13 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
...
@@ -167,9 +167,13 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
assert
len
(
input_ids
[
i
])
>
bench_args
.
cut_len
assert
len
(
input_ids
[
i
])
>
bench_args
.
cut_len
tmp_input_ids
=
input_ids
[
i
][:
bench_args
.
cut_len
]
tmp_input_ids
=
input_ids
[
i
][:
bench_args
.
cut_len
]
req
=
Req
(
rid
=
i
,
origin_input_text
=
prompts
[
i
],
origin_input_ids
=
tmp_input_ids
)
req
=
Req
(
rid
=
i
,
origin_input_text
=
prompts
[
i
],
origin_input_ids
=
tmp_input_ids
,
sampling_params
=
sampling_params
,
)
req
.
prefix_indices
=
[]
req
.
prefix_indices
=
[]
req
.
sampling_params
=
sampling_params
req
.
fill_ids
=
req
.
origin_input_ids
req
.
fill_ids
=
req
.
origin_input_ids
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
reqs
.
append
(
req
)
reqs
.
append
(
req
)
...
@@ -199,9 +203,13 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
...
@@ -199,9 +203,13 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
reqs
=
[]
reqs
=
[]
for
i
in
range
(
len
(
input_ids
)):
for
i
in
range
(
len
(
input_ids
)):
req
=
Req
(
rid
=
i
,
origin_input_text
=
""
,
origin_input_ids
=
list
(
input_ids
[
i
]))
req
=
Req
(
rid
=
i
,
origin_input_text
=
""
,
origin_input_ids
=
list
(
input_ids
[
i
]),
sampling_params
=
sampling_params
,
)
req
.
prefix_indices
=
[]
req
.
prefix_indices
=
[]
req
.
sampling_params
=
sampling_params
req
.
fill_ids
=
req
.
origin_input_ids
req
.
fill_ids
=
req
.
origin_input_ids
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
reqs
.
append
(
req
)
reqs
.
append
(
req
)
...
...
python/sglang/srt/managers/io_struct.py
View file @
f86c1e61
...
@@ -18,7 +18,6 @@ The definition of objects transfered between different
...
@@ -18,7 +18,6 @@ The definition of objects transfered between different
processes (TokenizerManager, DetokenizerManager, Controller).
processes (TokenizerManager, DetokenizerManager, Controller).
"""
"""
import
copy
import
uuid
import
uuid
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
...
@@ -53,12 +52,12 @@ class GenerateReqInput:
...
@@ -53,12 +52,12 @@ class GenerateReqInput:
stream
:
bool
=
False
stream
:
bool
=
False
# The modalities of the image data [image, multi-images, video]
# The modalities of the image data [image, multi-images, video]
modalities
:
Optional
[
List
[
str
]]
=
None
modalities
:
Optional
[
List
[
str
]]
=
None
is_single
:
bool
=
True
# LoRA related
# LoRA related
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
# Whether it is a single request or a batch request
is_single
:
bool
=
True
def
post_init
(
self
):
def
post_init
(
self
):
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
...
@@ -307,10 +306,6 @@ class BatchTokenIDOut:
...
@@ -307,10 +306,6 @@ class BatchTokenIDOut:
meta_info
:
List
[
Dict
]
meta_info
:
List
[
Dict
]
finished_reason
:
List
[
BaseFinishReason
]
finished_reason
:
List
[
BaseFinishReason
]
def
__post_init__
(
self
):
# deepcopy meta_info to avoid modification in place
self
.
meta_info
=
copy
.
deepcopy
(
self
.
meta_info
)
@
dataclass
@
dataclass
class
BatchStrOut
:
class
BatchStrOut
:
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
f86c1e61
...
@@ -31,6 +31,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
...
@@ -31,6 +31,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
...
@@ -143,6 +144,7 @@ class Req:
...
@@ -143,6 +144,7 @@ class Req:
rid
:
str
,
rid
:
str
,
origin_input_text
:
str
,
origin_input_text
:
str
,
origin_input_ids
:
Tuple
[
int
],
origin_input_ids
:
Tuple
[
int
],
sampling_params
:
SamplingParams
,
lora_path
:
Optional
[
str
]
=
None
,
lora_path
:
Optional
[
str
]
=
None
,
):
):
# Input and output info
# Input and output info
...
@@ -152,6 +154,8 @@ class Req:
...
@@ -152,6 +154,8 @@ class Req:
self
.
origin_input_ids
=
origin_input_ids
self
.
origin_input_ids
=
origin_input_ids
self
.
output_ids
=
[]
# Each decode stage's output ids
self
.
output_ids
=
[]
# Each decode stage's output ids
self
.
fill_ids
=
None
# fill_ids = origin_input_ids + output_ids
self
.
fill_ids
=
None
# fill_ids = origin_input_ids + output_ids
self
.
sampling_params
=
sampling_params
self
.
lora_path
=
lora_path
self
.
lora_path
=
lora_path
# Memory info
# Memory info
...
@@ -160,6 +164,7 @@ class Req:
...
@@ -160,6 +164,7 @@ class Req:
# Check finish
# Check finish
self
.
tokenizer
=
None
self
.
tokenizer
=
None
self
.
finished_reason
=
None
self
.
finished_reason
=
None
self
.
stream
=
False
# For incremental decoding
# For incremental decoding
# ----- | --------- read_ids -------|
# ----- | --------- read_ids -------|
...
@@ -187,10 +192,6 @@ class Req:
...
@@ -187,10 +192,6 @@ class Req:
self
.
extend_input_len
=
0
self
.
extend_input_len
=
0
self
.
last_node
=
None
self
.
last_node
=
None
# Sampling parameters
self
.
sampling_params
=
None
self
.
stream
=
False
# Logprobs (arguments)
# Logprobs (arguments)
self
.
return_logprob
=
False
self
.
return_logprob
=
False
self
.
logprob_start_len
=
0
self
.
logprob_start_len
=
0
...
...
python/sglang/srt/managers/scheduler.py
View file @
f86c1e61
This diff is collapsed.
Click to expand it.
python/sglang/srt/managers/
policy_
scheduler.py
→
python/sglang/srt/managers/scheduler
_policy
.py
View file @
f86c1e61
...
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
"""
"""
"""Request
policy
scheduler"""
"""Request scheduler
policy
"""
import
os
import
os
import
random
import
random
...
@@ -32,7 +32,7 @@ from sglang.srt.mem_cache.radix_cache import TreeNode
...
@@ -32,7 +32,7 @@ from sglang.srt.mem_cache.radix_cache import TreeNode
CLIP_MAX_NEW_TOKENS
=
int
(
os
.
environ
.
get
(
"SGLANG_CLIP_MAX_NEW_TOKENS"
,
"4096"
))
CLIP_MAX_NEW_TOKENS
=
int
(
os
.
environ
.
get
(
"SGLANG_CLIP_MAX_NEW_TOKENS"
,
"4096"
))
class
Policy
Scheduler
:
class
Scheduler
Policy
:
def
__init__
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
):
def
__init__
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
):
if
tree_cache
.
disable
and
policy
in
[
"lpm"
,
"dfs-weight"
]:
if
tree_cache
.
disable
and
policy
in
[
"lpm"
,
"dfs-weight"
]:
# LPM and DFS-weight is meaningless when the tree cache is disabled.
# LPM and DFS-weight is meaningless when the tree cache is disabled.
...
...
python/sglang/srt/managers/tp_worker.py
View file @
f86c1e61
This diff is collapsed.
Click to expand it.
python/sglang/srt/mem_cache/memory_pool.py
View file @
f86c1e61
...
@@ -27,11 +27,11 @@ logger = logging.getLogger(__name__)
...
@@ -27,11 +27,11 @@ logger = logging.getLogger(__name__)
class
ReqToTokenPool
:
class
ReqToTokenPool
:
"""A memory pool that maps a request to its token locations."""
"""A memory pool that maps a request to its token locations."""
def
__init__
(
self
,
size
:
int
,
max_context_len
:
int
):
def
__init__
(
self
,
size
:
int
,
max_context_len
:
int
,
device
:
str
):
self
.
size
=
size
self
.
size
=
size
self
.
free_slots
=
list
(
range
(
size
))
self
.
free_slots
=
list
(
range
(
size
))
self
.
req_to_token
=
torch
.
empty
(
self
.
req_to_token
=
torch
.
empty
(
(
size
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
"cuda"
(
size
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
device
)
)
def
alloc
(
self
,
need_size
:
int
)
->
List
[
int
]:
def
alloc
(
self
,
need_size
:
int
)
->
List
[
int
]:
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
f86c1e61
...
@@ -87,6 +87,7 @@ class ModelRunner:
...
@@ -87,6 +87,7 @@ class ModelRunner:
self
.
model_config
.
hf_config
.
architectures
self
.
model_config
.
hf_config
.
architectures
)
)
# Model-specific adjustment
if
(
if
(
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
and
not
self
.
server_args
.
disable_mla
and
not
self
.
server_args
.
disable_mla
...
@@ -94,6 +95,13 @@ class ModelRunner:
...
@@ -94,6 +95,13 @@ class ModelRunner:
logger
.
info
(
"MLA optimization is tunred on. Use triton backend."
)
logger
.
info
(
"MLA optimization is tunred on. Use triton backend."
)
self
.
server_args
.
attention_backend
=
"triton"
self
.
server_args
.
attention_backend
=
"triton"
if
self
.
is_multimodal_model
:
logger
.
info
(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args
.
chunked_prefill_size
=
None
server_args
.
mem_fraction_static
*=
0.95
global_server_args_dict
.
update
(
global_server_args_dict
.
update
(
{
{
"attention_backend"
:
server_args
.
attention_backend
,
"attention_backend"
:
server_args
.
attention_backend
,
...
@@ -104,14 +112,6 @@ class ModelRunner:
...
@@ -104,14 +112,6 @@ class ModelRunner:
}
}
)
)
# Model-specific adjustment
if
self
.
is_multimodal_model
:
logger
.
info
(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args
.
chunked_prefill_size
=
None
server_args
.
mem_fraction_static
*=
0.95
# Init componnets
# Init componnets
min_per_gpu_memory
=
self
.
init_torch_distributed
()
min_per_gpu_memory
=
self
.
init_torch_distributed
()
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
@@ -400,8 +400,7 @@ class ModelRunner:
...
@@ -400,8 +400,7 @@ class ModelRunner:
)
)
self
.
req_to_token_pool
=
ReqToTokenPool
(
self
.
req_to_token_pool
=
ReqToTokenPool
(
max_num_reqs
+
1
,
max_num_reqs
+
1
,
self
.
model_config
.
context_len
+
4
,
device
=
"cuda"
self
.
model_config
.
context_len
+
4
,
)
)
if
(
if
(
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment