-
omkhalil authored
Fix UnembedMetrics to correctly count FLOPs for the unembedding (LM head) layer. The bug: UnembedMetrics used total_num_tokens() which counts all tokens in the batch for projection flops, vocab projections are run on just the last token for the autoregressive use case. Co-authored-by:Omar Mohamed Khalil <omarkhalil@meta.com>
5ec44056