Unverified Commit b5cbe8ee authored by Rajveer Bachkaniwala's avatar Rajveer Bachkaniwala Committed by GitHub
Browse files

[Bugfix] Last token measurement fix (#11376)


Signed-off-by: default avatarrajveerb <46040700+rajveerb@users.noreply.github.com>
Co-authored-by: default avatarRoger Wang <136131678+ywang96@users.noreply.github.com>
parent df04dffa
......@@ -1124,6 +1124,8 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
if not seq_group.is_prefill():
seq_group.set_last_token_time(now)
request_output = RequestOutputFactory.create(
seq_group,
self.seq_id_to_seq_group,
......@@ -1166,6 +1168,8 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
if not seq_group.is_prefill():
seq_group.set_last_token_time(now)
request_output = RequestOutputFactory.create(
seq_group,
self.seq_id_to_seq_group,
......@@ -1686,7 +1690,7 @@ class LLMEngine:
# If the seq_group just finished the prefill state
# get TTFT.
if not seq_group.is_prefill():
latency = seq_group.get_last_latency(now)
latency = seq_group.get_last_token_latency()
time_to_first_tokens_iter.append(latency)
# One generation token per finished prefill.
......@@ -1694,7 +1698,7 @@ class LLMEngine:
seq_group.num_seqs())
else:
# TPOTs.
latency = seq_group.get_last_latency(now)
latency = seq_group.get_last_token_latency()
time_per_output_tokens_iter.append(latency)
if seq_group.state.current_step == 0:
# For async_output_proc, the do_log_stats()
......
......@@ -667,6 +667,7 @@ class SequenceGroup:
first_scheduled_time=None,
first_token_time=None,
time_in_queue=None)
self.last_token_latency = 0.0
self.lora_request = lora_request
self.prompt_logprobs: Optional[PromptLogprobs] = None
self.state = SequenceGroupState()
......@@ -762,18 +763,21 @@ class SequenceGroup:
assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill
self.init_multi_step(num_steps=num_lookahead_slots + 1)
def get_last_latency(self, now: float) -> float:
def set_last_token_time(self, now: float) -> None:
"""Sets the last token time for Request level timings."""
# If still in prefill phase, raise Error.
if self.is_prefill():
raise ValueError(
"seq_group.get_last_latency() should not be called "
"if the seq_group is in prefill phase.")
# Otherwise return token latency.
latency = now - self.metrics.last_token_time
# If still in prefill phase, assertion fails.
assert not self.is_prefill(), (
"seq_group.set_last_token_time() should not be called "
"if the seq_group is in prefill phase.")
self.last_token_latency = now - self.metrics.last_token_time
self.metrics.last_token_time = now
return latency
def get_last_token_latency(self) -> float:
"""Returns the latency of the last token."""
assert not self.is_prefill(), (
"seq_group.get_last_token_latency() should not be called "
"if the seq_group is in prefill phase.")
return self.last_token_latency
def maybe_set_first_token_time(self, time: float) -> None:
"""Sets the first token time for Request level timings."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment