Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
59cbf476
Unverified
Commit
59cbf476
authored
Oct 19, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 19, 2024
Browse files
Unify the memory pool api and tp worker API (#1724)
parent
95946271
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
87 additions
and
25 deletions
+87
-25
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+26
-10
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+24
-8
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+6
-2
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+7
-1
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+4
-3
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+2
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+7
-0
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+11
-1
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
59cbf476
...
@@ -23,6 +23,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
...
@@ -23,6 +23,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
It contains high-level scheduling data. Most of the data is on the CPU.
It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
It will be transformed from CPU scheduler to GPU model runner.
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
It contains low-level tensor data. Most of the data consists of GPU tensors.
It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
"""
...
@@ -522,12 +524,12 @@ class ScheduleBatch:
...
@@ -522,12 +524,12 @@ class ScheduleBatch:
assert
seq_len
-
pre_len
==
req
.
extend_input_len
assert
seq_len
-
pre_len
==
req
.
extend_input_len
if
pre_len
>
0
:
if
pre_len
>
0
:
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
pre_len
]
=
(
self
.
req_to_token_pool
.
write
(
req
.
prefix_indices
(
req
.
req_pool_idx
,
slice
(
0
,
pre_len
)),
req
.
prefix_indices
)
)
self
.
req_to_token_pool
.
write
(
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
pre_len
:
seq_len
]
=
(
(
req
.
req_pool_idx
,
slice
(
pre_len
,
seq_len
)),
out_cache_loc
[
pt
:
pt
+
req
.
extend_input_len
]
out_cache_loc
[
pt
:
pt
+
req
.
extend_input_len
]
,
)
)
# Compute the relative logprob_start_len in an extend batch
# Compute the relative logprob_start_len in an extend batch
...
@@ -765,9 +767,8 @@ class ScheduleBatch:
...
@@ -765,9 +767,8 @@ class ScheduleBatch:
# Alloc mem
# Alloc mem
bs
=
len
(
self
.
reqs
)
bs
=
len
(
self
.
reqs
)
self
.
out_cache_loc
=
self
.
alloc_token_slots
(
bs
)
self
.
out_cache_loc
=
self
.
alloc_token_slots
(
bs
)
self
.
req_to_token_pool
.
write
(
self
.
req_to_token_pool
.
req_to_token
[
self
.
req_pool_indices
,
self
.
seq_lens
]
=
(
(
self
.
req_pool_indices
,
self
.
seq_lens
),
self
.
out_cache_loc
self
.
out_cache_loc
)
)
self
.
seq_lens
.
add_
(
1
)
self
.
seq_lens
.
add_
(
1
)
...
@@ -848,7 +849,6 @@ class ScheduleBatch:
...
@@ -848,7 +849,6 @@ class ScheduleBatch:
extend_logprob_start_lens
=
self
.
extend_logprob_start_lens
extend_logprob_start_lens
=
self
.
extend_logprob_start_lens
image_inputs
=
[
r
.
image_inputs
for
r
in
self
.
reqs
]
image_inputs
=
[
r
.
image_inputs
for
r
in
self
.
reqs
]
lora_paths
=
[
req
.
lora_path
for
req
in
self
.
reqs
]
if
self
.
has_regex
:
if
self
.
has_regex
:
self
.
sampling_info
.
regex_fsms
=
[
req
.
regex_fsm
for
req
in
self
.
reqs
]
self
.
sampling_info
.
regex_fsms
=
[
req
.
regex_fsm
for
req
in
self
.
reqs
]
self
.
sampling_info
.
regex_fsm_states
=
[
self
.
sampling_info
.
regex_fsm_states
=
[
...
@@ -869,13 +869,14 @@ class ScheduleBatch:
...
@@ -869,13 +869,14 @@ class ScheduleBatch:
req_pool_indices
=
self
.
req_pool_indices
,
req_pool_indices
=
self
.
req_pool_indices
,
seq_lens
=
self
.
seq_lens
,
seq_lens
=
self
.
seq_lens
,
out_cache_loc
=
self
.
out_cache_loc
,
out_cache_loc
=
self
.
out_cache_loc
,
req_to_token_pool_records
=
self
.
req_to_token_pool
.
get_write_records
(),
return_logprob
=
self
.
return_logprob
,
return_logprob
=
self
.
return_logprob
,
top_logprobs_nums
=
self
.
top_logprobs_nums
,
top_logprobs_nums
=
self
.
top_logprobs_nums
,
extend_seq_lens
=
extend_seq_lens
,
extend_seq_lens
=
extend_seq_lens
,
extend_prefix_lens
=
extend_prefix_lens
,
extend_prefix_lens
=
extend_prefix_lens
,
extend_logprob_start_lens
=
extend_logprob_start_lens
,
extend_logprob_start_lens
=
extend_logprob_start_lens
,
image_inputs
=
image_inputs
,
image_inputs
=
image_inputs
,
lora_paths
=
lora_path
s
,
lora_paths
=
[
req
.
lora_path
for
req
in
self
.
reqs
]
,
sampling_info
=
self
.
sampling_info
,
sampling_info
=
self
.
sampling_info
,
mrope_positions_delta
=
mrope_positions_delta
,
mrope_positions_delta
=
mrope_positions_delta
,
)
)
...
@@ -911,6 +912,9 @@ class ModelWorkerBatch:
...
@@ -911,6 +912,9 @@ class ModelWorkerBatch:
# The indices of output tokens in the token_to_kv_pool
# The indices of output tokens in the token_to_kv_pool
out_cache_loc
:
torch
.
Tensor
out_cache_loc
:
torch
.
Tensor
# The memory pool operation records
req_to_token_pool_records
:
Optional
[
List
[
Tuple
[
Tuple
,
torch
.
Tensor
]]]
# For logprob
# For logprob
return_logprob
:
bool
return_logprob
:
bool
top_logprobs_nums
:
Optional
[
List
[
int
]]
top_logprobs_nums
:
Optional
[
List
[
int
]]
...
@@ -940,6 +944,7 @@ class ModelWorkerBatch:
...
@@ -940,6 +944,7 @@ class ModelWorkerBatch:
req_pool_indices
=
self
.
req_pool_indices
,
req_pool_indices
=
self
.
req_pool_indices
,
seq_lens
=
self
.
seq_lens
.
clone
(),
seq_lens
=
self
.
seq_lens
.
clone
(),
out_cache_loc
=
self
.
out_cache_loc
,
out_cache_loc
=
self
.
out_cache_loc
,
req_to_token_pool_records
=
self
.
req_to_token_pool_records
,
return_logprob
=
self
.
return_logprob
,
return_logprob
=
self
.
return_logprob
,
top_logprobs_nums
=
self
.
top_logprobs_nums
,
top_logprobs_nums
=
self
.
top_logprobs_nums
,
extend_seq_lens
=
self
.
extend_seq_lens
,
extend_seq_lens
=
self
.
extend_seq_lens
,
...
@@ -950,3 +955,14 @@ class ModelWorkerBatch:
...
@@ -950,3 +955,14 @@ class ModelWorkerBatch:
sampling_info
=
self
.
sampling_info
.
copy
(),
sampling_info
=
self
.
sampling_info
.
copy
(),
mrope_positions_delta
=
self
.
mrope_positions_delta
,
mrope_positions_delta
=
self
.
mrope_positions_delta
,
)
)
def
to
(
self
,
device
:
str
):
self
.
input_ids
=
self
.
input_ids
.
to
(
device
,
non_blocking
=
True
)
self
.
req_pool_indices
=
self
.
req_pool_indices
.
to
(
device
,
non_blocking
=
True
)
self
.
seq_lens
=
self
.
seq_lens
.
to
(
device
,
non_blocking
=
True
)
self
.
out_cache_loc
=
self
.
out_cache_loc
.
to
(
device
,
non_blocking
=
True
)
self
.
req_to_token_pool_records
=
[
(
x
,
y
.
to
(
device
,
non_blocking
=
True
))
for
x
,
y
in
self
.
req_to_token_pool_records
]
self
.
sampling_info
.
to
(
device
)
python/sglang/srt/managers/scheduler.py
View file @
59cbf476
...
@@ -51,6 +51,7 @@ from sglang.srt.managers.schedule_batch import (
...
@@ -51,6 +51,7 @@ from sglang.srt.managers.schedule_batch import (
ImageInputs
,
ImageInputs
,
Req
,
Req
,
ScheduleBatch
,
ScheduleBatch
,
global_server_args_dict
,
)
)
from
sglang.srt.managers.schedule_policy
import
(
from
sglang.srt.managers.schedule_policy
import
(
AddReqResult
,
AddReqResult
,
...
@@ -144,25 +145,27 @@ class Scheduler:
...
@@ -144,25 +145,27 @@ class Scheduler:
)
)
# Launch a tensor parallel worker
# Launch a tensor parallel worker
self
.
tp_worker
=
TpModelWorker
(
if
self
.
server_args
.
enable_overlap_schedule
:
TpWorkerClass
=
TpModelWorker
else
:
TpWorkerClass
=
TpModelWorker
self
.
tp_worker
=
TpWorkerClass
(
server_args
=
server_args
,
server_args
=
server_args
,
gpu_id
=
gpu_id
,
gpu_id
=
gpu_id
,
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
dp_rank
=
dp_rank
,
dp_rank
=
dp_rank
,
nccl_port
=
port_args
.
nccl_port
,
nccl_port
=
port_args
.
nccl_port
,
)
)
# Init states for overlap schedule
if
self
.
server_args
.
enable_overlap_schedule
:
if
self
.
server_args
.
enable_overlap_schedule
:
self
.
forward_batch_generation
=
(
self
.
tp_worker
.
forward_batch_generation_non_blocking
)
self
.
resolve_next_token_ids
=
(
self
.
resolve_next_token_ids
=
(
lambda
bid
,
x
:
self
.
tp_worker
.
resolve_future_token_ids
(
bid
)
lambda
bid
,
x
:
self
.
tp_worker
.
resolve_future_token_ids
(
bid
)
)
)
self
.
forward_batch_generation
=
(
self
.
tp_worker
.
forward_batch_generation_non_blocking
)
else
:
else
:
self
.
forward_batch_generation
=
self
.
tp_worker
.
forward_batch_generation
self
.
resolve_next_token_ids
=
lambda
bid
,
x
:
x
.
tolist
()
self
.
resolve_next_token_ids
=
lambda
bid
,
x
:
x
.
tolist
()
self
.
forward_batch_generation
=
self
.
tp_worker
.
forward_batch_generation
# Get token and memory info from the model worker
# Get token and memory info from the model worker
(
(
...
@@ -172,9 +175,14 @@ class Scheduler:
...
@@ -172,9 +175,14 @@ class Scheduler:
self
.
max_req_input_len
,
self
.
max_req_input_len
,
self
.
random_seed
,
self
.
random_seed
,
self
.
device
,
self
.
device
,
)
=
self
.
tp_worker
.
get_token_and_memory_info
()
worker_global_server_args_dict
,
_
,
_
,
_
,
)
=
self
.
tp_worker
.
get_worker_info
()
self
.
tp_cpu_group
=
self
.
tp_worker
.
get_tp_cpu_group
()
self
.
tp_cpu_group
=
self
.
tp_worker
.
get_tp_cpu_group
()
self
.
pad_input_ids_func
=
self
.
tp_worker
.
get_pad_input_ids_func
()
self
.
pad_input_ids_func
=
self
.
tp_worker
.
get_pad_input_ids_func
()
global_server_args_dict
.
update
(
worker_global_server_args_dict
)
set_random_seed
(
self
.
random_seed
)
set_random_seed
(
self
.
random_seed
)
# Print debug info
# Print debug info
...
@@ -266,6 +274,7 @@ class Scheduler:
...
@@ -266,6 +274,7 @@ class Scheduler:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
event_loop_normal
(
self
):
def
event_loop_normal
(
self
):
"""A normal blocking scheduler loop."""
self
.
last_batch
=
None
self
.
last_batch
=
None
while
True
:
while
True
:
...
@@ -296,6 +305,7 @@ class Scheduler:
...
@@ -296,6 +305,7 @@ class Scheduler:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
event_loop_overlap
(
self
):
def
event_loop_overlap
(
self
):
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
result_queue
=
deque
()
result_queue
=
deque
()
self
.
last_batch
=
None
self
.
last_batch
=
None
...
@@ -572,6 +582,7 @@ class Scheduler:
...
@@ -572,6 +582,7 @@ class Scheduler:
else
set
([])
else
set
([])
)
)
# Get requests from the waiting queue to a new prefill batch
for
req
in
self
.
waiting_queue
:
for
req
in
self
.
waiting_queue
:
if
(
if
(
self
.
lora_paths
self
.
lora_paths
...
@@ -673,6 +684,7 @@ class Scheduler:
...
@@ -673,6 +684,7 @@ class Scheduler:
return
new_batch
return
new_batch
def
update_running_batch
(
self
):
def
update_running_batch
(
self
):
"""Update the current running decoding batch."""
global
test_retract
global
test_retract
batch
=
self
.
running_batch
batch
=
self
.
running_batch
...
@@ -712,6 +724,7 @@ class Scheduler:
...
@@ -712,6 +724,7 @@ class Scheduler:
batch
.
prepare_for_decode
()
batch
.
prepare_for_decode
()
def
run_batch
(
self
,
batch
:
ScheduleBatch
):
def
run_batch
(
self
,
batch
:
ScheduleBatch
):
"""Run a batch."""
if
self
.
is_generation
:
if
self
.
is_generation
:
if
batch
.
forward_mode
.
is_decode
()
or
batch
.
extend_num_tokens
!=
0
:
if
batch
.
forward_mode
.
is_decode
()
or
batch
.
extend_num_tokens
!=
0
:
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
...
@@ -933,6 +946,7 @@ class Scheduler:
...
@@ -933,6 +946,7 @@ class Scheduler:
return
num_input_logprobs
return
num_input_logprobs
def
stream_output
(
self
,
reqs
:
List
[
Req
]):
def
stream_output
(
self
,
reqs
:
List
[
Req
]):
"""Stream the output to detokenizer."""
output_rids
=
[]
output_rids
=
[]
output_meta_info
=
[]
output_meta_info
=
[]
output_finished_reason
:
List
[
BaseFinishReason
]
=
[]
output_finished_reason
:
List
[
BaseFinishReason
]
=
[]
...
@@ -1030,6 +1044,7 @@ class Scheduler:
...
@@ -1030,6 +1044,7 @@ class Scheduler:
)
)
def
flush_cache
(
self
):
def
flush_cache
(
self
):
"""Flush the memory pool and cache."""
if
len
(
self
.
waiting_queue
)
==
0
and
(
if
len
(
self
.
waiting_queue
)
==
0
and
(
self
.
running_batch
is
None
or
len
(
self
.
running_batch
.
reqs
)
==
0
self
.
running_batch
is
None
or
len
(
self
.
running_batch
.
reqs
)
==
0
):
):
...
@@ -1070,6 +1085,7 @@ class Scheduler:
...
@@ -1070,6 +1085,7 @@ class Scheduler:
break
break
def
update_weights
(
self
,
recv_req
:
UpdateWeightReqInput
):
def
update_weights
(
self
,
recv_req
:
UpdateWeightReqInput
):
"""In-place update of the weights."""
success
,
message
=
self
.
tp_worker
.
update_weights
(
recv_req
)
success
,
message
=
self
.
tp_worker
.
update_weights
(
recv_req
)
if
success
:
if
success
:
flash_cache_success
=
self
.
flush_cache
()
flash_cache_success
=
self
.
flush_cache
()
...
...
python/sglang/srt/managers/tp_worker.py
View file @
59cbf476
...
@@ -27,7 +27,7 @@ import torch
...
@@ -27,7 +27,7 @@ import torch
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.managers.io_struct
import
UpdateWeightReqInput
from
sglang.srt.managers.io_struct
import
UpdateWeightReqInput
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
,
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
...
@@ -111,7 +111,7 @@ class TpModelWorker:
...
@@ -111,7 +111,7 @@ class TpModelWorker:
if
server_args
.
enable_overlap_schedule
:
if
server_args
.
enable_overlap_schedule
:
self
.
init_overlap_status
()
self
.
init_overlap_status
()
def
get_
token_and_memory
_info
(
self
):
def
get_
worker
_info
(
self
):
return
(
return
(
self
.
max_total_num_tokens
,
self
.
max_total_num_tokens
,
self
.
max_prefill_tokens
,
self
.
max_prefill_tokens
,
...
@@ -119,6 +119,10 @@ class TpModelWorker:
...
@@ -119,6 +119,10 @@ class TpModelWorker:
self
.
max_req_input_len
,
self
.
max_req_input_len
,
self
.
random_seed
,
self
.
random_seed
,
self
.
device
,
self
.
device
,
global_server_args_dict
,
self
.
model_runner
.
req_to_token_pool
.
size
,
self
.
model_runner
.
req_to_token_pool
.
max_context_len
,
self
.
model_runner
.
token_to_kv_pool
.
size
,
)
)
def
get_pad_input_ids_func
(
self
):
def
get_pad_input_ids_func
(
self
):
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
59cbf476
...
@@ -56,6 +56,12 @@ class ReqToTokenPool:
...
@@ -56,6 +56,12 @@ class ReqToTokenPool:
def
clear
(
self
):
def
clear
(
self
):
self
.
free_slots
=
list
(
range
(
self
.
size
))
self
.
free_slots
=
list
(
range
(
self
.
size
))
def
write
(
self
,
indices
,
values
):
self
.
req_to_token
[
indices
]
=
values
def
get_write_records
(
self
):
return
None
class
BaseTokenToKVPool
:
class
BaseTokenToKVPool
:
"""A memory pool that maps a token to its kv cache locations"""
"""A memory pool that maps a token to its kv cache locations"""
...
@@ -68,12 +74,12 @@ class BaseTokenToKVPool:
...
@@ -68,12 +74,12 @@ class BaseTokenToKVPool:
):
):
self
.
size
=
size
self
.
size
=
size
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
device
=
device
if
dtype
==
torch
.
float8_e5m2
:
if
dtype
==
torch
.
float8_e5m2
:
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
self
.
store_dtype
=
torch
.
uint8
self
.
store_dtype
=
torch
.
uint8
else
:
else
:
self
.
store_dtype
=
dtype
self
.
store_dtype
=
dtype
self
.
device
=
device
self
.
free_slots
=
None
self
.
free_slots
=
None
self
.
is_not_in_free_group
=
True
self
.
is_not_in_free_group
=
True
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
59cbf476
...
@@ -145,9 +145,10 @@ class RadixCache(BasePrefixCache):
...
@@ -145,9 +145,10 @@ class RadixCache(BasePrefixCache):
# The prefix indices could be updated, reuse it
# The prefix indices could be updated, reuse it
new_indices
,
new_last_node
=
self
.
match_prefix
(
token_ids
)
new_indices
,
new_last_node
=
self
.
match_prefix
(
token_ids
)
assert
len
(
new_indices
)
==
len
(
token_ids
)
assert
len
(
new_indices
)
==
len
(
token_ids
)
self
.
req_to_token_pool
.
req_to_token
[
self
.
req_to_token_pool
.
write
(
req
.
req_pool_idx
,
len
(
req
.
prefix_indices
)
:
len
(
new_indices
)
(
req
.
req_pool_idx
,
slice
(
len
(
req
.
prefix_indices
),
len
(
new_indices
))),
]
=
new_indices
[
len
(
req
.
prefix_indices
)
:]
new_indices
[
len
(
req
.
prefix_indices
)
:],
)
self
.
dec_lock_ref
(
req
.
last_node
)
self
.
dec_lock_ref
(
req
.
last_node
)
self
.
inc_lock_ref
(
new_last_node
)
self
.
inc_lock_ref
(
new_last_node
)
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
59cbf476
...
@@ -25,6 +25,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
...
@@ -25,6 +25,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
It contains high-level scheduling data. Most of the data is on the CPU.
It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
It will be transformed from CPU scheduler to GPU model runner.
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
It contains low-level tensor data. Most of the data consists of GPU tensors.
It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
"""
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
59cbf476
...
@@ -131,6 +131,13 @@ class ModelRunner:
...
@@ -131,6 +131,13 @@ class ModelRunner:
]:
]:
server_args
.
disable_cuda_graph
=
True
server_args
.
disable_cuda_graph
=
True
if
self
.
server_args
.
enable_overlap_schedule
:
logger
.
warning
(
"Overlap scheduler is enabled. This is an experimental feature. "
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
"and embedding APIs are not supported and will lead to wrong results."
)
# Global vars
# Global vars
if
server_args
.
show_time_cost
:
if
server_args
.
show_time_cost
:
enable_show_time_cost
()
enable_show_time_cost
()
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
59cbf476
...
@@ -78,7 +78,7 @@ class SamplingBatchInfo:
...
@@ -78,7 +78,7 @@ class SamplingBatchInfo:
need_min_p_sampling
=
any
(
r
.
sampling_params
.
min_p
>
0
for
r
in
reqs
),
need_min_p_sampling
=
any
(
r
.
sampling_params
.
min_p
>
0
for
r
in
reqs
),
is_all_greedy
=
top_ks
.
max
().
item
()
<=
1
,
is_all_greedy
=
top_ks
.
max
().
item
()
<=
1
,
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
device
=
batch
.
input_ids
.
device
,
device
=
device
,
)
)
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
...
@@ -224,3 +224,13 @@ class SamplingBatchInfo:
...
@@ -224,3 +224,13 @@ class SamplingBatchInfo:
vocab_size
=
self
.
vocab_size
,
vocab_size
=
self
.
vocab_size
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
def
to
(
self
,
device
:
str
):
for
item
in
[
"temperatures"
,
"top_ps"
,
"top_ks"
,
"min_ps"
,
]:
value
=
getattr
(
self
,
item
)
setattr
(
self
,
item
,
value
.
to
(
device
,
non_blocking
=
True
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment