Unverified Commit bf8f7a94 authored by Scott Lee's avatar Scott Lee Committed by GitHub
Browse files

Add per-request retraction count (#11177)

parent 81a632ac
......@@ -275,6 +275,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
output_hidden_states=recv_obj.output_hidden_states,
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
retraction_counts=recv_obj.retraction_counts,
token_steps=recv_obj.token_steps,
)
......
......@@ -860,6 +860,9 @@ class BatchTokenIDOutput(BaseBatchReq):
placeholder_tokens_idx: List[Optional[List[int]]]
placeholder_tokens_val: List[Optional[List[int]]]
# Number of times each request was retracted.
retraction_counts: List[int]
# The trainer step id. Used to know which step's weights are used for sampling.
token_steps: List[List[int]] = None
......@@ -936,6 +939,9 @@ class BatchStrOutput(BaseBatchReq):
placeholder_tokens_idx: List[Optional[List[int]]]
placeholder_tokens_val: List[Optional[List[int]]]
# Number of times each request was retracted.
retraction_counts: List[int]
# The trainer step id. Used to know which step's weights are used for sampling.
token_steps: List[List[int]] = None
......@@ -978,6 +984,9 @@ class BatchEmbeddingOutput(BaseBatchReq):
placeholder_tokens_idx: List[Optional[List[int]]]
placeholder_tokens_val: List[Optional[List[int]]]
# Number of times each request was retracted.
retraction_counts: List[int]
@dataclass
class ClearHiCacheReqInput(BaseReq):
......
......@@ -334,6 +334,11 @@ def _handle_output_by_index(output, i):
),
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
retraction_counts=(
[output.retraction_counts[i]]
if len(output.retraction_counts) > i
else None
),
token_steps=([output.token_steps[i]] if output.token_steps else None),
)
elif isinstance(output, BatchMultimodalOutput):
......
......@@ -623,6 +623,9 @@ class Req:
# This is used to compute the acceptance rate and average acceptance length per request.
self.spec_accepted_tokens = 0
# The number of times this request has been retracted / preempted.
self.retraction_count = 0
# For metrics
self.metrics_collector = metrics_collector
self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
......@@ -883,6 +886,10 @@ class Req:
return
def reset_for_retract(self):
# Increment retraction count before resetting other state. We should not reset this
# since we are tracking the total number of retractions for each request.
self.retraction_count += 1
self.prefix_indices = torch.empty((0,), dtype=torch.int64)
self.last_node = None
self.swa_uuid_for_lock = None
......
......@@ -730,6 +730,7 @@ class SchedulerOutputProcessorMixin:
cached_tokens = []
spec_verify_ct = []
spec_accepted_tokens = []
retraction_counts = []
output_hidden_states = None
if return_logprob:
......@@ -831,6 +832,8 @@ class SchedulerOutputProcessorMixin:
completion_tokens.append(len(output_ids_))
cached_tokens.append(req.cached_tokens)
retraction_counts.append(req.retraction_count)
if not self.spec_algorithm.is_none():
spec_verify_ct.append(req.spec_verify_ct)
spec_accepted_tokens.append(req.spec_accepted_tokens)
......@@ -953,6 +956,7 @@ class SchedulerOutputProcessorMixin:
http_worker_ipcs=http_worker_ipcs,
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
retraction_counts=retraction_counts,
)
)
......@@ -964,6 +968,7 @@ class SchedulerOutputProcessorMixin:
embeddings = []
prompt_tokens = []
cached_tokens = []
retraction_counts = []
for req in reqs:
if req.finished():
rids.append(req.rid)
......@@ -972,6 +977,7 @@ class SchedulerOutputProcessorMixin:
embeddings.append(req.embedding)
prompt_tokens.append(len(req.origin_input_ids))
cached_tokens.append(req.cached_tokens)
retraction_counts.append(req.retraction_count)
self.send_to_detokenizer.send_output(
BatchEmbeddingOutput(
finished_reasons,
......@@ -982,5 +988,6 @@ class SchedulerOutputProcessorMixin:
http_worker_ipcs=http_worker_ipcs,
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
retraction_counts=retraction_counts,
)
)
......@@ -1390,6 +1390,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
"finish_reason": recv_obj.finished_reasons[i],
"prompt_tokens": recv_obj.prompt_tokens[i],
"weight_version": self.server_args.weight_version,
"total_retractions": recv_obj.retraction_counts[i],
}
if getattr(state.obj, "return_logprob", False):
......@@ -1682,6 +1683,14 @@ class TokenizerManager(TokenizerCommunicatorMixin):
or state.obj.sampling_params.get("ebnf", None)
or state.obj.sampling_params.get("structural_tag", None)
)
retraction_count = (
recv_obj.retraction_counts[i]
if getattr(recv_obj, "retraction_counts", None)
and i < len(recv_obj.retraction_counts)
else 0
)
self.metrics_collector.observe_one_finished_request(
labels,
recv_obj.prompt_tokens[i],
......@@ -1689,6 +1698,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
recv_obj.cached_tokens[i],
state.finished_time - state.created_time,
has_grammar,
retraction_count,
)
def dump_requests(self, state: ReqState, out_dict: dict):
......
......@@ -811,6 +811,34 @@ class TokenizerMetricsCollector:
buckets=bucket_e2e_request_latency,
)
# Retraction count histogram
self.num_retractions = Histogram(
name="sglang:num_retractions",
documentation="Histogram of retraction counts per request.",
labelnames=labels.keys(),
buckets=[
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
15,
20,
25,
30,
40,
50,
75,
100,
],
)
def observe_one_finished_request(
self,
labels: Dict[str, str],
......@@ -819,6 +847,7 @@ class TokenizerMetricsCollector:
cached_tokens: int,
e2e_latency: float,
has_grammar: bool,
retraction_count: int,
):
self.prompt_tokens_total.labels(**labels).inc(prompt_tokens)
self.generation_tokens_total.labels(**labels).inc(generation_tokens)
......@@ -833,6 +862,7 @@ class TokenizerMetricsCollector:
self.generation_tokens_histogram.labels(**labels).observe(
float(generation_tokens)
)
self.num_retractions.labels(**labels).observe(retraction_count)
def observe_time_to_first_token(self, labels: Dict[str, str], value: float):
self.histogram_time_to_first_token.labels(**labels).observe(value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment