Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
501dfa6b
Unverified
Commit
501dfa6b
authored
Oct 07, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 07, 2025
Browse files
Remove sampling info events and overlap thread file (#11300)
parent
79d34951
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
13 additions
and
393 deletions
+13
-393
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+0
-13
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+0
-15
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+0
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-24
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+0
-3
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+0
-307
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+0
-7
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+5
-9
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+7
-14
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
501dfa6b
...
...
@@ -783,16 +783,6 @@ class SchedulerDisaggregationDecodeMixin:
self
.
prepare_mlp_sync_batch
(
batch
)
result
=
self
.
run_batch
(
batch
)
self
.
result_queue
.
append
((
batch
.
copy
(),
result
))
if
(
self
.
last_batch
is
None
)
or
(
not
self
.
last_batch_in_queue
):
# Create a dummy first batch to start the pipeline for overlap schedule.
# It is now used for triggering the sampling_info_done event.
tmp_batch
=
ScheduleBatch
(
reqs
=
None
,
forward_mode
=
ForwardMode
.
DUMMY_FIRST
,
next_batch_sampling_info
=
self
.
tp_worker
.
cur_sampling_info
,
)
self
.
set_next_batch_sampling_info_done
(
tmp_batch
)
last_batch_in_queue
=
True
elif
prepare_mlp_sync_flag
:
...
...
@@ -806,9 +796,6 @@ class SchedulerDisaggregationDecodeMixin:
# Process the results of the previous batch but skip if the last batch is extend
if
self
.
last_batch
and
self
.
last_batch_in_queue
:
tmp_batch
,
tmp_result
=
self
.
result_queue
.
popleft
()
tmp_batch
.
next_batch_sampling_info
=
(
self
.
tp_worker
.
cur_sampling_info
if
batch
else
None
)
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
queue_size
=
(
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
501dfa6b
...
...
@@ -338,21 +338,8 @@ class SchedulerDisaggregationPrefillMixin:
result
=
self
.
run_batch
(
batch
)
self
.
result_queue
.
append
((
batch
.
copy
(),
result
))
if
self
.
last_batch
is
None
:
# Create a dummy first batch to start the pipeline for overlap schedule.
# It is now used for triggering the sampling_info_done event.
tmp_batch
=
ScheduleBatch
(
reqs
=
None
,
forward_mode
=
ForwardMode
.
DUMMY_FIRST
,
next_batch_sampling_info
=
self
.
tp_worker
.
cur_sampling_info
,
)
self
.
set_next_batch_sampling_info_done
(
tmp_batch
)
if
self
.
last_batch
:
tmp_batch
,
tmp_result
=
self
.
result_queue
.
popleft
()
tmp_batch
.
next_batch_sampling_info
=
(
self
.
tp_worker
.
cur_sampling_info
if
batch
else
None
)
self
.
process_batch_result_disagg_prefill
(
tmp_batch
,
tmp_result
)
if
len
(
self
.
disagg_prefill_inflight_queue
)
>
0
:
...
...
@@ -491,8 +478,6 @@ class SchedulerDisaggregationPrefillMixin:
if
self
.
enable_overlap
:
self
.
send_kv_chunk
(
req
,
last_chunk
=
False
,
end_idx
=
req
.
tmp_end_idx
)
# We need to remove the sync in the following function for overlap schedule.
self
.
set_next_batch_sampling_info_done
(
batch
)
self
.
maybe_send_health_check_signal
()
def
process_disagg_prefill_inflight_queue
(
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
501dfa6b
...
...
@@ -891,7 +891,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Sampling info
sampling_info
:
SamplingBatchInfo
=
None
next_batch_sampling_info
:
SamplingBatchInfo
=
None
# Batched arguments to model runner
input_ids
:
torch
.
Tensor
=
None
# shape: [b], int64
...
...
python/sglang/srt/managers/scheduler.py
View file @
501dfa6b
...
...
@@ -1012,22 +1012,9 @@ class Scheduler(
result
=
self
.
run_batch
(
batch
)
self
.
result_queue
.
append
((
batch
.
copy
(),
result
))
if
self
.
last_batch
is
None
:
# Create a dummy first batch to start the pipeline for overlap schedule.
# It is now used for triggering the sampling_info_done event.
tmp_batch
=
ScheduleBatch
(
reqs
=
None
,
forward_mode
=
ForwardMode
.
DUMMY_FIRST
,
next_batch_sampling_info
=
self
.
tp_worker
.
cur_sampling_info
,
)
self
.
process_batch_result
(
tmp_batch
,
None
)
if
self
.
last_batch
:
# Process the results of the last batch
tmp_batch
,
tmp_result
=
self
.
result_queue
.
popleft
()
tmp_batch
.
next_batch_sampling_info
=
(
self
.
tp_worker
.
cur_sampling_info
if
batch
else
None
)
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
elif
batch
is
None
:
# When the server is idle, do self-check and re-init some states
...
...
@@ -2100,7 +2087,7 @@ class Scheduler(
self
.
record_batch_in_overlap
(
model_worker_batch
)
# Sampling info will be modified during forward
model_worker_batch
.
sampling_info
=
self
.
tp_worker
.
cur_sampling_info
=
(
model_worker_batch
.
sampling_info
=
(
model_worker_batch
.
sampling_info
.
copy_for_forward
()
)
...
...
@@ -2219,9 +2206,6 @@ class Scheduler(
if
self
.
enable_overlap
:
if
result
.
copy_done
is
not
None
:
result
.
copy_done
.
synchronize
()
self
.
set_next_batch_sampling_info_done
(
batch
)
elif
batch
.
forward_mode
.
is_dummy_first
():
self
.
set_next_batch_sampling_info_done
(
batch
)
self
.
maybe_send_health_check_signal
()
...
...
@@ -2431,13 +2415,6 @@ class Scheduler(
self
.
_add_request_to_queue
(
req
)
self
.
grammar_queue
=
self
.
grammar_queue
[
num_ready_reqs
:]
def
set_next_batch_sampling_info_done
(
self
,
batch
:
ScheduleBatch
):
if
batch
.
next_batch_sampling_info
:
if
batch
.
next_batch_sampling_info
.
grammars
is
not
None
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
self
.
default_stream
.
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
def
watchdog_thread
(
self
):
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
self
.
watchdog_last_forward_ct
=
0
...
...
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
501dfa6b
...
...
@@ -173,8 +173,6 @@ class SchedulerOutputProcessorMixin:
)
logprob_pt
+=
num_input_logprobs
self
.
set_next_batch_sampling_info_done
(
batch
)
else
:
# embedding or reward model
embeddings
=
result
.
embeddings
.
tolist
()
...
...
@@ -295,7 +293,6 @@ class SchedulerOutputProcessorMixin:
self
.
abort_request
(
AbortReq
(
rid
=
req
.
rid
))
req
.
grammar
.
finished
=
req
.
finished
()
self
.
set_next_batch_sampling_info_done
(
batch
)
self
.
stream_output
(
batch
.
reqs
,
batch
.
return_logprob
)
self
.
token_to_kv_pool_allocator
.
free_group_end
()
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
deleted
100644 → 0
View file @
79d34951
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A tensor parallel worker."""
from
__future__
import
annotations
import
dataclasses
import
logging
import
signal
import
threading
from
queue
import
Queue
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
import
psutil
import
torch
from
sglang.srt.managers.io_struct
import
(
DestroyWeightsUpdateGroupReqInput
,
GetWeightsByNameReqInput
,
InitWeightsSendGroupForRemoteInstanceReqInput
,
InitWeightsUpdateGroupReqInput
,
LoadLoRAAdapterReqInput
,
SendWeightsToRemoteInstanceReqInput
,
UnloadLoRAAdapterReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromTensorReqInput
,
)
from
sglang.srt.managers.overlap_utils
import
FutureMap
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatchOutput
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
DynamicGradMode
from
sglang.utils
import
get_exception_traceback
if
TYPE_CHECKING
:
from
sglang.srt.managers.cache_controller
import
LayerDoneCounter
logger
=
logging
.
getLogger
(
__name__
)
class
TpModelWorkerClient
:
"""A tensor parallel model worker."""
def
__init__
(
self
,
server_args
:
ServerArgs
,
gpu_id
:
int
,
tp_rank
:
int
,
moe_ep_rank
:
int
,
pp_rank
:
int
,
dp_rank
:
Optional
[
int
],
nccl_port
:
int
,
):
# Load the model
self
.
worker
=
TpModelWorker
(
server_args
,
gpu_id
,
tp_rank
,
moe_ep_rank
,
pp_rank
,
dp_rank
,
nccl_port
)
self
.
max_running_requests
=
self
.
worker
.
max_running_requests
self
.
device
=
self
.
worker
.
device
self
.
gpu_id
=
gpu_id
# Init future mappings
self
.
future_map
=
FutureMap
(
self
.
max_running_requests
,
self
.
device
)
# Launch threads
self
.
input_queue
=
Queue
[
Tuple
[
ModelWorkerBatch
,
int
,
torch
.
Event
]]()
self
.
output_queue
=
Queue
()
self
.
forward_stream
=
torch
.
get_device_module
(
self
.
device
).
Stream
()
self
.
forward_thread
=
threading
.
Thread
(
target
=
self
.
forward_thread_func
,
)
self
.
forward_thread
.
start
()
self
.
parent_process
=
psutil
.
Process
().
parent
()
self
.
scheduler_stream
=
torch
.
get_device_module
(
self
.
device
).
current_stream
()
if
self
.
device
==
"cpu"
:
self
.
scheduler_stream
.
synchronize
=
lambda
:
None
# No-op for CPU
self
.
hicache_layer_transfer_counter
=
None
def
register_hicache_layer_transfer_counter
(
self
,
counter
:
LayerDoneCounter
):
self
.
hicache_layer_transfer_counter
=
counter
def
get_worker_info
(
self
):
return
self
.
worker
.
get_worker_info
()
def
get_tokens_per_layer_info
(
self
):
return
self
.
worker
.
get_tokens_per_layer_info
()
@
property
def
sliding_window_size
(
self
)
->
Optional
[
int
]:
return
self
.
worker
.
sliding_window_size
@
property
def
is_hybrid
(
self
)
->
bool
:
return
self
.
worker
.
is_hybrid
def
get_pad_input_ids_func
(
self
):
return
self
.
worker
.
get_pad_input_ids_func
()
def
get_tp_group
(
self
):
return
self
.
worker
.
get_tp_group
()
def
get_attention_tp_group
(
self
):
return
self
.
worker
.
get_attention_tp_group
()
def
get_attention_tp_cpu_group
(
self
):
return
self
.
worker
.
get_attention_tp_cpu_group
()
def
get_memory_pool
(
self
):
return
(
self
.
worker
.
model_runner
.
req_to_token_pool
,
self
.
worker
.
model_runner
.
token_to_kv_pool_allocator
,
)
def
get_kv_cache
(
self
):
return
self
.
worker
.
model_runner
.
token_to_kv_pool
def
forward_thread_func
(
self
):
try
:
with
torch
.
get_device_module
(
self
.
device
).
stream
(
self
.
forward_stream
):
self
.
forward_thread_func_
()
except
Exception
:
traceback
=
get_exception_traceback
()
logger
.
error
(
f
"TpModelWorkerClient hit an exception:
{
traceback
}
"
)
self
.
parent_process
.
send_signal
(
signal
.
SIGQUIT
)
@
DynamicGradMode
()
def
forward_thread_func_
(
self
):
batch_pt
=
0
batch_lists
:
List
=
[
None
]
*
2
while
True
:
model_worker_batch
,
future_map_ct
,
sync_event
=
self
.
input_queue
.
get
()
if
not
model_worker_batch
:
break
sync_event
.
wait
()
# Keep a reference of model_worker_batch by storing it into a list.
# Otherwise, the tensor members of model_worker_batch will be released
# by pytorch and cause CUDA illegal memory access errors.
batch_lists
[
batch_pt
%
2
]
=
model_worker_batch
batch_pt
+=
1
# Create event
copy_done
=
torch
.
get_device_module
(
self
.
device
).
Event
()
# Resolve future tokens in the input
self
.
future_map
.
resolve_future
(
model_worker_batch
)
# Run forward
forward_batch_output
=
self
.
worker
.
forward_batch_generation
(
model_worker_batch
,
model_worker_batch
.
launch_done
,
)
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
forward_batch_output
.
logits_output
,
forward_batch_output
.
next_token_ids
,
forward_batch_output
.
can_run_cuda_graph
,
)
# Update the future token ids map
bs
=
len
(
model_worker_batch
.
seq_lens
)
if
model_worker_batch
.
is_prefill_only
:
# For prefill-only requests, create dummy token IDs on CPU
next_token_ids
=
torch
.
zeros
(
bs
,
dtype
=
torch
.
long
)
# store the future indices into future map
self
.
future_map
.
store_to_map
(
future_map_ct
,
bs
,
next_token_ids
)
# Copy results to the CPU
if
model_worker_batch
.
return_logprob
:
if
logits_output
.
next_token_logprobs
is
not
None
:
logits_output
.
next_token_logprobs
=
(
logits_output
.
next_token_logprobs
.
to
(
"cpu"
,
non_blocking
=
True
)
)
if
logits_output
.
input_token_logprobs
is
not
None
:
logits_output
.
input_token_logprobs
=
(
logits_output
.
input_token_logprobs
.
to
(
"cpu"
,
non_blocking
=
True
)
)
if
logits_output
.
hidden_states
is
not
None
:
logits_output
.
hidden_states
=
logits_output
.
hidden_states
.
to
(
"cpu"
,
non_blocking
=
True
)
# Only copy to CPU if not already on CPU
if
next_token_ids
.
device
.
type
!=
"cpu"
:
next_token_ids
=
next_token_ids
.
to
(
"cpu"
,
non_blocking
=
True
)
copy_done
.
record
()
self
.
output_queue
.
put
(
(
copy_done
,
logits_output
,
next_token_ids
,
can_run_cuda_graph
)
)
def
resolve_last_batch_result
(
self
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
):
"""
This function is called to resolve the last batch result and
wait for the current batch to be launched. Used in overlap mode.
"""
copy_done
,
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
self
.
output_queue
.
get
()
)
if
launch_done
is
not
None
:
launch_done
.
wait
()
copy_done
.
synchronize
()
if
logits_output
.
next_token_logprobs
is
not
None
:
logits_output
.
next_token_logprobs
=
(
logits_output
.
next_token_logprobs
.
tolist
()
)
if
logits_output
.
input_token_logprobs
is
not
None
:
logits_output
.
input_token_logprobs
=
tuple
(
logits_output
.
input_token_logprobs
.
tolist
()
)
next_token_ids
=
next_token_ids
.
tolist
()
return
logits_output
,
next_token_ids
,
can_run_cuda_graph
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
)
->
ForwardBatchOutput
:
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
model_worker_batch
.
sampling_info
=
self
.
cur_sampling_info
=
(
model_worker_batch
.
sampling_info
.
copy_for_forward
()
)
# A cuda stream sync here to avoid the cuda illegal memory access error.
sync_event
=
torch
.
get_device_module
(
self
.
device
).
Event
()
sync_event
.
record
(
self
.
scheduler_stream
)
# Push a new batch to the queue
bs
=
len
(
model_worker_batch
.
seq_lens
)
cur_future_map_ct
=
self
.
future_map
.
update_ct
(
bs
)
self
.
input_queue
.
put
((
model_worker_batch
,
cur_future_map_ct
,
sync_event
))
# get this forward batch's future token ids
future_next_token_ids
=
self
.
future_map
.
update_next_future
(
cur_future_map_ct
,
bs
)
return
ForwardBatchOutput
(
next_token_ids
=
future_next_token_ids
,
can_run_cuda_graph
=
False
,
)
def
update_weights_from_disk
(
self
,
recv_req
:
UpdateWeightFromDiskReqInput
):
success
,
message
=
self
.
worker
.
update_weights_from_disk
(
recv_req
)
return
success
,
message
def
init_weights_update_group
(
self
,
recv_req
:
InitWeightsUpdateGroupReqInput
):
success
,
message
=
self
.
worker
.
init_weights_update_group
(
recv_req
)
return
success
,
message
def
destroy_weights_update_group
(
self
,
recv_req
:
DestroyWeightsUpdateGroupReqInput
):
success
,
message
=
self
.
worker
.
destroy_weights_update_group
(
recv_req
)
return
success
,
message
def
init_weights_send_group_for_remote_instance
(
self
,
recv_req
:
InitWeightsSendGroupForRemoteInstanceReqInput
):
success
,
message
=
self
.
worker
.
init_weights_send_group_for_remote_instance
(
recv_req
)
return
success
,
message
def
send_weights_to_remote_instance
(
self
,
recv_req
:
SendWeightsToRemoteInstanceReqInput
):
success
,
message
=
self
.
worker
.
send_weights_to_remote_instance
(
recv_req
)
return
success
,
message
def
update_weights_from_distributed
(
self
,
recv_req
:
UpdateWeightsFromDistributedReqInput
):
success
,
message
=
self
.
worker
.
update_weights_from_distributed
(
recv_req
)
return
success
,
message
def
update_weights_from_tensor
(
self
,
recv_req
:
UpdateWeightsFromTensorReqInput
):
success
,
message
=
self
.
worker
.
update_weights_from_tensor
(
recv_req
)
return
success
,
message
def
get_weights_by_name
(
self
,
recv_req
:
GetWeightsByNameReqInput
):
return
self
.
worker
.
get_weights_by_name
(
recv_req
)
def
load_lora_adapter
(
self
,
recv_req
:
LoadLoRAAdapterReqInput
):
return
self
.
worker
.
load_lora_adapter
(
recv_req
)
def
unload_lora_adapter
(
self
,
recv_req
:
UnloadLoRAAdapterReqInput
):
return
self
.
worker
.
unload_lora_adapter
(
recv_req
)
def
can_run_lora_batch
(
self
,
lora_ids
:
list
[
str
])
->
bool
:
return
self
.
worker
.
can_run_lora_batch
(
lora_ids
)
def
__delete__
(
self
):
self
.
input_queue
.
put
((
None
,
None
))
self
.
copy_queue
.
put
((
None
,
None
,
None
))
python/sglang/srt/model_executor/forward_batch_info.py
View file @
501dfa6b
...
...
@@ -75,10 +75,6 @@ class ForwardMode(IntEnum):
# Used in speculative decoding: extend a batch in the draft model.
DRAFT_EXTEND
=
auto
()
# A dummy first batch to start the pipeline for overlap scheduler.
# It is now used for triggering the sampling_info_done event for the first prefill batch.
DUMMY_FIRST
=
auto
()
# Split Prefill for PD multiplexing
SPLIT_PREFILL
=
auto
()
...
...
@@ -128,9 +124,6 @@ class ForwardMode(IntEnum):
def
is_cpu_graph
(
self
):
return
self
==
ForwardMode
.
DECODE
def
is_dummy_first
(
self
):
return
self
==
ForwardMode
.
DUMMY_FIRST
def
is_split_prefill
(
self
):
return
self
==
ForwardMode
.
SPLIT_PREFILL
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
501dfa6b
...
...
@@ -2057,15 +2057,11 @@ class ModelRunner:
def
_preprocess_logits
(
self
,
logits_output
:
LogitsProcessorOutput
,
sampling_info
:
SamplingBatchInfo
):
# Apply logit bias
if
sampling_info
.
sampling_info_done
:
# Overlap mode: the function update_regex_vocab_mask was executed
# in process_batch_result of the last batch.
if
sampling_info
.
grammars
:
sampling_info
.
sampling_info_done
.
wait
()
else
:
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info
.
update_regex_vocab_mask
()
# NOTE: In overlap mode, the function update_regex_vocab_mask (in sample)
# was executed after we processed last batch's results.
# Calculate logits bias and apply it to next_token_logits.
sampling_info
.
update_regex_vocab_mask
()
sampling_info
.
apply_logits_bias
(
logits_output
.
next_token_logits
)
def
sample
(
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
501dfa6b
...
...
@@ -44,12 +44,9 @@ class SamplingBatchInfo:
vocab_mask
:
Optional
[
torch
.
Tensor
]
=
None
apply_mask_func
:
Optional
[
Callable
[[
torch
.
Tensor
,
torch
.
Tensor
],
None
]]
=
None
# An event used for overlap schedule
sampling_info_done
:
Optional
[
threading
.
Event
]
=
None
# Penalizer
penalizer_orchestrator
:
Optional
[
penaltylib
.
BatchedPenalizerOrchestrator
]
=
None
linear_penalt
y
:
torch
.
Tensor
=
None
acc_
linear_penalt
ies
:
torch
.
Tensor
=
None
# Used in the overlap mode
# Whether any request has custom logit processor
has_custom_logit_processor
:
bool
=
False
...
...
@@ -217,19 +214,19 @@ class SamplingBatchInfo:
def
update_penalties
(
self
):
if
self
.
penalizer_orchestrator
.
is_required
:
self
.
linear_penalt
y
=
torch
.
zeros
(
self
.
acc_
linear_penalt
ies
=
torch
.
zeros
(
(
len
(
self
.
temperatures
),
self
.
vocab_size
),
dtype
=
torch
.
float32
,
device
=
self
.
temperatures
.
device
,
)
self
.
penalizer_orchestrator
.
apply
(
self
.
linear_penalt
y
)
self
.
penalizer_orchestrator
.
apply
(
self
.
acc_
linear_penalt
ies
)
else
:
self
.
linear_penalt
y
=
None
self
.
acc_
linear_penalt
ies
=
None
def
apply_logits_bias
(
self
,
logits
:
torch
.
Tensor
):
if
self
.
linear_penalt
y
is
not
None
:
if
self
.
acc_
linear_penalt
ies
is
not
None
:
# Used in the overlap mode
logits
.
add_
(
self
.
linear_penalt
y
)
logits
.
add_
(
self
.
acc_
linear_penalt
ies
)
if
self
.
penalizer_orchestrator
and
self
.
penalizer_orchestrator
.
is_required
:
# Used in the non-overlap mode
...
...
@@ -373,11 +370,7 @@ class SamplingBatchInfo:
def
copy_for_forward
(
self
):
# Accumulate the penalty into a pre-allocated buffer to get rid of the dependency of `penalizer_orchestrator` later
self
.
update_penalties
()
return
dataclasses
.
replace
(
self
,
sampling_info_done
=
threading
.
Event
(),
penalizer_orchestrator
=
None
,
)
return
dataclasses
.
replace
(
self
,
penalizer_orchestrator
=
None
)
def
merge_bias_tensor
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment