Unverified Commit 55b14656 authored by Scott Lee's avatar Scott Lee Committed by GitHub
Browse files

Revert "Add metrics for speculative decoding (acceptance rate, average acceptance length)" (#11433)

parent b4408e60
...@@ -233,7 +233,6 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): ...@@ -233,7 +233,6 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
completion_tokens=recv_obj.completion_tokens, completion_tokens=recv_obj.completion_tokens,
cached_tokens=recv_obj.cached_tokens, cached_tokens=recv_obj.cached_tokens,
spec_verify_ct=recv_obj.spec_verify_ct, spec_verify_ct=recv_obj.spec_verify_ct,
spec_accepted_tokens=recv_obj.spec_accepted_tokens,
input_token_logprobs_val=recv_obj.input_token_logprobs_val, input_token_logprobs_val=recv_obj.input_token_logprobs_val,
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx, input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
output_token_logprobs_val=recv_obj.output_token_logprobs_val, output_token_logprobs_val=recv_obj.output_token_logprobs_val,
......
...@@ -816,7 +816,6 @@ class BatchTokenIDOutput(BaseBatchReq): ...@@ -816,7 +816,6 @@ class BatchTokenIDOutput(BaseBatchReq):
completion_tokens: List[int] completion_tokens: List[int]
cached_tokens: List[int] cached_tokens: List[int]
spec_verify_ct: List[int] spec_verify_ct: List[int]
spec_accepted_tokens: List[int]
# Logprobs # Logprobs
input_token_logprobs_val: List[float] input_token_logprobs_val: List[float]
...@@ -883,7 +882,6 @@ class BatchStrOutput(BaseBatchReq): ...@@ -883,7 +882,6 @@ class BatchStrOutput(BaseBatchReq):
completion_tokens: List[int] completion_tokens: List[int]
cached_tokens: List[int] cached_tokens: List[int]
spec_verify_ct: List[int] spec_verify_ct: List[int]
spec_accepted_tokens: List[int]
# Logprobs # Logprobs
input_token_logprobs_val: List[float] input_token_logprobs_val: List[float]
......
...@@ -246,11 +246,6 @@ def _handle_output_by_index(output, i): ...@@ -246,11 +246,6 @@ def _handle_output_by_index(output, i):
spec_verify_ct=( spec_verify_ct=(
[output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None [output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
), ),
spec_accepted_tokens=(
[output.spec_accepted_tokens[i]]
if len(output.spec_accepted_tokens) > i
else None
),
input_token_logprobs_val=( input_token_logprobs_val=(
[output.input_token_logprobs_val[i]] [output.input_token_logprobs_val[i]]
if output.input_token_logprobs_val if output.input_token_logprobs_val
......
...@@ -631,10 +631,6 @@ class Req: ...@@ -631,10 +631,6 @@ class Req:
# This is used to compute the average acceptance length per request. # This is used to compute the average acceptance length per request.
self.spec_verify_ct = 0 self.spec_verify_ct = 0
# The number of accepted tokens in speculative decoding for this request.
# This is used to compute the acceptance rate and average acceptance length per request.
self.spec_accepted_tokens = 0
# For metrics # For metrics
self.metrics_collector = metrics_collector self.metrics_collector = metrics_collector
self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode) self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
......
...@@ -216,24 +216,14 @@ class SchedulerMetricsMixin: ...@@ -216,24 +216,14 @@ class SchedulerMetricsMixin:
if self.spec_algorithm.is_none(): if self.spec_algorithm.is_none():
spec_accept_length = 0 spec_accept_length = 0
spec_accept_rate = 0
else: else:
spec_accept_length = ( spec_accept_length = (
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
) )
# Calculate acceptance rate: accepted tokens / total draft tokens
total_draft_tokens = self.spec_num_total_forward_ct * (
self.server_args.speculative_num_steps or 1
)
spec_accept_rate = (
self.spec_num_total_accepted_tokens / total_draft_tokens
if total_draft_tokens > 0
else 0
)
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
self.cum_spec_accept_count += self.spec_num_total_forward_ct self.cum_spec_accept_count += self.spec_num_total_forward_ct
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0 self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
msg += f"accept len: {spec_accept_length:.2f}, accept rate: {spec_accept_rate:.2f}, " msg += f"accept len: {spec_accept_length:.2f}, "
cache_hit_rate = 0.0 cache_hit_rate = 0.0
if self.disaggregation_mode == DisaggregationMode.DECODE: if self.disaggregation_mode == DisaggregationMode.DECODE:
...@@ -261,9 +251,6 @@ class SchedulerMetricsMixin: ...@@ -261,9 +251,6 @@ class SchedulerMetricsMixin:
self.stats.num_queue_reqs = len(self.waiting_queue) self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.num_grammar_queue_reqs = len(self.grammar_queue) self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
self.stats.cache_hit_rate = cache_hit_rate self.stats.cache_hit_rate = cache_hit_rate
# Speculative decoding
self.stats.spec_accept_rate = spec_accept_rate
self.stats.spec_accept_length = spec_accept_length self.stats.spec_accept_length = spec_accept_length
# Retract # Retract
......
...@@ -634,7 +634,6 @@ class SchedulerOutputProcessorMixin: ...@@ -634,7 +634,6 @@ class SchedulerOutputProcessorMixin:
completion_tokens = [] completion_tokens = []
cached_tokens = [] cached_tokens = []
spec_verify_ct = [] spec_verify_ct = []
spec_accepted_tokens = []
output_hidden_states = None output_hidden_states = None
if return_logprob: if return_logprob:
...@@ -726,7 +725,6 @@ class SchedulerOutputProcessorMixin: ...@@ -726,7 +725,6 @@ class SchedulerOutputProcessorMixin:
if not self.spec_algorithm.is_none(): if not self.spec_algorithm.is_none():
spec_verify_ct.append(req.spec_verify_ct) spec_verify_ct.append(req.spec_verify_ct)
spec_accepted_tokens.append(req.spec_accepted_tokens)
if return_logprob: if return_logprob:
if ( if (
...@@ -827,7 +825,6 @@ class SchedulerOutputProcessorMixin: ...@@ -827,7 +825,6 @@ class SchedulerOutputProcessorMixin:
completion_tokens, completion_tokens,
cached_tokens, cached_tokens,
spec_verify_ct, spec_verify_ct,
spec_accepted_tokens,
input_token_logprobs_val, input_token_logprobs_val,
input_token_logprobs_idx, input_token_logprobs_idx,
output_token_logprobs_val, output_token_logprobs_val,
......
...@@ -1394,36 +1394,6 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1394,36 +1394,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
if state.finished: if state.finished:
if self.server_args.speculative_algorithm: if self.server_args.speculative_algorithm:
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i] meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
if (
recv_obj.spec_verify_ct[i] > 0
and self.server_args.speculative_num_steps is not None
and not isinstance(recv_obj, BatchEmbeddingOutput)
and hasattr(recv_obj, "spec_accepted_tokens")
# Checks that `spec_accepted_tokens[i]` will exist.
and len(recv_obj.spec_accepted_tokens) > i
):
total_draft_tokens = (
recv_obj.spec_verify_ct[i]
* self.server_args.speculative_num_steps
)
accepted_tokens = recv_obj.spec_accepted_tokens[i]
# Calculate per-request acceptance rate and average acceptance length.
if total_draft_tokens > 0:
# Calculate acceptance rate: accepted / (steps * lookahead)
meta_info["spec_accept_rate"] = (
accepted_tokens / total_draft_tokens
)
meta_info["spec_accept_length"] = (
recv_obj.completion_tokens[i]
/ recv_obj.spec_verify_ct[i]
)
else:
meta_info["spec_accept_rate"] = 0.0
meta_info["spec_accept_length"] = 0
else:
meta_info["spec_acceptance_rate"] = 0.0
meta_info["spec_accept_length"] = 0
state.finished_time = time.time() state.finished_time = time.time()
meta_info["e2e_latency"] = state.finished_time - state.created_time meta_info["e2e_latency"] = state.finished_time - state.created_time
......
...@@ -127,7 +127,6 @@ class SchedulerStats: ...@@ -127,7 +127,6 @@ class SchedulerStats:
# Speculative decoding # Speculative decoding
spec_accept_length: float = 0.0 spec_accept_length: float = 0.0
spec_accept_rate: float = 0.0
# Retract # Retract
num_retracted_reqs: int = 0 num_retracted_reqs: int = 0
...@@ -221,12 +220,6 @@ class SchedulerMetricsCollector: ...@@ -221,12 +220,6 @@ class SchedulerMetricsCollector:
labelnames=labels.keys(), labelnames=labels.keys(),
multiprocess_mode="mostrecent", multiprocess_mode="mostrecent",
) )
self.spec_accept_rate = Gauge(
name="sglang:spec_accept_rate",
documentation="The average acceptance rate of speculative decoding (`accepted tokens / total draft tokens` in batch).",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
# Retract # Retract
self.num_retracted_reqs = Gauge( self.num_retracted_reqs = Gauge(
...@@ -527,7 +520,6 @@ class SchedulerMetricsCollector: ...@@ -527,7 +520,6 @@ class SchedulerMetricsCollector:
# Speculative decoding # Speculative decoding
self._log_gauge(self.spec_accept_length, stats.spec_accept_length) self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
self._log_gauge(self.spec_accept_rate, stats.spec_accept_rate)
# PD disaggregation # PD disaggregation
self._log_gauge( self._log_gauge(
......
...@@ -378,13 +378,6 @@ class EagleVerifyInput(SpecInput): ...@@ -378,13 +378,6 @@ class EagleVerifyInput(SpecInput):
unfinished_accept_index.append(accept_index[i]) unfinished_accept_index.append(accept_index[i])
req.spec_verify_ct += 1 req.spec_verify_ct += 1
# For each request, accumulate # of accepted tokens for this verify pass.
accept_length_this_pass = (accept_index != -1).sum(dim=1) - 1
for i, (req, accepted_count) in enumerate(
zip(batch.reqs, accept_length_this_pass.tolist())
):
req.spec_accepted_tokens += accepted_count
if has_finished: if has_finished:
accept_length = (accept_index != -1).sum(dim=1) - 1 accept_length = (accept_index != -1).sum(dim=1) - 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment