Unverified Commit a0b782f9 authored by SungMinCho's avatar SungMinCho Committed by GitHub
Browse files

[Metrics] Model FLOPs Utilization estimation (#30738)


Signed-off-by: default avatarSungMinCho <tjdals4565@gmail.com>
Signed-off-by: default avatarMark McLoughlin <markmc@redhat.com>
Co-authored-by: default avatarMark McLoughlin <markmc@redhat.com>
parent ed2897f3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for the analytic estimators in metrics/flops.py.
"""
import types
from types import SimpleNamespace
from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config
from transformers.models.llama4.configuration_llama4 import (
Llama4Config,
Llama4TextConfig,
)
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig
from vllm.config.model import ModelConfig, get_hf_text_config
from vllm.v1.metrics.perf import (
AttentionMetrics,
BaseConfigParser,
ExecutionContext,
FfnMetrics,
ModelMetrics,
ParsedArgs,
UnembedMetrics,
)
class MockModelConfig:
"""Mock ModelConfig that implements the getter methods used by parsers."""
def __init__(self, hf_config, dtype):
self.hf_config = hf_config
self.hf_text_config = get_hf_text_config(hf_config)
self.dtype = dtype
self.is_attention_free = False
def __getattr__(self, name):
# 1. Check if ModelConfig actually has this attribute
if not hasattr(ModelConfig, name):
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}' "
f"and neither does 'ModelConfig'."
)
# 2. Fetch the attribute from the ModelConfig CLASS
attr = getattr(ModelConfig, name)
# 3. Case A: It is a @property
if isinstance(attr, property):
# Manually invoke the property's getter, passing 'self' (this mock instance)
return attr.__get__(self, self.__class__)
# 4. Case B: It is a standard method (function)
if isinstance(attr, types.FunctionType):
# Bind the function to 'self' so it acts like a method of
# this instance. This creates a bound method where 'self' is
# automatically passed as the first arg.
return types.MethodType(attr, self)
# 5. Case C: It is a class attribute / static variable
return attr
def create_mock_vllm_config(
hf_config,
model_dtype="bfloat16",
cache_dtype="auto",
quant_config=None,
data_parallel_size=1,
tensor_parallel_size=1,
pipeline_parallel_size=1,
enable_expert_parallel=False,
) -> SimpleNamespace:
vllm_config = SimpleNamespace()
vllm_config.model_config = MockModelConfig(hf_config, model_dtype)
vllm_config.cache_config = SimpleNamespace()
vllm_config.cache_config.cache_dtype = cache_dtype
vllm_config.quant_config = quant_config
vllm_config.parallel_config = SimpleNamespace()
vllm_config.parallel_config.data_parallel_size = data_parallel_size
vllm_config.parallel_config.tensor_parallel_size = tensor_parallel_size
vllm_config.parallel_config.pipeline_parallel_size = pipeline_parallel_size
vllm_config.parallel_config.enable_expert_parallel = enable_expert_parallel
return vllm_config
#### Parser Tests ####
def test_base_config_parser():
"""Test BaseConfigParser extracts base model attributes correctly."""
hf_config = Qwen3Config(
vocab_size=50000,
hidden_size=2048,
num_attention_heads=16,
num_hidden_layers=24,
)
vllm_config = create_mock_vllm_config(hf_config, model_dtype="float16")
parser = BaseConfigParser()
args = ParsedArgs()
result = parser.parse(args, vllm_config)
assert result.vocab_size == 50000
assert result.hidden_size == 2048
assert result.num_attention_heads == 16
assert result.num_hidden_layers == 24
assert result.weight_byte_size == 2 # float16 is 2 bytes
assert result.activation_byte_size == 2 # default activation size
def test_base_attention_config_parser_with_gqa():
"""Test BaseAttentionConfigParser with grouped query attention."""
hf_config = Qwen3Config(
hidden_size=4096,
num_attention_heads=32,
num_key_value_heads=8, # GQA with 4:1 ratio
head_dim=128,
)
vllm_config = create_mock_vllm_config(hf_config)
parser_chain = AttentionMetrics.get_parser()
result = parser_chain.parse(vllm_config)
assert result.num_key_value_heads == 8
assert result.head_dim == 128
def test_base_attention_config_parser_without_gqa():
"""
Test BaseAttentionConfigParser defaults to MHA when num_key_value_heads not
specified.
"""
hf_config = Qwen3Config(
hidden_size=4096,
num_attention_heads=32,
# No num_key_value_heads specified
)
vllm_config = create_mock_vllm_config(hf_config)
parser_chain = AttentionMetrics.get_parser()
result = parser_chain.parse(vllm_config)
# Should default to MHA (num_key_value_heads = num_attention_heads)
assert result.num_key_value_heads == 32
def test_base_ffn_config_parser_dense():
"""Test BaseFfnConfigParser for dense FFN."""
hf_config = Qwen3Config(
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
)
vllm_config = create_mock_vllm_config(hf_config)
parser_chain = FfnMetrics.get_parser()
result = parser_chain.parse(vllm_config)
assert result.intermediate_size == 11008
assert result.num_experts == 0
assert result.num_experts_per_tok == 0
assert result.num_moe_layers == 0 # No MoE
def test_base_ffn_config_parser_moe():
"""Test BaseFfnConfigParser for MoE FFN."""
hf_config = Qwen3MoeConfig(
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_experts=64,
num_experts_per_tok=8,
moe_intermediate_size=14336,
n_shared_experts=2,
)
vllm_config = create_mock_vllm_config(hf_config)
parser_chain = FfnMetrics.get_parser()
result = parser_chain.parse(vllm_config)
assert result.num_experts == 64
assert result.num_experts_per_tok == 8
assert result.moe_intermediate_size == 14336
assert result.num_shared_experts == 2
assert result.num_moe_layers == 32 # All layers are MoE by default
def test_interleave_moe_layer_step_parser():
"""Test InterleaveMoeLayerStepParser correctly computes MoE layer count."""
hf_config = Llama4Config(
text_config=Llama4TextConfig(
num_hidden_layers=32,
num_local_experts=64,
interleave_moe_layer_step=4, # Every 4th layer is MoE
),
)
vllm_config = create_mock_vllm_config(hf_config)
parser_chain = FfnMetrics.get_parser()
result = parser_chain.parse(vllm_config)
assert result.num_moe_layers == 8
def test_moe_layer_freq_parser():
"""Test MoeLayerFreqParser correctly computes MoE layer count."""
hf_config = DeepseekV3Config(
num_hidden_layers=30,
n_routed_experts=64,
moe_layer_freq=3, # Every 3rd layer after first_k_dense_replace
first_k_dense_replace=6, # First 6 layers are dense
)
vllm_config = create_mock_vllm_config(hf_config)
parser_chain = FfnMetrics.get_parser()
result = parser_chain.parse(vllm_config)
# Layers >= 6 and divisible by 3: 6, 9, 12, 15, 18, 21, 24, 27
expected_moe_layers = len(
[layer for layer in range(30) if layer >= 6 and layer % 3 == 0]
)
assert expected_moe_layers == 8
assert result.num_moe_layers == expected_moe_layers
#### ComponentMetrics Tests ####
def test_attention_metrics_scaling():
"""Test that attention metrics scale proportionally with model dimensions."""
base_hf_config = Qwen3Config(
hidden_size=2048,
num_attention_heads=16,
num_key_value_heads=16,
num_hidden_layers=12,
head_dim=128,
)
base_vllm_config = create_mock_vllm_config(base_hf_config)
base_metrics = AttentionMetrics.from_vllm_config(base_vllm_config)
# Test scaling with number of layers
double_layers_hf_config = Qwen3Config(
hidden_size=2048,
num_attention_heads=16,
num_key_value_heads=16,
num_hidden_layers=24, # Double the layers
head_dim=128,
)
double_layers_vllm_config = create_mock_vllm_config(double_layers_hf_config)
double_layers_metrics = AttentionMetrics.from_vllm_config(double_layers_vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# FLOPS should double when layers double
base_flops = base_metrics.get_num_flops(ctx)
double_flops = double_layers_metrics.get_num_flops(ctx)
assert double_flops == 2 * base_flops
# Read/write bytes should also scale proportionally
base_read = base_metrics.get_read_bytes(ctx)
double_read = double_layers_metrics.get_read_bytes(ctx)
assert double_read == 2 * base_read
base_write = base_metrics.get_write_bytes(ctx)
double_write = double_layers_metrics.get_write_bytes(ctx)
assert double_write == 2 * base_write
def test_attention_metrics_grouped_query():
"""Test attention metrics handle grouped query attention correctly."""
mha_hf_config = Qwen3Config(
hidden_size=4096,
num_attention_heads=32,
num_key_value_heads=32, # MHA
num_hidden_layers=1,
)
mha_config = create_mock_vllm_config(mha_hf_config)
gqa_hf_config = Qwen3Config(
hidden_size=4096,
num_attention_heads=32,
num_key_value_heads=8, # GQA with 4:1 ratio
num_hidden_layers=1,
)
gqa_config = create_mock_vllm_config(gqa_hf_config)
mha_metrics = AttentionMetrics.from_vllm_config(mha_config)
gqa_metrics = AttentionMetrics.from_vllm_config(gqa_config)
ctx = ExecutionContext.from_single_request(
num_tokens=1, context_len=1024, is_prefill=False
)
# GQA should have less KV cache reads since fewer KV heads
mha_read = mha_metrics.get_read_bytes(ctx)
gqa_read = gqa_metrics.get_read_bytes(ctx)
assert gqa_read < mha_read
def test_ffn_metrics_scaling():
"""Test FFN metrics scale proportionally with model dimensions."""
base_hf_config = Qwen3Config(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=12,
)
base_vllm_config = create_mock_vllm_config(base_hf_config)
base_metrics = FfnMetrics.from_vllm_config(base_vllm_config)
# Test scaling with intermediate size
larger_ffn_hf_config = Qwen3Config(
hidden_size=2048,
intermediate_size=16384, # Double intermediate size
num_hidden_layers=12,
)
larger_ffn_vllm_config = create_mock_vllm_config(larger_ffn_hf_config)
larger_ffn_metrics = FfnMetrics.from_vllm_config(larger_ffn_vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# FLOPS should double when intermediate size doubles
base_flops = base_metrics.get_num_flops(ctx)
larger_flops = larger_ffn_metrics.get_num_flops(ctx)
assert larger_flops == base_flops * 2
def test_moe_metrics_vs_dense():
"""Test MoE metrics versus dense metrics."""
dense_hf_config = Qwen3Config(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=12,
)
dense_config = create_mock_vllm_config(dense_hf_config)
moe_hf_config = Qwen3MoeConfig(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=12,
num_experts=64,
num_experts_per_tok=2, # 2 routed expert
moe_intermediate_size=8192,
n_shared_experts=0,
)
moe_config = create_mock_vllm_config(moe_hf_config)
dense_metrics = FfnMetrics.from_vllm_config(dense_config)
moe_metrics = FfnMetrics.from_vllm_config(moe_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# MoE should have different compute/memory characteristics
dense_flops = dense_metrics.get_num_flops(ctx)
moe_flops = moe_metrics.get_num_flops(ctx)
# 2 routed experts vs 1 dense.
assert moe_flops == dense_flops * 2
def test_unembed_metrics_scaling():
"""Test unembedding metrics scale with vocab size."""
small_vocab_hf_config = Qwen3Config(
hidden_size=2048,
vocab_size=32000,
)
small_vocab_config = create_mock_vllm_config(small_vocab_hf_config)
large_vocab_hf_config = Qwen3Config(
hidden_size=2048,
vocab_size=64000, # Double vocab size
)
large_vocab_config = create_mock_vllm_config(large_vocab_hf_config)
small_vocab_metrics = UnembedMetrics.from_vllm_config(small_vocab_config)
large_vocab_metrics = UnembedMetrics.from_vllm_config(large_vocab_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# FLOPS should double when vocab size doubles
small_flops = small_vocab_metrics.get_num_flops(ctx)
large_flops = large_vocab_metrics.get_num_flops(ctx)
assert large_flops == 2 * small_flops
def test_prefill_vs_decode_differences():
"""Test that prefill and decode have different memory access patterns."""
hf_config = Qwen3Config(
hidden_size=2048,
num_attention_heads=16,
num_key_value_heads=16,
num_hidden_layers=1,
)
config = create_mock_vllm_config(hf_config)
metrics = AttentionMetrics.from_vllm_config(config)
prefill_ctx = ExecutionContext.from_single_request(
num_tokens=512, context_len=512, is_prefill=True
)
decode_ctx = ExecutionContext.from_single_request(
num_tokens=1, context_len=512, is_prefill=False
)
prefill_read = metrics.get_read_bytes(prefill_ctx)
decode_read = metrics.get_read_bytes(decode_ctx)
assert prefill_read != decode_read
def test_model_metrics_aggregation():
"""Test ModelMetrics correctly aggregates across components."""
hf_config = Qwen3Config(
hidden_size=2048,
num_attention_heads=16,
num_hidden_layers=12,
vocab_size=32000,
intermediate_size=8192,
)
config = create_mock_vllm_config(hf_config)
model_metrics = ModelMetrics(config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# Should have metrics for attention, ffn, and unembed
total_flops = model_metrics.get_num_flops(ctx)
breakdown = model_metrics.get_num_flops_breakdown(ctx)
# Breakdown should sum to total
assert total_flops == sum(breakdown.values())
def test_moe_expert_activation_proportional_scaling():
"""Test that routed expert metrics scale proportionally with num_experts_per_tok."""
base_moe_config = Qwen3MoeConfig(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=12,
num_experts=64,
num_experts_per_tok=1, # 1 expert per token
moe_intermediate_size=8192,
n_shared_experts=2,
)
double_experts_config = Qwen3MoeConfig(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=12,
num_experts=64,
num_experts_per_tok=2, # 2 experts per token (double)
moe_intermediate_size=8192,
n_shared_experts=2, # Same shared experts
)
triple_experts_config = Qwen3MoeConfig(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=12,
num_experts=64,
num_experts_per_tok=3, # 3 experts per token (triple)
moe_intermediate_size=8192,
n_shared_experts=2, # Same shared experts
)
base_vllm_config = create_mock_vllm_config(base_moe_config)
double_vllm_config = create_mock_vllm_config(double_experts_config)
triple_vllm_config = create_mock_vllm_config(triple_experts_config)
base_metrics = FfnMetrics.from_vllm_config(base_vllm_config)
double_metrics = FfnMetrics.from_vllm_config(double_vllm_config)
triple_metrics = FfnMetrics.from_vllm_config(triple_vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# Get total metrics - the key insight is that differences should be proportional
base_flops = base_metrics.get_num_flops(ctx)
double_flops = double_metrics.get_num_flops(ctx)
triple_flops = triple_metrics.get_num_flops(ctx)
# The difference between double and base should equal one additional expert
one_expert_diff = double_flops - base_flops
# The difference between triple and base should equal two additional experts
two_expert_diff = triple_flops - base_flops
# Proportional scaling: 2 * (1 expert diff) should equal (2 expert diff)
assert two_expert_diff == 2 * one_expert_diff
# Same logic applies to memory operations
base_read = base_metrics.get_read_bytes(ctx)
double_read = double_metrics.get_read_bytes(ctx)
triple_read = triple_metrics.get_read_bytes(ctx)
one_expert_read_diff = double_read - base_read
two_expert_read_diff = triple_read - base_read
assert two_expert_read_diff == 2 * one_expert_read_diff
# Same for write bytes
base_write = base_metrics.get_write_bytes(ctx)
double_write = double_metrics.get_write_bytes(ctx)
triple_write = triple_metrics.get_write_bytes(ctx)
one_expert_write_diff = double_write - base_write
two_expert_write_diff = triple_write - base_write
assert two_expert_write_diff == 2 * one_expert_write_diff
def test_quantization_config_parser_fp8():
"""Test quantization parsers with fp8."""
class MockQuantConfig:
def get_name(self):
return "fp8"
hf_config = Qwen3Config(
hidden_size=2048, num_attention_heads=16, num_hidden_layers=1
)
vllm_config = create_mock_vllm_config(hf_config, quant_config=MockQuantConfig())
attn_result = AttentionMetrics.get_parser().parse(vllm_config)
assert attn_result.weight_byte_size == 1 # fp8
ffn_result = FfnMetrics.get_parser().parse(vllm_config)
assert ffn_result.weight_byte_size == 1 # fp8
def test_quantization_config_parser_mxfp4():
"""Test quantization parsers with mxfp4."""
class MockQuantConfig:
def get_name(self):
return "mxfp4"
hf_config = Qwen3Config(
hidden_size=2048, intermediate_size=8192, num_hidden_layers=1
)
vllm_config = create_mock_vllm_config(hf_config, quant_config=MockQuantConfig())
ffn_result = FfnMetrics.get_parser().parse(vllm_config)
assert ffn_result.weight_byte_size == 0.5 # mxfp4
#### Per-GPU Tests ####
def test_attention_per_gpu_with_tensor_parallelism():
"""Test attention metrics with tensor parallelism - per_gpu vs global."""
hf_config = Qwen3Config(
hidden_size=4096,
num_attention_heads=32,
num_key_value_heads=8,
num_hidden_layers=24,
)
# Test with TP=4
vllm_config = create_mock_vllm_config(hf_config, tensor_parallel_size=4)
metrics = AttentionMetrics.from_vllm_config(vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=128, context_len=1024, is_prefill=True
)
# Get global and per-gpu metrics
global_flops = metrics.get_num_flops(ctx, per_gpu=False)
per_gpu_flops = metrics.get_num_flops(ctx, per_gpu=True)
# With TP=4, global flops should be 4x per-gpu flops (heads divided by 4)
assert global_flops == 4 * per_gpu_flops
# Same for read/write bytes
global_read = metrics.get_read_bytes(ctx, per_gpu=False)
per_gpu_read = metrics.get_read_bytes(ctx, per_gpu=True)
# Reads should scale similarly (weight reads are divided by TP)
assert global_read > per_gpu_read
global_write = metrics.get_write_bytes(ctx, per_gpu=False)
per_gpu_write = metrics.get_write_bytes(ctx, per_gpu=True)
assert global_write > per_gpu_write
def test_attention_per_gpu_with_pipeline_parallelism():
"""Test attention metrics with pipeline parallelism - per_gpu vs global."""
hf_config = Qwen3Config(
hidden_size=2048,
num_attention_heads=16,
num_hidden_layers=32,
)
# Test with PP=4
vllm_config = create_mock_vllm_config(hf_config, pipeline_parallel_size=4)
metrics = AttentionMetrics.from_vllm_config(vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=False
)
# Get global and per-gpu metrics
global_flops = metrics.get_num_flops(ctx, per_gpu=False)
per_gpu_flops = metrics.get_num_flops(ctx, per_gpu=True)
# With PP=4, global flops should be 4x per-gpu flops (layers divided by 4)
assert global_flops == 4 * per_gpu_flops
global_read = metrics.get_read_bytes(ctx, per_gpu=False)
per_gpu_read = metrics.get_read_bytes(ctx, per_gpu=True)
assert global_read == 4 * per_gpu_read
def test_ffn_per_gpu_with_tensor_parallelism():
"""Test FFN metrics with tensor parallelism - per_gpu vs global."""
hf_config = Qwen3Config(
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
)
# Test with DP=2, TP=4 (ffn_tp_size will be 8)
vllm_config = create_mock_vllm_config(
hf_config,
data_parallel_size=2,
tensor_parallel_size=4,
)
metrics = FfnMetrics.from_vllm_config(vllm_config)
# ffn_tp_size should be dp_size * tp_size = 8 (when EP not enabled)
assert metrics.ffn_tp_size == 8
ctx = ExecutionContext.from_single_request(
num_tokens=128, context_len=2048, is_prefill=True
)
# Get global and per-gpu metrics
global_flops = metrics.get_num_flops(ctx, per_gpu=False)
per_gpu_flops = metrics.get_num_flops(ctx, per_gpu=True)
# With ffn_tp_size=8, global should be 8x per-gpu
assert global_flops == 8 * per_gpu_flops
def test_ffn_per_gpu_with_pipeline_parallelism():
"""Test FFN metrics with pipeline parallelism - per_gpu vs global."""
hf_config = Qwen3Config(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=24,
)
# Test with PP=6
vllm_config = create_mock_vllm_config(hf_config, pipeline_parallel_size=6)
metrics = FfnMetrics.from_vllm_config(vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# Get global and per-gpu metrics
global_flops = metrics.get_num_flops(ctx, per_gpu=False)
per_gpu_flops = metrics.get_num_flops(ctx, per_gpu=True)
# With PP=6, global should be 6x per-gpu (layers divided by 6)
assert global_flops == 6 * per_gpu_flops
def test_moe_per_gpu_with_expert_parallelism():
"""
Test MoE metrics with expert parallelism - verifies num_activated_experts bug fix.
"""
hf_config = Qwen3MoeConfig(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=24,
num_experts=64,
num_experts_per_tok=8,
moe_intermediate_size=14336,
n_shared_experts=2,
)
# Test with DP=2, TP=4, EP enabled (ffn_ep_size will be 8)
vllm_config = create_mock_vllm_config(
hf_config,
data_parallel_size=2,
tensor_parallel_size=4,
enable_expert_parallel=True,
)
metrics = FfnMetrics.from_vllm_config(vllm_config)
# When EP enabled, ffn_ep_size = dp_size * tp_size = 8
assert metrics.ffn_ep_size == 8
assert metrics.ffn_tp_size == 1
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# Get per-gpu metrics
per_gpu_read_breakdown = metrics.get_read_bytes_breakdown(ctx, per_gpu=True)
global_read_breakdown = metrics.get_read_bytes_breakdown(ctx, per_gpu=False)
# Verify that routed expert weight reads are reasonable
# With per_gpu=True, each GPU has 64/8 = 8 experts
# T=100, E_per_gpu=8/8=1, so T*E=100 expert activations
# num_activated_experts should be min(100, 8) = 8
# Check that weight reads scale appropriately
# Global has all 64 experts, per-gpu has 8 experts
# So weight reads should reflect this difference
if "routed_up_gate_weights" in per_gpu_read_breakdown:
per_gpu_weight_reads = per_gpu_read_breakdown["routed_up_gate_weights"]
global_weight_reads = global_read_breakdown["routed_up_gate_weights"]
# The ratio should reflect the expert count difference
# This verifies the bug fix works correctly
assert per_gpu_weight_reads < global_weight_reads
# Global should read more experts than per-gpu
# Exact ratio depends on num_activated_experts calculation
ratio = global_weight_reads / per_gpu_weight_reads
# Should be > 1 since global has more experts to read
assert ratio > 1
def test_moe_per_gpu_expert_activation_accounting():
"""
Test that MoE correctly accounts for expert activations with small batch sizes.
"""
hf_config = Qwen3MoeConfig(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=12,
num_experts=64,
num_experts_per_tok=8,
moe_intermediate_size=14336,
n_shared_experts=0, # No shared experts for this test
)
# Test with EP=8
vllm_config = create_mock_vllm_config(
hf_config,
data_parallel_size=8,
enable_expert_parallel=True,
)
metrics = FfnMetrics.from_vllm_config(vllm_config)
# Small batch: T=10, E_per_gpu=8/8=1
# Each GPU: T*E = 10*1 = 10 activations
# Experts per GPU: 64/8 = 8
# So num_activated_experts should be min(10, 8) = 8
small_ctx = ExecutionContext.from_single_request(
num_tokens=10, context_len=512, is_prefill=True
)
small_read = metrics.get_read_bytes_breakdown(small_ctx, per_gpu=True)
# Large batch: T=1000, E_per_gpu=1
# Each GPU: T*E = 1000*1 = 1000 activations
# Experts per GPU: 8
# So num_activated_experts should be min(1000, 8) = 8 (all experts activated)
large_ctx = ExecutionContext.from_single_request(
num_tokens=1000, context_len=512, is_prefill=True
)
large_read = metrics.get_read_bytes_breakdown(large_ctx, per_gpu=True)
# Weight reads should be similar (both activate all 8 experts per GPU)
# But activation reads should differ (proportional to T*E)
if "routed_up_gate_weights" in small_read:
small_weight = small_read["routed_up_gate_weights"]
large_weight = large_read["routed_up_gate_weights"]
# Weight reads should be the same (both read all 8 experts)
assert small_weight == large_weight
# But input activation reads should scale with T*E
small_input = small_read["routed_up_gate_input"]
large_input = large_read["routed_up_gate_input"]
assert large_input == 100 * small_input # 1000/10 = 100x
def test_unembed_per_gpu_with_tensor_parallelism():
"""Test unembed metrics with tensor parallelism - per_gpu vs global."""
hf_config = Qwen3Config(
hidden_size=4096,
vocab_size=128000,
)
# Test with TP=8
vllm_config = create_mock_vllm_config(hf_config, tensor_parallel_size=8)
metrics = UnembedMetrics.from_vllm_config(vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# Get global and per-gpu metrics
global_flops = metrics.get_num_flops(ctx, per_gpu=False)
per_gpu_flops = metrics.get_num_flops(ctx, per_gpu=True)
# With TP=8, vocab is divided by 8, so global should be 8x per-gpu
assert global_flops == 8 * per_gpu_flops
# For read bytes, weight reads scale with TP but input reads don't (replicated)
global_read_breakdown = metrics.get_read_bytes_breakdown(ctx, per_gpu=False)
per_gpu_read_breakdown = metrics.get_read_bytes_breakdown(ctx, per_gpu=True)
# Input reads should be the same (replicated across TP ranks)
assert global_read_breakdown["input"] == per_gpu_read_breakdown["input"]
# Weight reads should scale 8x (divided by TP)
assert global_read_breakdown["weight"] == 8 * per_gpu_read_breakdown["weight"]
def test_model_metrics_per_gpu_aggregation():
"""Test ModelMetrics correctly aggregates per_gpu metrics across components."""
hf_config = Qwen3Config(
hidden_size=2048,
num_attention_heads=16,
num_hidden_layers=12,
vocab_size=32000,
intermediate_size=8192,
)
# Test with mixed parallelism: TP=2, PP=2
vllm_config = create_mock_vllm_config(
hf_config,
tensor_parallel_size=2,
pipeline_parallel_size=2,
)
model_metrics = ModelMetrics(vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# Get breakdowns for both modes
per_gpu_breakdown = model_metrics.get_num_flops_breakdown(ctx, per_gpu=True)
global_breakdown = model_metrics.get_num_flops_breakdown(ctx, per_gpu=False)
# Verify breakdown sums match totals
per_gpu_total = model_metrics.get_num_flops(ctx, per_gpu=True)
global_total = model_metrics.get_num_flops(ctx, per_gpu=False)
assert per_gpu_total == sum(per_gpu_breakdown.values())
assert global_total == sum(global_breakdown.values())
# Global should be larger than per-gpu due to parallelism
assert global_total > per_gpu_total
# With TP=2 and PP=2, the ratio depends on which parallelism applies to
# which component but we can verify that global is reasonably larger
ratio = global_total / per_gpu_total
assert ratio > 1 # Should be between PP and TP*PP depending on component mix
def test_attention_per_gpu_heads_not_evenly_divisible():
"""Test attention with heads not evenly divisible by TP."""
hf_config = Qwen3Config(
hidden_size=2048,
num_attention_heads=17, # Not divisible by 4
num_key_value_heads=5, # Not divisible by 4
num_hidden_layers=8,
)
vllm_config = create_mock_vllm_config(hf_config, tensor_parallel_size=4)
metrics = AttentionMetrics.from_vllm_config(vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=64, context_len=256, is_prefill=True
)
# Should not crash and should handle max(1, ...) correctly
per_gpu_flops = metrics.get_num_flops(ctx, per_gpu=True)
global_flops = metrics.get_num_flops(ctx, per_gpu=False)
# Both should be positive
assert per_gpu_flops > 0
assert global_flops > 0
assert global_flops > per_gpu_flops
...@@ -64,6 +64,9 @@ class ObservabilityConfig: ...@@ -64,6 +64,9 @@ class ObservabilityConfig:
module in the model and attach informations such as input/output shapes to module in the model and attach informations such as input/output shapes to
nvtx range markers. Noted that this doesn't work with CUDA graphs enabled.""" nvtx range markers. Noted that this doesn't work with CUDA graphs enabled."""
enable_mfu_metrics: bool = False
"""Enable Model FLOPs Utilization (MFU) metrics."""
@cached_property @cached_property
def collect_model_forward_time(self) -> bool: def collect_model_forward_time(self) -> bool:
"""Whether to collect model forward time for the request.""" """Whether to collect model forward time for the request."""
......
...@@ -523,6 +523,7 @@ class EngineArgs: ...@@ -523,6 +523,7 @@ class EngineArgs:
enable_layerwise_nvtx_tracing: bool = ( enable_layerwise_nvtx_tracing: bool = (
ObservabilityConfig.enable_layerwise_nvtx_tracing ObservabilityConfig.enable_layerwise_nvtx_tracing
) )
enable_mfu_metrics: bool = ObservabilityConfig.enable_mfu_metrics
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
...@@ -1042,6 +1043,10 @@ class EngineArgs: ...@@ -1042,6 +1043,10 @@ class EngineArgs:
"--enable-layerwise-nvtx-tracing", "--enable-layerwise-nvtx-tracing",
**observability_kwargs["enable_layerwise_nvtx_tracing"], **observability_kwargs["enable_layerwise_nvtx_tracing"],
) )
observability_group.add_argument(
"--enable-mfu-metrics",
**observability_kwargs["enable_mfu_metrics"],
)
# Scheduler arguments # Scheduler arguments
scheduler_kwargs = get_kwargs(SchedulerConfig) scheduler_kwargs = get_kwargs(SchedulerConfig)
...@@ -1689,6 +1694,7 @@ class EngineArgs: ...@@ -1689,6 +1694,7 @@ class EngineArgs:
kv_cache_metrics_sample=self.kv_cache_metrics_sample, kv_cache_metrics_sample=self.kv_cache_metrics_sample,
cudagraph_metrics=self.cudagraph_metrics, cudagraph_metrics=self.cudagraph_metrics,
enable_layerwise_nvtx_tracing=self.enable_layerwise_nvtx_tracing, enable_layerwise_nvtx_tracing=self.enable_layerwise_nvtx_tracing,
enable_mfu_metrics=self.enable_mfu_metrics,
) )
# Compilation config overrides # Compilation config overrides
......
...@@ -244,6 +244,7 @@ if TYPE_CHECKING: ...@@ -244,6 +244,7 @@ if TYPE_CHECKING:
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
VLLM_USE_V2_MODEL_RUNNER: bool = False VLLM_USE_V2_MODEL_RUNNER: bool = False
VLLM_DEBUG_MFU_METRICS: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -1565,6 +1566,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1565,6 +1566,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_V2_MODEL_RUNNER": lambda: bool( "VLLM_USE_V2_MODEL_RUNNER": lambda: bool(
int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0")) int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0"))
), ),
# Debug logging for --enable-mfu-metrics
"VLLM_DEBUG_MFU_METRICS": lambda: bool(
int(os.getenv("VLLM_DEBUG_MFU_METRICS", "0"))
),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -43,6 +43,7 @@ from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_qu ...@@ -43,6 +43,7 @@ from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_qu
from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.perf import ModelMetrics, PerfStats
from vllm.v1.metrics.stats import ( from vllm.v1.metrics.stats import (
PrefixCacheStats, PrefixCacheStats,
SchedulerStats, SchedulerStats,
...@@ -219,6 +220,10 @@ class Scheduler(SchedulerInterface): ...@@ -219,6 +220,10 @@ class Scheduler(SchedulerInterface):
self.use_pp = self.parallel_config.pipeline_parallel_size > 1 self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
self.perf_metrics: ModelMetrics | None = None
if self.log_stats and vllm_config.observability_config.enable_mfu_metrics:
self.perf_metrics = ModelMetrics(vllm_config)
def schedule(self) -> SchedulerOutput: def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm: # NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler. # There's no "decoding phase" nor "prefill phase" in the scheduler.
...@@ -1066,6 +1071,10 @@ class Scheduler(SchedulerInterface): ...@@ -1066,6 +1071,10 @@ class Scheduler(SchedulerInterface):
kv_connector_output = model_runner_output.kv_connector_output kv_connector_output = model_runner_output.kv_connector_output
cudagraph_stats = model_runner_output.cudagraph_stats cudagraph_stats = model_runner_output.cudagraph_stats
perf_stats: PerfStats | None = None
if self.perf_metrics and self.perf_metrics.is_enabled():
perf_stats = self.perf_metrics.get_step_perf_stats_per_gpu(scheduler_output)
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: SpecDecodingStats | None = None spec_decoding_stats: SpecDecodingStats | None = None
kv_connector_stats: KVConnectorStats | None = ( kv_connector_stats: KVConnectorStats | None = (
...@@ -1262,7 +1271,7 @@ class Scheduler(SchedulerInterface): ...@@ -1262,7 +1271,7 @@ class Scheduler(SchedulerInterface):
if ( if (
stats := self.make_stats( stats := self.make_stats(
spec_decoding_stats, kv_connector_stats, cudagraph_stats spec_decoding_stats, kv_connector_stats, cudagraph_stats, perf_stats
) )
) is not None: ) is not None:
# Return stats to only one of the front-ends. # Return stats to only one of the front-ends.
...@@ -1485,6 +1494,7 @@ class Scheduler(SchedulerInterface): ...@@ -1485,6 +1494,7 @@ class Scheduler(SchedulerInterface):
spec_decoding_stats: SpecDecodingStats | None = None, spec_decoding_stats: SpecDecodingStats | None = None,
kv_connector_stats: KVConnectorStats | None = None, kv_connector_stats: KVConnectorStats | None = None,
cudagraph_stats: CUDAGraphStat | None = None, cudagraph_stats: CUDAGraphStat | None = None,
perf_stats: PerfStats | None = None,
) -> SchedulerStats | None: ) -> SchedulerStats | None:
if not self.log_stats: if not self.log_stats:
return None return None
...@@ -1510,6 +1520,7 @@ class Scheduler(SchedulerInterface): ...@@ -1510,6 +1520,7 @@ class Scheduler(SchedulerInterface):
spec_decoding_stats=spec_stats, spec_decoding_stats=spec_stats,
kv_connector_stats=connector_stats_payload, kv_connector_stats=connector_stats_payload,
cudagraph_stats=cudagraph_stats, cudagraph_stats=cudagraph_stats,
perf_stats=perf_stats,
) )
def make_spec_decoding_stats( def make_spec_decoding_stats(
......
...@@ -19,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( ...@@ -19,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.plugins import STAT_LOGGER_PLUGINS_GROUP, load_plugins_by_group from vllm.plugins import STAT_LOGGER_PLUGINS_GROUP, load_plugins_by_group
from vllm.v1.engine import FinishReason from vllm.v1.engine import FinishReason
from vllm.v1.metrics.perf import PerfMetricsLogging
from vllm.v1.metrics.prometheus import unregister_vllm_metrics from vllm.v1.metrics.prometheus import unregister_vllm_metrics
from vllm.v1.metrics.stats import ( from vllm.v1.metrics.stats import (
CachingMetrics, CachingMetrics,
...@@ -118,6 +119,9 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -118,6 +119,9 @@ class LoggingStatLogger(StatLoggerBase):
self.engine_is_idle = False self.engine_is_idle = False
self.aggregated = False self.aggregated = False
if self._enable_perf_stats():
self.perf_metrics_logging = PerfMetricsLogging(vllm_config)
def _reset(self, now): def _reset(self, now):
self.last_log_time = now self.last_log_time = now
...@@ -127,6 +131,9 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -127,6 +131,9 @@ class LoggingStatLogger(StatLoggerBase):
self.num_corrupted_reqs: int = 0 self.num_corrupted_reqs: int = 0
self.num_preemptions: int = 0 self.num_preemptions: int = 0
def _enable_perf_stats(self) -> bool:
return self.vllm_config.observability_config.enable_mfu_metrics
def _track_iteration_stats(self, iteration_stats: IterationStats): def _track_iteration_stats(self, iteration_stats: IterationStats):
# Save tracked stats for token counters. # Save tracked stats for token counters.
self.num_prompt_tokens += iteration_stats.num_prompt_tokens self.num_prompt_tokens += iteration_stats.num_prompt_tokens
...@@ -175,6 +182,8 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -175,6 +182,8 @@ class LoggingStatLogger(StatLoggerBase):
self.cudagraph_logging.observe(scheduler_stats.cudagraph_stats) self.cudagraph_logging.observe(scheduler_stats.cudagraph_stats)
if not self.aggregated: if not self.aggregated:
self.last_scheduler_stats = scheduler_stats self.last_scheduler_stats = scheduler_stats
if (perf_stats := scheduler_stats.perf_stats) and self._enable_perf_stats():
self.perf_metrics_logging.observe(perf_stats)
if mm_cache_stats: if mm_cache_stats:
self.mm_caching_metrics.observe(mm_cache_stats) self.mm_caching_metrics.observe(mm_cache_stats)
...@@ -211,7 +220,7 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -211,7 +220,7 @@ class LoggingStatLogger(StatLoggerBase):
"Running: %d reqs", "Running: %d reqs",
"Waiting: %d reqs", "Waiting: %d reqs",
] ]
log_args = [ log_args: list[int | float | str] = [
self.last_prompt_throughput, self.last_prompt_throughput,
self.last_generation_throughput, self.last_generation_throughput,
self.last_scheduler_stats.num_running_reqs, self.last_scheduler_stats.num_running_reqs,
...@@ -254,6 +263,8 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -254,6 +263,8 @@ class LoggingStatLogger(StatLoggerBase):
self.kv_connector_logging.log(log_fn=log_fn) self.kv_connector_logging.log(log_fn=log_fn)
if self.cudagraph_logging is not None: if self.cudagraph_logging is not None:
self.cudagraph_logging.log(log_fn=log_fn) self.cudagraph_logging.log(log_fn=log_fn)
if self._enable_perf_stats():
self.perf_metrics_logging.log(log_fn=log_fn, log_prefix=self.log_prefix)
def log_engine_initialized(self): def log_engine_initialized(self):
if self.vllm_config.cache_config.num_gpu_blocks: if self.vllm_config.cache_config.num_gpu_blocks:
...@@ -282,6 +293,10 @@ class AggregatedLoggingStatLogger(LoggingStatLogger, AggregateStatLoggerBase): ...@@ -282,6 +293,10 @@ class AggregatedLoggingStatLogger(LoggingStatLogger, AggregateStatLoggerBase):
def log_prefix(self): def log_prefix(self):
return "{} Engines Aggregated: ".format(len(self.engine_indexes)) return "{} Engines Aggregated: ".format(len(self.engine_indexes))
def _enable_perf_stats(self) -> bool:
# Adding per_gpu perf stats across engines can lead to misleading numbers.
return False
def record( def record(
self, self,
scheduler_stats: SchedulerStats | None, scheduler_stats: SchedulerStats | None,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Analytic flops/memory estimation module for transformer components,
to help derive MFU (Model Flops Utilization) stats for a running model.
"""
import json
import time
from abc import ABC, abstractmethod
from collections.abc import Iterable
from dataclasses import asdict, dataclass
from typing import Any, Protocol
import torch
from pydantic import BaseModel, Field, ValidationError, model_validator
from typing_extensions import Self
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE,
get_dtype_size,
get_kv_cache_torch_dtype,
)
from vllm.v1.core.sched.output import SchedulerOutput
logger = init_logger(__name__)
class InvalidComponent(Exception):
"""
Custom exception to indicate that a certain ComponentMetric is not
applicable to the given VllmConfig.
"""
pass
#### Basic Data Types ####
@dataclass
class DebugPerfStats:
## Stats for debugging the metrics calculation
calc_duration: float = 0.0 # time spent calculating these stats
num_prefill_requests: int = 0
num_decode_requests: int = 0
context_breakdown: dict[str, int] | None = None
num_flops_per_gpu_breakdown: dict[str, int] | None = None
num_read_bytes_per_gpu_breakdown: dict[str, int] | None = None
num_write_bytes_per_gpu_breakdown: dict[str, int] | None = None
@dataclass
class PerfStats:
num_flops_per_gpu: int = 0
num_read_bytes_per_gpu: int = 0
num_write_bytes_per_gpu: int = 0
debug_stats: DebugPerfStats | None = None
@dataclass
class ExecutionContext:
"""
Represents an execution context for a batch of requests.
This class aggregates statistics across multiple requests in a batch,
separately tracking prefill and decode phases.
Example)
- Batch with one full prefill (2048 tokens) and one decode (1 token, 8192 context):
ctx = ExecutionContext()
ctx.add(2048, 2048, is_prefill=True)
ctx.add(1, 8192, is_prefill=False)
"""
# Prefill phase statistics
num_prefill_requests: int = 0
prefill_num_tokens: int = 0 # sum of num_tokens for prefill requests
prefill_context_len: int = 0 # sum of context_len for prefill requests
prefill_token_context_product: int = 0 # sum of (num_tokens * context_len)
# Decode phase statistics
num_decode_requests: int = 0
decode_num_tokens: int = 0 # sum of num_tokens for decode requests
decode_context_len: int = 0 # sum of context_len for decode requests
decode_token_context_product: int = 0 # sum of (num_tokens * context_len)
def add(self, num_tokens: int, context_len: int, is_prefill: bool) -> None:
"""Add a single request's statistics to this batch context."""
if is_prefill:
self.num_prefill_requests += 1
self.prefill_num_tokens += num_tokens
self.prefill_context_len += context_len
self.prefill_token_context_product += num_tokens * context_len
else:
self.num_decode_requests += 1
self.decode_num_tokens += num_tokens
self.decode_context_len += context_len
self.decode_token_context_product += num_tokens * context_len
def total_num_tokens(self) -> int:
"""Total number of tokens across all requests in the batch."""
return self.prefill_num_tokens + self.decode_num_tokens
def total_token_context_product(self) -> int:
"""Total sum of (num_tokens * context_len) across all requests."""
return self.prefill_token_context_product + self.decode_token_context_product
@classmethod
def from_single_request(
cls, num_tokens: int, context_len: int, is_prefill: bool
) -> "ExecutionContext":
"""Create an ExecutionContext from a single request.
This is a convenience method primarily for testing.
"""
ctx = cls()
ctx.add(num_tokens, context_len, is_prefill)
return ctx
class ParsedArgs:
"""
Syntactic sugar so that Parsers can use dot notations
to access/update the parsed arguments.
e.g.)
args = ParsedArgs()
args.x = 3
args.y = args.x + 1
"""
def __getattr__(self, name: str) -> Any:
raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'")
def __setattr__(self, name: str, value: Any) -> None:
object.__setattr__(self, name, value)
def model_dump(self) -> dict[str, Any]:
return vars(self).copy()
#### Abstract ####
class Parser(Protocol):
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
"""
Parse the vllm config and update the current ParsedArgs and pass it on.
If the parser isn't applicable to the vllm_config, it will do nothing.
"""
...
class ParserChain:
"""
Applies chain of parser in a sequential order.
Later parsers might overwrite results from previous parsers,
so parsers should be chained in the appropriate order if they
are not mutually exclusive.
"""
def __init__(self, *parsers: Parser) -> None:
self.parsers = list(parsers)
def add_parser(self, parser: Parser) -> None:
self.parsers.append(parser)
def parse(self, vllm_config: VllmConfig) -> ParsedArgs:
args = ParsedArgs()
for parser in self.parsers:
args = parser.parse(args, vllm_config)
return args
_COMPONENT_METRICS_REGISTRY: dict[str, type["ComponentMetrics"]] = {}
class ComponentMetrics(BaseModel, ABC):
"""
Each concrete ComponentMetrics class is associated with:
- fields that are required for metric derivation
(fields are specified/validated through pydantic model)
- parser to parse VllmConfig into fields
- metric methods that derive flops/bytes for a given execution context
"""
@classmethod
@abstractmethod
def component_type(cls) -> str: ...
@classmethod
@abstractmethod
def get_parser(cls) -> ParserChain:
"""
Return a ParserChain that provides values for all required fields.
The returned parser chain must populate ParsedArgs with values for every
field defined on this ComponentMetrics class. Missing fields will cause
a ValidationError when from_vllm_config() is called.
See individual Parser docstrings for which args they provide, and field
comments on ComponentMetrics subclasses for which parser provides each field.
"""
...
def __init_subclass__(cls):
_COMPONENT_METRICS_REGISTRY[cls.component_type()] = cls
@classmethod
def from_vllm_config(cls, vllm_config: VllmConfig) -> Self:
"""
Instantiate this class from VllmConfig.
Raises ValidationError if parsing fails.
"""
parser = cls.get_parser()
parsed_args = parser.parse(vllm_config)
try:
return cls.model_validate(parsed_args.model_dump())
except ValidationError as e:
raise InvalidComponent(f"Invalid {cls.component_type()} config: {e}") from e
@classmethod
def registered_metrics(cls) -> Iterable[type["ComponentMetrics"]]:
return iter(_COMPONENT_METRICS_REGISTRY.values())
@abstractmethod
def get_num_flops_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]: ...
@abstractmethod
def get_read_bytes_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]: ...
@abstractmethod
def get_write_bytes_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]: ...
def get_num_flops(self, ctx: ExecutionContext, per_gpu: bool = True) -> int:
return sum(self.get_num_flops_breakdown(ctx, per_gpu).values())
def get_read_bytes(self, ctx: ExecutionContext, per_gpu: bool = True) -> int:
return sum(self.get_read_bytes_breakdown(ctx, per_gpu).values())
def get_write_bytes(self, ctx: ExecutionContext, per_gpu: bool = True) -> int:
return sum(self.get_write_bytes_breakdown(ctx, per_gpu).values())
#### parsers ####
class BaseConfigParser(Parser):
"""
Parses base model configuration.
Provides: vocab_size, hidden_size, num_attention_heads, num_hidden_layers,
weight_byte_size, activation_byte_size, dp_size, tp_size, pp_size, enable_ep
"""
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
model_config = vllm_config.model_config
args.vocab_size = model_config.get_vocab_size()
args.hidden_size = model_config.get_hidden_size()
# NOTE: model_config.get_attention_heads() divide by TP
# so we access field manually here to get total num_heads
args.num_attention_heads = get_required(
model_config.hf_text_config, "num_attention_heads"
)
args.num_hidden_layers = get_required(
model_config.hf_text_config, "num_hidden_layers"
)
model_dtype = vllm_config.model_config.dtype
if isinstance(model_dtype, torch.dtype):
torch_dtype = model_dtype
elif isinstance(model_dtype, str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE:
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
else:
# FIXME: handle this better
logger.warning(
"Unknown model_dtype %s, defaulting to bfloat16",
model_dtype,
)
torch_dtype = torch.bfloat16
args.weight_byte_size = get_dtype_size(torch_dtype)
# FIXME: handle this better by parsing whether activations use
# bf16, fp32, etc...
args.activation_byte_size = 2
args.dp_size = vllm_config.parallel_config.data_parallel_size
args.tp_size = vllm_config.parallel_config.tensor_parallel_size
args.pp_size = vllm_config.parallel_config.pipeline_parallel_size
args.enable_ep = vllm_config.parallel_config.enable_expert_parallel
return args
#### Attention ####
class BaseAttentionConfigParser(Parser):
"""
Parses attention-specific configuration.
Provides: num_key_value_heads, head_dim, cache_byte_size
"""
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
model_config = vllm_config.model_config
args.num_key_value_heads = model_config.get_total_num_kv_heads()
args.head_dim = model_config.get_head_size()
model_dtype = vllm_config.model_config.dtype
cache_dtype = vllm_config.cache_config.cache_dtype
kv_cache_torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
args.cache_byte_size = get_dtype_size(kv_cache_torch_dtype)
return args
class AttentionQuantizationConfigParser(Parser):
"""
Parses quantization configuration for attention layers.
Overrides: weight_byte_size
"""
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
cfg = vllm_config.quant_config
if cfg is None:
return args
quant_method = cfg.get_name()
if quant_method in ["fp8", "fbgemm_fp8"]:
# FIXME: This is a hacky coarse-grained fp8 quantization detection.
# FIXME: These configs also have concept of "ignored layers" and we
# need to solve the same problem as above.
args.weight_byte_size = 1
elif quant_method == "mxfp4":
# FIXME: Also has "ignored layers" issue above
args.weight_byte_size = 0.5
else:
# FIXME: Add more parsing logic for different quant methods.
raise InvalidComponent
return args
class AttentionMetrics(ComponentMetrics):
# From BaseConfigParser
num_hidden_layers: int = Field(..., gt=0)
hidden_size: int = Field(..., gt=0)
num_attention_heads: int = Field(..., gt=0)
activation_byte_size: int = Field(..., gt=0)
tp_size: int = Field(..., gt=0)
pp_size: int = Field(..., gt=0)
# From BaseAttentionConfigParser
num_key_value_heads: int = Field(..., gt=0)
head_dim: int = Field(..., gt=0)
cache_byte_size: int = Field(..., gt=0)
# From BaseConfig Parser, overridden by AttentionQuantizationConfigParser
weight_byte_size: int | float = Field(..., gt=0)
# TODO: discern cases where we have mixture of different attention layer types
# such as SWA, MLA, etc.
@classmethod
def component_type(cls) -> str:
return "attn"
@classmethod
def get_parser(cls) -> ParserChain:
return ParserChain(
BaseConfigParser(),
BaseAttentionConfigParser(),
AttentionQuantizationConfigParser(),
)
def get_num_flops_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]:
L, D, q, kv, d = (
self.num_hidden_layers,
self.hidden_size,
self.num_attention_heads,
self.num_key_value_heads,
self.head_dim,
)
T = ctx.total_num_tokens()
TC = ctx.total_token_context_product()
if per_gpu:
L //= self.pp_size
# tensor parallel along heads
q = max(1, q // self.tp_size)
kv = max(1, kv // self.tp_size)
return {
"qkv_proj": 2 * T * D * (q + 2 * kv) * d * L,
"attn_qk": 2 * q * TC * d * L,
"attn_av": 2 * q * TC * d * L,
"out_proj": 2 * T * D * q * d * L,
}
def get_read_bytes_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]:
L, D, q, kv, d = (
self.num_hidden_layers,
self.hidden_size,
self.num_attention_heads,
self.num_key_value_heads,
self.head_dim,
)
T = ctx.total_num_tokens()
if per_gpu:
L //= self.pp_size
# tensor parallel along heads
q = max(1, q // self.tp_size)
kv = max(1, kv // self.tp_size)
read_bytes = {}
read_bytes["qkv_input"] = T * D * self.activation_byte_size * L
read_bytes["qkv_weight"] = int(D * (q + 2 * kv) * d * self.weight_byte_size * L)
# Attention input reads differ between prefill and decode
# Prefill: read Q, K, V activations (all in activation_byte_size)
if ctx.prefill_num_tokens > 0:
read_bytes["attn_input"] = (
(ctx.prefill_num_tokens * q + 2 * ctx.prefill_context_len * kv)
* d
* self.activation_byte_size
* L
)
# Decode: read Q activations + read K, V from cache (in cache_byte_size)
if ctx.decode_num_tokens > 0:
read_bytes["attn_input"] = read_bytes.get("attn_input", 0) + (
ctx.decode_num_tokens * q * d * self.activation_byte_size * L
+ 2 * ctx.decode_context_len * kv * d * self.cache_byte_size * L
)
read_bytes["out_input"] = T * q * d * self.activation_byte_size * L
read_bytes["out_weight"] = int(q * d * D * self.weight_byte_size * L)
return read_bytes
def get_write_bytes_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]:
"""Calculate write memory traffic for attention layers."""
L, D, q, kv, d = (
self.num_hidden_layers,
self.hidden_size,
self.num_attention_heads,
self.num_key_value_heads,
self.head_dim,
)
T = ctx.total_num_tokens()
if per_gpu:
L //= self.pp_size
# tensor parallel along heads
q = max(1, q // self.tp_size)
kv = max(1, kv // self.tp_size)
return {
"qkv_output": T * (q + 2 * kv) * d * self.activation_byte_size * L,
"kv_cache": 2 * T * kv * d * self.cache_byte_size * L,
"out_output": T * D * self.activation_byte_size * L,
}
#### Ffn ####
class BaseFfnConfigParser(Parser):
"""
Parses FFN and MoE configuration.
Provides: intermediate_size, num_experts, num_experts_per_tok,
moe_intermediate_size, num_shared_experts, num_moe_layers
"""
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
cfg = vllm_config.model_config.hf_config
if hasattr(cfg, "text_config") and cfg.text_config is not None:
cfg = cfg.text_config
args.intermediate_size = getattr(cfg, "intermediate_size", args.hidden_size * 4)
# Try different naming conventions.
args.num_experts = vllm_config.model_config.get_num_experts()
args.num_experts_per_tok = getattr_from_list(
cfg, ["num_experts_per_tok", "moe_topk"], 0
)
args.moe_intermediate_size = getattr_from_list(
cfg, ["moe_intermediate_size", "intermediate_size"], 0
)
args.num_shared_experts = getattr_from_list(
cfg, ["n_shared_experts", "num_shared_experts"], 0
)
is_moe = args.num_experts != 0
# Assume all MoE layers by default
args.num_moe_layers = args.num_hidden_layers if is_moe else 0
return args
class FfnParallelParser(Parser):
"""
Parses FFN parallelism configuration.
Provides: ffn_tp_size, ffn_ep_size
"""
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
# NOTE: ffn tp_size does not equal the tp_size parameter directly.
# e.g.) If we use DP2TP4, ffn will use TP8 (or EP8 if EP is enabled.)
if args.enable_ep:
ffn_tp_size, ffn_ep_size = 1, args.dp_size * args.tp_size
else:
ffn_tp_size, ffn_ep_size = args.dp_size * args.tp_size, 1
args.ffn_tp_size = ffn_tp_size
args.ffn_ep_size = ffn_ep_size
return args
class InterleaveMoeLayerStepParser(Parser):
"""
Parses interleave_moe_layer_step field for models like Llama4.
Overrides: num_moe_layers
"""
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
cfg = vllm_config.model_config.hf_config
if hasattr(cfg, "text_config") and cfg.text_config is not None:
cfg = cfg.text_config
if (
hasattr(cfg, "interleave_moe_layer_step")
and cfg.interleave_moe_layer_step > 0
):
args.num_moe_layers = len(
[
layer
for layer in range(args.num_hidden_layers)
if (layer + 1) % cfg.interleave_moe_layer_step == 0
]
)
return args
class MoeLayerFreqParser(Parser):
"""
Parses moe_layer_freq and first_k_dense_replace fields for models like Deepseek.
Overrides: num_moe_layers
"""
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
cfg = vllm_config.model_config.hf_config
if hasattr(cfg, "text_config") and cfg.text_config is not None:
cfg = cfg.text_config
if hasattr(cfg, "moe_layer_freq") and hasattr(cfg, "first_k_dense_replace"):
args.num_moe_layers = len(
[
layer
for layer in range(args.num_hidden_layers)
if layer >= cfg.first_k_dense_replace
and layer % cfg.moe_layer_freq == 0
]
)
return args
class FfnQuantizationConfigParser(Parser):
"""
Parses quantization configuration for FFN layers.
Overrides: weight_byte_size
"""
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
cfg = vllm_config.quant_config
if cfg is None:
return args
quant_method = cfg.get_name()
if quant_method in ["fp8", "fbgemm_fp8"]:
# FIXME: This is a hacky coarse-grained fp8 quantization detection.
# (there might be more quantization methods for fp8).
# FIXME: These configs also have concept of "ignored layers" and we
# need to solve the same problem as above.
args.weight_byte_size = 1
pass
elif quant_method == "mxfp4":
# FIXME: Also has "ignored layers" issue above
args.weight_byte_size = 0.5
else:
# FIXME: Add more parsing logic for different quant methods.
raise InvalidComponent
return args
class FfnMetrics(ComponentMetrics):
# From BaseConfigParser
num_hidden_layers: int = Field(..., gt=0)
hidden_size: int = Field(..., gt=0)
activation_byte_size: int = Field(..., gt=0)
pp_size: int = Field(..., gt=0)
# From FfnParallelParser
ffn_tp_size: int = Field(..., gt=0)
ffn_ep_size: int = Field(..., gt=0)
# From BaseFfnConfigParser
intermediate_size: int = Field(..., gt=0)
num_experts: int = Field(0)
num_experts_per_tok: int = Field(1)
moe_intermediate_size: int = Field(0)
num_shared_experts: int = Field(0)
# From BaseConfigParser, can be overridden InterleaveMoeLayerStep or MoeLayerFreq
num_moe_layers: int = Field(..., ge=0)
# FIXME: might have to make this more granular
# (i.e. dense_weight_byte_size, moe_routed_weight_byte_size,
# moe_shared_weight_byte_size)
# since it can differ from byte size of other components (e.g. attn)
# and can differ even from each other.
# From BaseConfigParser, can be overridden by FfnQuantizationConfigParser
weight_byte_size: int | float = Field(..., gt=0)
@model_validator(mode="after")
def validate_moe_fields(self) -> Self:
"""Validate that MoE-related fields are properly set when num_moe_layers > 0."""
if self.num_moe_layers > 0:
assert self.num_experts, f"{self.num_experts=}"
assert self.num_experts_per_tok, f"{self.num_experts_per_tok=}"
assert self.moe_intermediate_size, f"{self.moe_intermediate_size=}"
return self
@classmethod
def component_type(cls) -> str:
return "ffn"
@classmethod
def get_parser(cls) -> ParserChain:
return ParserChain(
BaseConfigParser(),
FfnParallelParser(),
BaseFfnConfigParser(),
InterleaveMoeLayerStepParser(),
MoeLayerFreqParser(),
FfnQuantizationConfigParser(),
)
def get_num_flops_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]:
"""Calculate flops breakdown for FFN layers."""
L, D, DI = self.num_hidden_layers, self.hidden_size, self.intermediate_size
Lm, E, MI, S = (
self.num_moe_layers,
self.num_experts_per_tok,
self.moe_intermediate_size,
self.num_shared_experts,
)
T = ctx.total_num_tokens()
Ld = L - Lm
num_activated_tokens = T * E if E else 0
if per_gpu:
Ld //= self.pp_size
Lm //= self.pp_size
DI //= self.ffn_tp_size
if MI is not None:
MI //= self.ffn_tp_size
if E:
num_activated_tokens //= self.ffn_ep_size
flops = {}
# Dense FFN layers (SwiGLU: 3 linear layers: up, gate, down)
if Ld:
flops["dense_ffn"] = 2 * D * 3 * DI * T * Ld
# MoE routed experts (each token activates E experts)
if Lm and E:
flops["routed_ffn"] = 2 * D * 3 * MI * num_activated_tokens * Lm
# MoE shared experts (all S shared experts run for every token)
if Lm and S:
flops["shared_ffn"] = 2 * D * 3 * MI * S * T * Lm
return flops
def get_read_bytes_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]:
"""Calculate read memory traffic for FFN layers."""
L, D, DI = self.num_hidden_layers, self.hidden_size, self.intermediate_size
Lm, E, MI, S = (
self.num_moe_layers,
self.num_experts_per_tok,
self.moe_intermediate_size,
self.num_shared_experts,
)
T = ctx.total_num_tokens()
num_experts = self.num_experts
Ld = L - Lm
num_activated_tokens = T * E if E else 0
if per_gpu:
Ld //= self.pp_size
Lm //= self.pp_size
DI //= self.ffn_tp_size
if MI is not None:
MI //= self.ffn_tp_size
if E:
num_activated_tokens //= self.ffn_ep_size
if num_experts is not None:
num_experts //= self.ffn_ep_size
read_bytes = {}
# Dense FFN layers (3 GEMMs: up, gate, down projections + SiLU activation)
if Ld:
read_bytes["dense_up_gate_input"] = int(
T * D * self.activation_byte_size * Ld
)
read_bytes["dense_up_gate_weights"] = int(
2 * D * DI * self.weight_byte_size * Ld
)
read_bytes["dense_silu_input"] = int(
2 * T * DI * self.activation_byte_size * Ld
)
read_bytes["dense_down_input"] = int(
T * DI * self.activation_byte_size * Ld
)
read_bytes["dense_down_weights"] = int(D * DI * self.weight_byte_size * Ld)
if Lm:
# MoE routed expert reads
if E:
# FIXME: Assume perfect load balancing for now.
num_activated_experts = min(num_activated_tokens, num_experts)
read_bytes["routed_up_gate_input"] = int(
num_activated_tokens * D * self.activation_byte_size * Lm
)
read_bytes["routed_up_gate_weights"] = int(
2 * D * MI * num_activated_experts * self.weight_byte_size * Lm
)
read_bytes["routed_silu_input"] = int(
2 * num_activated_tokens * MI * self.activation_byte_size * Lm
)
read_bytes["routed_down_input"] = int(
num_activated_tokens * MI * self.activation_byte_size * Lm
)
read_bytes["routed_down_weights"] = int(
D * MI * num_activated_experts * self.weight_byte_size * Lm
)
# MoE shared expert reads
if S:
read_bytes["shared_up_gate_input"] = int(
T * D * self.activation_byte_size * Lm
)
read_bytes["shared_up_gate_weights"] = int(
2 * D * MI * S * self.weight_byte_size * Lm
)
read_bytes["shared_silu_input"] = int(
2 * T * MI * S * self.activation_byte_size * Lm
)
read_bytes["shared_down_input"] = int(
T * MI * self.activation_byte_size * Lm
)
read_bytes["shared_down_weights"] = int(
D * MI * S * self.weight_byte_size * Lm
)
return read_bytes
def get_write_bytes_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]:
"""Calculate write memory traffic for FFN layers."""
L, D, DI = self.num_hidden_layers, self.hidden_size, self.intermediate_size
Lm, E, MI, S = (
self.num_moe_layers,
self.num_experts_per_tok,
self.moe_intermediate_size,
self.num_shared_experts,
)
T = ctx.total_num_tokens()
Ld = L - Lm
num_activated_tokens = T * E if E else 0
if per_gpu:
Ld //= self.pp_size
Lm //= self.pp_size
DI //= self.ffn_tp_size
if MI is not None:
MI //= self.ffn_tp_size
if E:
num_activated_tokens //= self.ffn_ep_size
write_bytes = {}
# Dense FFN layers
if Ld:
write_bytes["dense_up_gate_output"] = int(
2 * T * DI * self.activation_byte_size * Ld
)
write_bytes["dense_silu_output"] = int(
T * DI * self.activation_byte_size * Ld
)
write_bytes["dense_down_output"] = int(
T * D * self.activation_byte_size * Ld
)
# MoE outputs
if Lm:
if E:
write_bytes["routed_up_gate_output"] = int(
2 * num_activated_tokens * MI * self.activation_byte_size * Lm
)
write_bytes["routed_silu_output"] = int(
num_activated_tokens * MI * self.activation_byte_size * Lm
)
write_bytes["routed_down_output"] = int(
num_activated_tokens * D * self.activation_byte_size * Lm
)
if S:
write_bytes["shared_up_gate_output"] = int(
2 * T * S * MI * self.activation_byte_size * Lm
)
write_bytes["shared_silu_output"] = int(
T * S * MI * self.activation_byte_size * Lm
)
write_bytes["shared_down_output"] = int(
T * S * D * self.activation_byte_size * Lm
)
return write_bytes
#### Unembed ####
class UnembedMetrics(ComponentMetrics):
# From BaseConfigParser
hidden_size: int = Field(..., gt=0)
vocab_size: int = Field(..., gt=0)
weight_byte_size: int = Field(..., gt=0)
activation_byte_size: int = Field(..., gt=0)
tp_size: int
@classmethod
def component_type(cls) -> str:
return "unembed"
@classmethod
def get_parser(cls) -> ParserChain:
return ParserChain(
BaseConfigParser(),
)
def get_num_flops_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]:
"""Calculate flops breakdown for unembedding layer."""
D, V = self.hidden_size, self.vocab_size
T = ctx.total_num_tokens()
if per_gpu:
V //= self.tp_size
return {
"unembed": 2 * T * D * V,
}
def get_read_bytes_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]:
"""Calculate read memory traffic for unembedding layer."""
D, V = self.hidden_size, self.vocab_size
T = ctx.total_num_tokens()
if per_gpu:
V //= self.tp_size
return {
"input": T * D * self.activation_byte_size,
"weight": D * V * self.weight_byte_size,
}
def get_write_bytes_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]:
"""Calculate write memory traffic for unembedding layer."""
V = self.vocab_size
T = ctx.total_num_tokens()
if per_gpu:
V //= self.tp_size
return {
"output": T * V * self.activation_byte_size,
}
#### ModelMetrics ####
class ModelMetrics:
def __init__(self, vllm_config: VllmConfig) -> None:
"""
Parse vllm_config to instantiate metrics for each component.
is_enabled() will return False if no component metrics could be instantiated.
"""
self.vllm_config = vllm_config
self.metrics: list[ComponentMetrics] = []
for metric_cls in ComponentMetrics.registered_metrics():
try:
metric = metric_cls.from_vllm_config(vllm_config)
self.metrics.append(metric)
logger.info(
"Instantiated ComponentMetrics [%s] with (%s)",
metric.component_type(),
str(metric),
)
except InvalidComponent as e:
logger.debug(
"Failed to instantiate %s from %s",
metric_cls.component_type(),
str(e),
)
def is_enabled(self) -> bool:
return len(self.metrics) > 0
def get_num_flops(self, ctx: ExecutionContext, per_gpu: bool = True) -> int:
return sum(metric.get_num_flops(ctx, per_gpu) for metric in self.metrics)
def get_read_bytes(self, ctx: ExecutionContext, per_gpu: bool = True) -> int:
return sum(metric.get_read_bytes(ctx, per_gpu) for metric in self.metrics)
def get_write_bytes(self, ctx: ExecutionContext, per_gpu: bool = True) -> int:
return sum(metric.get_write_bytes(ctx, per_gpu) for metric in self.metrics)
def get_num_flops_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]:
total = {}
for metric in self.metrics:
breakdown = metric.get_num_flops_breakdown(ctx, per_gpu)
component = metric.component_type()
prefixed = {f"{component}.{key}": val for key, val in breakdown.items()}
total.update(prefixed)
return total
def get_read_bytes_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]:
total = {}
for metric in self.metrics:
breakdown = metric.get_read_bytes_breakdown(ctx, per_gpu)
component = metric.component_type()
prefixed = {f"{component}.{key}": val for key, val in breakdown.items()}
total.update(prefixed)
return total
def get_write_bytes_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]:
total = {}
for metric in self.metrics:
breakdown = metric.get_write_bytes_breakdown(ctx, per_gpu)
component = metric.component_type()
prefixed = {f"{component}.{key}": val for key, val in breakdown.items()}
total.update(prefixed)
return total
def get_step_perf_stats_per_gpu(
self, scheduler_output: SchedulerOutput
) -> PerfStats:
"""
Calculate perf stats for the current step based on scheduled tokens.
"""
t0 = time.monotonic()
# Build a single batch context
ctx = ExecutionContext()
# Process new requests (these are in prefill phase)
for new_req in scheduler_output.scheduled_new_reqs:
req_id = new_req.req_id
num_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0)
if num_tokens == 0:
continue
# For new requests, context_len = num_computed_tokens + num_tokens
# num_computed_tokens represents previously computed tokens in the sequence
context_len = new_req.num_computed_tokens + num_tokens
ctx.add(num_tokens, context_len, is_prefill=True)
# Process cached requests (continuing requests)
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0)
if num_tokens == 0:
continue
# For cached requests, we have the current num_computed_tokens
num_computed_tokens = cached_reqs.num_computed_tokens[i]
context_len = num_computed_tokens + num_tokens
# Cached requests are typically in decode phase (num_tokens == 1)
# unless they're doing chunked prefill (num_tokens > 1)
is_prefill = num_tokens > 1
ctx.add(num_tokens, context_len, is_prefill)
num_flops_breakdown = self.get_num_flops_breakdown(ctx, True)
read_bytes_breakdown = self.get_read_bytes_breakdown(ctx, True)
write_bytes_breakdown = self.get_write_bytes_breakdown(ctx, True)
perf_stats = PerfStats(
sum(num_flops_breakdown.values()),
sum(read_bytes_breakdown.values()),
sum(write_bytes_breakdown.values()),
)
if envs.VLLM_DEBUG_MFU_METRICS:
perf_stats.debug_stats = DebugPerfStats(
time.monotonic() - t0,
ctx.num_prefill_requests,
ctx.num_decode_requests,
asdict(ctx),
num_flops_breakdown,
read_bytes_breakdown,
write_bytes_breakdown,
)
return perf_stats
#### Logging ####
class PerfMetricsDebugLogging:
def __init__(self):
self.reset()
def reset(self):
self.total_calc_duration: float = 0.0
self.total_num_prefill_requests: int = 0
self.total_num_decode_requests: int = 0
self.total_num_batches: int = 0
self.total_context_breakdown: dict[str, int] = {}
self.total_num_flops_per_gpu_breakdown: dict[str, int] = {}
self.total_read_bytes_per_gpu_breakdown: dict[str, int] = {}
self.total_write_bytes_per_gpu_breakdown: dict[str, int] = {}
def observe(self, debug_stats: DebugPerfStats) -> None:
self.total_calc_duration += debug_stats.calc_duration
self.total_num_prefill_requests += debug_stats.num_prefill_requests
self.total_num_decode_requests += debug_stats.num_decode_requests
self.total_num_batches += 1
for dst, src in zip(
[
self.total_context_breakdown,
self.total_num_flops_per_gpu_breakdown,
self.total_read_bytes_per_gpu_breakdown,
self.total_write_bytes_per_gpu_breakdown,
],
[
debug_stats.context_breakdown,
debug_stats.num_flops_per_gpu_breakdown,
debug_stats.num_read_bytes_per_gpu_breakdown,
debug_stats.num_write_bytes_per_gpu_breakdown,
],
):
assert isinstance(src, dict)
for key, val in src.items():
dst[key] = dst.get(key, 0) + val
def log(self, log_fn, log_prefix: str, delta_time: float):
# pretty print breakdowns
total_num_flops_per_gpu_breakdown = {
k: f"{v / 1e12:.1f}TF"
for k, v in self.total_num_flops_per_gpu_breakdown.items()
}
total_read_bytes_per_gpu_breakdown = {
k: f"{v / 1e9:.1f}GB"
for k, v in self.total_read_bytes_per_gpu_breakdown.items()
}
total_write_bytes_per_gpu_breakdown = {
k: f"{v / 1e9:.1f}GB"
for k, v in self.total_write_bytes_per_gpu_breakdown.items()
}
logger.debug(
"%sMFU details: %s",
log_prefix,
json.dumps(
{
"prefill_reqs": self.total_num_prefill_requests,
"decode_reqs": self.total_num_decode_requests,
"num_batches": self.total_num_batches,
"context_breakdown": self.total_context_breakdown,
"flops_breakdown": total_num_flops_per_gpu_breakdown,
"num_read_bytes_breakdown": total_read_bytes_per_gpu_breakdown,
"num_write_bytes_breakdown": (total_write_bytes_per_gpu_breakdown),
"duration": f"{delta_time:.1f}s",
"mfu_calc_overhead": (
f"{self.total_calc_duration / delta_time:.1%}"
),
},
indent=2,
),
)
class PerfMetricsLogging:
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
self.pp_size = vllm_config.parallel_config.pipeline_parallel_size
self.debug_logging: PerfMetricsDebugLogging | None = None
if envs.VLLM_DEBUG_MFU_METRICS:
self.debug_logging = PerfMetricsDebugLogging()
self.reset()
def reset(self):
self.last_log_time = time.monotonic()
self.total_num_flops_per_gpu: int = 0
self.total_read_bytes_per_gpu: int = 0
self.total_write_bytes_per_gpu: int = 0
if self.debug_logging:
self.debug_logging.reset()
def observe(self, perf_stats: PerfStats) -> None:
self.total_num_flops_per_gpu += perf_stats.num_flops_per_gpu
self.total_read_bytes_per_gpu += perf_stats.num_read_bytes_per_gpu
self.total_write_bytes_per_gpu += perf_stats.num_write_bytes_per_gpu
if self.debug_logging:
assert perf_stats.debug_stats is not None
self.debug_logging.observe(perf_stats.debug_stats)
def log(self, log_fn=logger.info, log_prefix: str = "") -> None:
if not (
self.total_num_flops_per_gpu
or self.total_read_bytes_per_gpu
or self.total_write_bytes_per_gpu
):
return
now = time.monotonic()
delta_time = now - self.last_log_time
if delta_time <= 0.0:
avg_tflops_per_gpu = 0.0
avg_gbps_per_gpu = 0.0
else:
avg_tflops_per_gpu = self.total_num_flops_per_gpu / delta_time / 1e12
avg_gbps_per_gpu = (
(self.total_read_bytes_per_gpu + self.total_write_bytes_per_gpu)
/ delta_time
/ 1e9
)
log_fn(
"%sMFU: %.1f TF/s/GPU %.1f GB/s/GPU",
log_prefix,
avg_tflops_per_gpu,
avg_gbps_per_gpu,
)
if self.debug_logging:
self.debug_logging.log(log_fn, log_prefix, delta_time)
self.reset()
## util functions
def get_required(obj: object, attr: str):
"""Get an attr from an object, or throw a InvalidComponentError if it's not set."""
if not hasattr(obj, attr):
raise InvalidComponent(f"Missing required attr {attr} in config")
return getattr(obj, attr)
def getattr_from_list(obj: object, attrs: list[str], default: object = None):
"""Try to get the first attr that exists in the object
from a list of attrs. Otherwise return None."""
for attr in attrs:
if hasattr(obj, attr):
return getattr(obj, attr)
return default
...@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any ...@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.cuda_graph import CUDAGraphStat from vllm.compilation.cuda_graph import CUDAGraphStat
from vllm.v1.metrics.perf import PerfStats
from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.spec_decode.metrics import SpecDecodingStats
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -186,6 +187,8 @@ class SchedulerStats: ...@@ -186,6 +187,8 @@ class SchedulerStats:
cudagraph_stats: CUDAGraphStat | None = None cudagraph_stats: CUDAGraphStat | None = None
perf_stats: PerfStats | None = None
@dataclass @dataclass
class RequestStateStats: class RequestStateStats:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment