"tests/pytorch/test_unpin_tensoradapter.py" did not exist on "65b0b9e8c3161605b77841200a87d1a0ac4abefc"
Unverified Commit 55b14656 authored by Scott Lee's avatar Scott Lee Committed by GitHub
Browse files

Revert "Add metrics for speculative decoding (acceptance rate, average acceptance length)" (#11433)

parent b4408e60
......@@ -233,7 +233,6 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
completion_tokens=recv_obj.completion_tokens,
cached_tokens=recv_obj.cached_tokens,
spec_verify_ct=recv_obj.spec_verify_ct,
spec_accepted_tokens=recv_obj.spec_accepted_tokens,
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
output_token_logprobs_val=recv_obj.output_token_logprobs_val,
......
......@@ -816,7 +816,6 @@ class BatchTokenIDOutput(BaseBatchReq):
completion_tokens: List[int]
cached_tokens: List[int]
spec_verify_ct: List[int]
spec_accepted_tokens: List[int]
# Logprobs
input_token_logprobs_val: List[float]
......@@ -883,7 +882,6 @@ class BatchStrOutput(BaseBatchReq):
completion_tokens: List[int]
cached_tokens: List[int]
spec_verify_ct: List[int]
spec_accepted_tokens: List[int]
# Logprobs
input_token_logprobs_val: List[float]
......
......@@ -246,11 +246,6 @@ def _handle_output_by_index(output, i):
spec_verify_ct=(
[output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
),
spec_accepted_tokens=(
[output.spec_accepted_tokens[i]]
if len(output.spec_accepted_tokens) > i
else None
),
input_token_logprobs_val=(
[output.input_token_logprobs_val[i]]
if output.input_token_logprobs_val
......
......@@ -631,10 +631,6 @@ class Req:
# This is used to compute the average acceptance length per request.
self.spec_verify_ct = 0
# The number of accepted tokens in speculative decoding for this request.
# This is used to compute the acceptance rate and average acceptance length per request.
self.spec_accepted_tokens = 0
# For metrics
self.metrics_collector = metrics_collector
self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
......
......@@ -216,24 +216,14 @@ class SchedulerMetricsMixin:
if self.spec_algorithm.is_none():
spec_accept_length = 0
spec_accept_rate = 0
else:
spec_accept_length = (
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
)
# Calculate acceptance rate: accepted tokens / total draft tokens
total_draft_tokens = self.spec_num_total_forward_ct * (
self.server_args.speculative_num_steps or 1
)
spec_accept_rate = (
self.spec_num_total_accepted_tokens / total_draft_tokens
if total_draft_tokens > 0
else 0
)
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
self.cum_spec_accept_count += self.spec_num_total_forward_ct
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
msg += f"accept len: {spec_accept_length:.2f}, accept rate: {spec_accept_rate:.2f}, "
msg += f"accept len: {spec_accept_length:.2f}, "
cache_hit_rate = 0.0
if self.disaggregation_mode == DisaggregationMode.DECODE:
......@@ -261,9 +251,6 @@ class SchedulerMetricsMixin:
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
self.stats.cache_hit_rate = cache_hit_rate
# Speculative decoding
self.stats.spec_accept_rate = spec_accept_rate
self.stats.spec_accept_length = spec_accept_length
# Retract
......
......@@ -634,7 +634,6 @@ class SchedulerOutputProcessorMixin:
completion_tokens = []
cached_tokens = []
spec_verify_ct = []
spec_accepted_tokens = []
output_hidden_states = None
if return_logprob:
......@@ -726,7 +725,6 @@ class SchedulerOutputProcessorMixin:
if not self.spec_algorithm.is_none():
spec_verify_ct.append(req.spec_verify_ct)
spec_accepted_tokens.append(req.spec_accepted_tokens)
if return_logprob:
if (
......@@ -827,7 +825,6 @@ class SchedulerOutputProcessorMixin:
completion_tokens,
cached_tokens,
spec_verify_ct,
spec_accepted_tokens,
input_token_logprobs_val,
input_token_logprobs_idx,
output_token_logprobs_val,
......
......@@ -1394,36 +1394,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
if state.finished:
if self.server_args.speculative_algorithm:
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
if (
recv_obj.spec_verify_ct[i] > 0
and self.server_args.speculative_num_steps is not None
and not isinstance(recv_obj, BatchEmbeddingOutput)
and hasattr(recv_obj, "spec_accepted_tokens")
# Checks that `spec_accepted_tokens[i]` will exist.
and len(recv_obj.spec_accepted_tokens) > i
):
total_draft_tokens = (
recv_obj.spec_verify_ct[i]
* self.server_args.speculative_num_steps
)
accepted_tokens = recv_obj.spec_accepted_tokens[i]
# Calculate per-request acceptance rate and average acceptance length.
if total_draft_tokens > 0:
# Calculate acceptance rate: accepted / (steps * lookahead)
meta_info["spec_accept_rate"] = (
accepted_tokens / total_draft_tokens
)
meta_info["spec_accept_length"] = (
recv_obj.completion_tokens[i]
/ recv_obj.spec_verify_ct[i]
)
else:
meta_info["spec_accept_rate"] = 0.0
meta_info["spec_accept_length"] = 0
else:
meta_info["spec_acceptance_rate"] = 0.0
meta_info["spec_accept_length"] = 0
state.finished_time = time.time()
meta_info["e2e_latency"] = state.finished_time - state.created_time
......
......@@ -127,7 +127,6 @@ class SchedulerStats:
# Speculative decoding
spec_accept_length: float = 0.0
spec_accept_rate: float = 0.0
# Retract
num_retracted_reqs: int = 0
......@@ -221,12 +220,6 @@ class SchedulerMetricsCollector:
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
self.spec_accept_rate = Gauge(
name="sglang:spec_accept_rate",
documentation="The average acceptance rate of speculative decoding (`accepted tokens / total draft tokens` in batch).",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
# Retract
self.num_retracted_reqs = Gauge(
......@@ -527,7 +520,6 @@ class SchedulerMetricsCollector:
# Speculative decoding
self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
self._log_gauge(self.spec_accept_rate, stats.spec_accept_rate)
# PD disaggregation
self._log_gauge(
......
......@@ -378,13 +378,6 @@ class EagleVerifyInput(SpecInput):
unfinished_accept_index.append(accept_index[i])
req.spec_verify_ct += 1
# For each request, accumulate # of accepted tokens for this verify pass.
accept_length_this_pass = (accept_index != -1).sum(dim=1) - 1
for i, (req, accepted_count) in enumerate(
zip(batch.reqs, accept_length_this_pass.tolist())
):
req.spec_accepted_tokens += accepted_count
if has_finished:
accept_length = (accept_index != -1).sum(dim=1) - 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment