Unverified Commit 5ec44056 authored by omkhalil's avatar omkhalil Committed by GitHub
Browse files

[Metrics][MFU] Fix UnembedMetrics FLOP overcounting for prefill (#33045) (#33045)



Fix UnembedMetrics to correctly count FLOPs for the unembedding (LM head) layer.

The bug: UnembedMetrics used total_num_tokens() which counts all tokens in the
batch for projection flops, vocab projections are run on just the last token for the
autoregressive use case.
Co-authored-by: default avatarOmar Mohamed Khalil <omarkhalil@meta.com>
parent 492a7983
...@@ -110,6 +110,14 @@ class ExecutionContext: ...@@ -110,6 +110,14 @@ class ExecutionContext:
"""Total sum of (num_tokens * context_len) across all requests.""" """Total sum of (num_tokens * context_len) across all requests."""
return self.prefill_token_context_product + self.decode_token_context_product return self.prefill_token_context_product + self.decode_token_context_product
def num_logits_tokens(self) -> int:
"""Number of tokens that require logits computation (unembedding).
For prefill, only the last token per request needs logits.
For decode, all tokens need logits.
"""
return self.num_prefill_requests + self.decode_num_tokens
@classmethod @classmethod
def from_single_request( def from_single_request(
cls, num_tokens: int, context_len: int, is_prefill: bool cls, num_tokens: int, context_len: int, is_prefill: bool
...@@ -906,7 +914,7 @@ class UnembedMetrics(ComponentMetrics): ...@@ -906,7 +914,7 @@ class UnembedMetrics(ComponentMetrics):
) -> dict[str, int]: ) -> dict[str, int]:
"""Calculate flops breakdown for unembedding layer.""" """Calculate flops breakdown for unembedding layer."""
D, V = self.hidden_size, self.vocab_size D, V = self.hidden_size, self.vocab_size
T = ctx.total_num_tokens() T = ctx.num_logits_tokens()
if per_gpu: if per_gpu:
V //= self.tp_size V //= self.tp_size
...@@ -920,7 +928,7 @@ class UnembedMetrics(ComponentMetrics): ...@@ -920,7 +928,7 @@ class UnembedMetrics(ComponentMetrics):
) -> dict[str, int]: ) -> dict[str, int]:
"""Calculate read memory traffic for unembedding layer.""" """Calculate read memory traffic for unembedding layer."""
D, V = self.hidden_size, self.vocab_size D, V = self.hidden_size, self.vocab_size
T = ctx.total_num_tokens() T = ctx.num_logits_tokens()
if per_gpu: if per_gpu:
V //= self.tp_size V //= self.tp_size
...@@ -935,7 +943,7 @@ class UnembedMetrics(ComponentMetrics): ...@@ -935,7 +943,7 @@ class UnembedMetrics(ComponentMetrics):
) -> dict[str, int]: ) -> dict[str, int]:
"""Calculate write memory traffic for unembedding layer.""" """Calculate write memory traffic for unembedding layer."""
V = self.vocab_size V = self.vocab_size
T = ctx.total_num_tokens() T = ctx.num_logits_tokens()
if per_gpu: if per_gpu:
V //= self.tp_size V //= self.tp_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment