[Metrics][MFU] Fix UnembedMetrics FLOP overcounting for prefill (#33045) (#33045)
Fix UnembedMetrics to correctly count FLOPs for the unembedding (LM head) layer.
The bug: UnembedMetrics used total_num_tokens() which counts all tokens in the
batch for projection flops, vocab projections are run on just the last token for the
autoregressive use case.
Co-authored-by:
Omar Mohamed Khalil <omarkhalil@meta.com>
Showing
Please register or sign in to comment