Unverified Commit bf8f7a94 authored by Scott Lee's avatar Scott Lee Committed by GitHub
Browse files

Add per-request retraction count (#11177)

parent 81a632ac
...@@ -275,6 +275,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): ...@@ -275,6 +275,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
output_hidden_states=recv_obj.output_hidden_states, output_hidden_states=recv_obj.output_hidden_states,
placeholder_tokens_idx=None, placeholder_tokens_idx=None,
placeholder_tokens_val=None, placeholder_tokens_val=None,
retraction_counts=recv_obj.retraction_counts,
token_steps=recv_obj.token_steps, token_steps=recv_obj.token_steps,
) )
......
...@@ -860,6 +860,9 @@ class BatchTokenIDOutput(BaseBatchReq): ...@@ -860,6 +860,9 @@ class BatchTokenIDOutput(BaseBatchReq):
placeholder_tokens_idx: List[Optional[List[int]]] placeholder_tokens_idx: List[Optional[List[int]]]
placeholder_tokens_val: List[Optional[List[int]]] placeholder_tokens_val: List[Optional[List[int]]]
# Number of times each request was retracted.
retraction_counts: List[int]
# The trainer step id. Used to know which step's weights are used for sampling. # The trainer step id. Used to know which step's weights are used for sampling.
token_steps: List[List[int]] = None token_steps: List[List[int]] = None
...@@ -936,6 +939,9 @@ class BatchStrOutput(BaseBatchReq): ...@@ -936,6 +939,9 @@ class BatchStrOutput(BaseBatchReq):
placeholder_tokens_idx: List[Optional[List[int]]] placeholder_tokens_idx: List[Optional[List[int]]]
placeholder_tokens_val: List[Optional[List[int]]] placeholder_tokens_val: List[Optional[List[int]]]
# Number of times each request was retracted.
retraction_counts: List[int]
# The trainer step id. Used to know which step's weights are used for sampling. # The trainer step id. Used to know which step's weights are used for sampling.
token_steps: List[List[int]] = None token_steps: List[List[int]] = None
...@@ -978,6 +984,9 @@ class BatchEmbeddingOutput(BaseBatchReq): ...@@ -978,6 +984,9 @@ class BatchEmbeddingOutput(BaseBatchReq):
placeholder_tokens_idx: List[Optional[List[int]]] placeholder_tokens_idx: List[Optional[List[int]]]
placeholder_tokens_val: List[Optional[List[int]]] placeholder_tokens_val: List[Optional[List[int]]]
# Number of times each request was retracted.
retraction_counts: List[int]
@dataclass @dataclass
class ClearHiCacheReqInput(BaseReq): class ClearHiCacheReqInput(BaseReq):
......
...@@ -334,6 +334,11 @@ def _handle_output_by_index(output, i): ...@@ -334,6 +334,11 @@ def _handle_output_by_index(output, i):
), ),
placeholder_tokens_idx=None, placeholder_tokens_idx=None,
placeholder_tokens_val=None, placeholder_tokens_val=None,
retraction_counts=(
[output.retraction_counts[i]]
if len(output.retraction_counts) > i
else None
),
token_steps=([output.token_steps[i]] if output.token_steps else None), token_steps=([output.token_steps[i]] if output.token_steps else None),
) )
elif isinstance(output, BatchMultimodalOutput): elif isinstance(output, BatchMultimodalOutput):
......
...@@ -623,6 +623,9 @@ class Req: ...@@ -623,6 +623,9 @@ class Req:
# This is used to compute the acceptance rate and average acceptance length per request. # This is used to compute the acceptance rate and average acceptance length per request.
self.spec_accepted_tokens = 0 self.spec_accepted_tokens = 0
# The number of times this request has been retracted / preempted.
self.retraction_count = 0
# For metrics # For metrics
self.metrics_collector = metrics_collector self.metrics_collector = metrics_collector
self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode) self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
...@@ -883,6 +886,10 @@ class Req: ...@@ -883,6 +886,10 @@ class Req:
return return
def reset_for_retract(self): def reset_for_retract(self):
# Increment retraction count before resetting other state. We should not reset this
# since we are tracking the total number of retractions for each request.
self.retraction_count += 1
self.prefix_indices = torch.empty((0,), dtype=torch.int64) self.prefix_indices = torch.empty((0,), dtype=torch.int64)
self.last_node = None self.last_node = None
self.swa_uuid_for_lock = None self.swa_uuid_for_lock = None
......
...@@ -730,6 +730,7 @@ class SchedulerOutputProcessorMixin: ...@@ -730,6 +730,7 @@ class SchedulerOutputProcessorMixin:
cached_tokens = [] cached_tokens = []
spec_verify_ct = [] spec_verify_ct = []
spec_accepted_tokens = [] spec_accepted_tokens = []
retraction_counts = []
output_hidden_states = None output_hidden_states = None
if return_logprob: if return_logprob:
...@@ -831,6 +832,8 @@ class SchedulerOutputProcessorMixin: ...@@ -831,6 +832,8 @@ class SchedulerOutputProcessorMixin:
completion_tokens.append(len(output_ids_)) completion_tokens.append(len(output_ids_))
cached_tokens.append(req.cached_tokens) cached_tokens.append(req.cached_tokens)
retraction_counts.append(req.retraction_count)
if not self.spec_algorithm.is_none(): if not self.spec_algorithm.is_none():
spec_verify_ct.append(req.spec_verify_ct) spec_verify_ct.append(req.spec_verify_ct)
spec_accepted_tokens.append(req.spec_accepted_tokens) spec_accepted_tokens.append(req.spec_accepted_tokens)
...@@ -953,6 +956,7 @@ class SchedulerOutputProcessorMixin: ...@@ -953,6 +956,7 @@ class SchedulerOutputProcessorMixin:
http_worker_ipcs=http_worker_ipcs, http_worker_ipcs=http_worker_ipcs,
placeholder_tokens_idx=None, placeholder_tokens_idx=None,
placeholder_tokens_val=None, placeholder_tokens_val=None,
retraction_counts=retraction_counts,
) )
) )
...@@ -964,6 +968,7 @@ class SchedulerOutputProcessorMixin: ...@@ -964,6 +968,7 @@ class SchedulerOutputProcessorMixin:
embeddings = [] embeddings = []
prompt_tokens = [] prompt_tokens = []
cached_tokens = [] cached_tokens = []
retraction_counts = []
for req in reqs: for req in reqs:
if req.finished(): if req.finished():
rids.append(req.rid) rids.append(req.rid)
...@@ -972,6 +977,7 @@ class SchedulerOutputProcessorMixin: ...@@ -972,6 +977,7 @@ class SchedulerOutputProcessorMixin:
embeddings.append(req.embedding) embeddings.append(req.embedding)
prompt_tokens.append(len(req.origin_input_ids)) prompt_tokens.append(len(req.origin_input_ids))
cached_tokens.append(req.cached_tokens) cached_tokens.append(req.cached_tokens)
retraction_counts.append(req.retraction_count)
self.send_to_detokenizer.send_output( self.send_to_detokenizer.send_output(
BatchEmbeddingOutput( BatchEmbeddingOutput(
finished_reasons, finished_reasons,
...@@ -982,5 +988,6 @@ class SchedulerOutputProcessorMixin: ...@@ -982,5 +988,6 @@ class SchedulerOutputProcessorMixin:
http_worker_ipcs=http_worker_ipcs, http_worker_ipcs=http_worker_ipcs,
placeholder_tokens_idx=None, placeholder_tokens_idx=None,
placeholder_tokens_val=None, placeholder_tokens_val=None,
retraction_counts=retraction_counts,
) )
) )
...@@ -1390,6 +1390,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1390,6 +1390,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
"finish_reason": recv_obj.finished_reasons[i], "finish_reason": recv_obj.finished_reasons[i],
"prompt_tokens": recv_obj.prompt_tokens[i], "prompt_tokens": recv_obj.prompt_tokens[i],
"weight_version": self.server_args.weight_version, "weight_version": self.server_args.weight_version,
"total_retractions": recv_obj.retraction_counts[i],
} }
if getattr(state.obj, "return_logprob", False): if getattr(state.obj, "return_logprob", False):
...@@ -1682,6 +1683,14 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1682,6 +1683,14 @@ class TokenizerManager(TokenizerCommunicatorMixin):
or state.obj.sampling_params.get("ebnf", None) or state.obj.sampling_params.get("ebnf", None)
or state.obj.sampling_params.get("structural_tag", None) or state.obj.sampling_params.get("structural_tag", None)
) )
retraction_count = (
recv_obj.retraction_counts[i]
if getattr(recv_obj, "retraction_counts", None)
and i < len(recv_obj.retraction_counts)
else 0
)
self.metrics_collector.observe_one_finished_request( self.metrics_collector.observe_one_finished_request(
labels, labels,
recv_obj.prompt_tokens[i], recv_obj.prompt_tokens[i],
...@@ -1689,6 +1698,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1689,6 +1698,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
recv_obj.cached_tokens[i], recv_obj.cached_tokens[i],
state.finished_time - state.created_time, state.finished_time - state.created_time,
has_grammar, has_grammar,
retraction_count,
) )
def dump_requests(self, state: ReqState, out_dict: dict): def dump_requests(self, state: ReqState, out_dict: dict):
......
...@@ -811,6 +811,34 @@ class TokenizerMetricsCollector: ...@@ -811,6 +811,34 @@ class TokenizerMetricsCollector:
buckets=bucket_e2e_request_latency, buckets=bucket_e2e_request_latency,
) )
# Retraction count histogram
self.num_retractions = Histogram(
name="sglang:num_retractions",
documentation="Histogram of retraction counts per request.",
labelnames=labels.keys(),
buckets=[
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
15,
20,
25,
30,
40,
50,
75,
100,
],
)
def observe_one_finished_request( def observe_one_finished_request(
self, self,
labels: Dict[str, str], labels: Dict[str, str],
...@@ -819,6 +847,7 @@ class TokenizerMetricsCollector: ...@@ -819,6 +847,7 @@ class TokenizerMetricsCollector:
cached_tokens: int, cached_tokens: int,
e2e_latency: float, e2e_latency: float,
has_grammar: bool, has_grammar: bool,
retraction_count: int,
): ):
self.prompt_tokens_total.labels(**labels).inc(prompt_tokens) self.prompt_tokens_total.labels(**labels).inc(prompt_tokens)
self.generation_tokens_total.labels(**labels).inc(generation_tokens) self.generation_tokens_total.labels(**labels).inc(generation_tokens)
...@@ -833,6 +862,7 @@ class TokenizerMetricsCollector: ...@@ -833,6 +862,7 @@ class TokenizerMetricsCollector:
self.generation_tokens_histogram.labels(**labels).observe( self.generation_tokens_histogram.labels(**labels).observe(
float(generation_tokens) float(generation_tokens)
) )
self.num_retractions.labels(**labels).observe(retraction_count)
def observe_time_to_first_token(self, labels: Dict[str, str], value: float): def observe_time_to_first_token(self, labels: Dict[str, str], value: float):
self.histogram_time_to_first_token.labels(**labels).observe(value) self.histogram_time_to_first_token.labels(**labels).observe(value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment