# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 """Generate the KV event density heatmap (Figure 1) for the Flash Indexer blog post. Renders a diverging heatmap showing Store vs Remove event density across 16 workers over time. Uses real Mooncake FAST'25 trace data when available, otherwise falls back to synthetic Gamma/Zipf event patterns. Prerequisites: pip3 install plotly kaleido numpy pyyaml Usage: python3 gen_heatmap.py # synthetic data python3 gen_heatmap.py --real-data PATH # real Mooncake trace """ from __future__ import annotations import argparse import gzip import json import sys from pathlib import Path import numpy as np import plotly.graph_objects as go from plotly.subplots import make_subplots sys.path.insert(0, str(Path(__file__).parent)) from plotly_dynamo import dynamo_template, load_tokens TOKENS = load_tokens() C = TOKENS["colors"] T = TOKENS["typography"] BG = C["background"]["primary"] SURFACE = C["background"]["surface"] BORDER = C["border"]["subtle"] TXT = C["text"]["primary"] TXT2 = C["text"]["secondary"] STORE = "#76b900" REMOVE = "#fac200" SANS = T["font_family"] MONO = T["font_family_mono"] OUT = Path(__file__).resolve().parent.parent / "images" def generate_kv_events( num_workers: int = 16, duration_ms: float = 100.0, rps_per_worker: float = 35.0, block_size: int = 16, cache_hit_ratio: float = 0.30, cache_capacity_blocks: int = 24, seed: int = 42, ) -> tuple[list[np.ndarray], list[np.ndarray]]: """Synthetic KV Store/Remove events. Arrivals: Gamma(alpha=1.5) -- bursty. Prompt lengths: Zipf(a=1.3) * 384 + 64 -- heavy-tailed. Low cache capacity forces frequent eviction sweeps. """ rng = np.random.default_rng(seed) gamma_alpha = 1.5 all_stores: list[np.ndarray] = [] all_removes: list[np.ndarray] = [] for _ in range(num_workers): mean_iat_ms = 1000.0 / rps_per_worker gamma_beta = gamma_alpha / mean_iat_ms stores, removes = [], [] t = rng.exponential(mean_iat_ms / 3) occupied = 0 while t < duration_ms: prompt_tokens = int(rng.zipf(1.3) * 384 + 64) prompt_tokens = min(prompt_tokens, 8192) total_blocks = max(1, prompt_tokens // block_size) new_blocks = max(1, int(total_blocks * (1 - cache_hit_ratio))) if occupied > cache_capacity_blocks: evict_n = occupied - int(cache_capacity_blocks * 0.6) evict_n = max(4, evict_n) evict_t = t + rng.uniform(0, 0.06) removes.extend(evict_t + rng.uniform(0, 0.06, size=evict_n)) occupied -= evict_n store_t = t + rng.uniform(0.12, 0.25) stores.extend(store_t + rng.uniform(0, 0.08, size=new_blocks)) occupied += new_blocks t += rng.gamma(gamma_alpha, 1.0 / gamma_beta) all_stores.append(np.sort(stores)) all_removes.append(np.sort(removes)) return all_stores, all_removes def save(fig: go.Figure, name: str, w: int = 1200, h: int = 600) -> None: """Write figure as both PNG (3x) and SVG.""" OUT.mkdir(parents=True, exist_ok=True) png = OUT / f"{name}.png" svg = OUT / f"{name}.svg" fig.write_image(str(png), width=w, height=h, scale=3) fig.write_image(str(svg), width=w, height=h) print(f" {png.name} ({w}x{h})") print(f" {svg.name} ({w}x{h})") def _ax(**kw) -> dict: """Clean axis defaults with minimal grid.""" base = { "zeroline": False, "showgrid": True, "gridcolor": BORDER, "gridwidth": 0.3, "linecolor": BORDER, "linewidth": 0.5, "tickfont": {"family": SANS, "size": 12, "color": TXT2}, "title_font": {"family": SANS, "size": 14, "color": TXT2}, } base.update(kw) return base def load_real_events( path: Path, t_start_s: float = 5.0, t_end_s: float = 10.0, ) -> tuple[int, np.ndarray, np.ndarray, np.ndarray, float, int, int]: """Load real KV events from JSON, return binned matrices for the time window. Returns (num_workers, s_mat, r_mat, bins, bin_width, total_stores, total_removes). Times are converted to ms relative to t_start. """ p = Path(path) opener = gzip.open if p.suffix == ".gz" else p.open with opener(p, "rt") as f: data = json.load(f) t_start_us = t_start_s * 1e6 t_end_us = t_end_s * 1e6 duration_ms = (t_end_s - t_start_s) * 1000 worker_ids = sorted(int(k) for k in data) n = len(worker_ids) bw = 10.0 bins = np.arange(0, duration_ms + bw, bw) nb = len(bins) - 1 s_mat = np.zeros((n, nb)) r_mat = np.zeros((n, nb)) total_stores = 0 total_removes = 0 for idx, wid in enumerate(worker_ids): events = data[str(wid)] for ev in events: ts_us = ev["timestamp_us"] if ts_us < t_start_us or ts_us > t_end_us: continue ts_ms = (ts_us - t_start_us) / 1000.0 bin_idx = min(int(ts_ms / bw), nb - 1) d = ev["event"]["data"] if "stored" in d: s_mat[idx, bin_idx] += 1 total_stores += 1 elif "removed" in d: r_mat[idx, bin_idx] += 1 total_removes += 1 return n, s_mat, r_mat, bins, bw, total_stores, total_removes def make_heatmap( stores: list[np.ndarray], removes: list[np.ndarray], real_data_path: Path | None = None, ) -> None: if real_data_path is not None: n, s_mat, r_mat, bins, bw, ts_count, tr_count = load_real_events(real_data_path) nb = len(bins) - 1 duration_ms = bins[-1] total_events = ts_count + tr_count subtitle = f"Mooncake trace (5%) \u00b7 16 Mocker workers \u00b7 2048 GPU blocks/worker \u00b7 {total_events:,} events in 5.0 s" else: n = len(stores) bw = 0.5 duration_ms = 100.0 bins = np.arange(0, duration_ms + bw, bw) nb = len(bins) - 1 s_mat = np.zeros((n, nb)) r_mat = np.zeros((n, nb)) for i in range(n): if len(stores[i]): s_mat[i] = np.histogram(stores[i], bins)[0] if len(removes[i]): r_mat[i] = np.histogram(removes[i], bins)[0] ts_count = sum(len(s) for s in stores) tr_count = sum(len(r) for r in removes) subtitle = f"16 workers \u00b7 TP1 \u00b7 block_size 16 \u00b7 35 RPS/worker \u00b7 30% cache hit \u00b7 {(ts_count + tr_count) // 1000}K events in 100 ms" combined = s_mat - r_mat total = combined.sum(axis=0) disp_pw = np.where(s_mat >= r_mat, s_mat, -r_mat).astype(float) disp_pw = np.where((s_mat == 0) & (r_mat == 0), 0.0, disp_pw) s_total = s_mat.sum(axis=0) r_total = r_mat.sum(axis=0) disp_tot = np.where(s_total >= r_total, s_total, -r_total).astype(float) disp_tot = np.where((s_total == 0) & (r_total == 0), 0.0, disp_tot) pw_zmax = 10.0 colorscale = [ [0.0, REMOVE], [0.15, "#c89a00"], [0.5, BG], [0.85, "#5aaa00"], [1.0, STORE], ] fig = make_subplots( rows=2, cols=1, shared_xaxes=True, row_heights=[0.8, 0.2], vertical_spacing=0.05, ) fig.update_layout(template=dynamo_template) tick_vals = [-10, -5, 0, 5, 10] tick_text = ["\u221210", "\u22125", "0", "5", "10"] fig.add_trace( go.Heatmap( z=disp_pw, x0=0, dx=bw, y0=0, dy=1, colorscale=colorscale, zmid=0, zmin=-pw_zmax, zmax=pw_zmax, colorbar={ "title": {"text": ""}, "tickfont": {"family": SANS, "size": 12, "color": TXT2}, "thickness": 12, "len": 0.62, "y": 0.62, "x": 1.027, "tickvals": tick_vals, "ticktext": tick_text, }, customdata=combined, hovertemplate="W%{y} \u00b7 t=%{x:.0f}ms
net %{customdata:.0f}", xgap=0.5, ygap=0.5, ), row=1, col=1, ) tot_zmax = 100.0 tick_vals_tot = [-100, 0, 100] tick_text_tot = ["\u2212100", "0", "100"] fig.add_trace( go.Heatmap( z=[disp_tot], x0=0, dx=bw, colorscale=colorscale, zmid=0, zmin=-tot_zmax, zmax=tot_zmax, colorbar={ "title": {"text": ""}, "tickfont": {"family": SANS, "size": 12, "color": TXT2}, "thickness": 12, "len": 0.16, "y": 0.095, "x": 1.027, "tickvals": tick_vals_tot, "ticktext": tick_text_tot, }, customdata=[total], hovertemplate="t=%{x:.0f}ms
\u03a3 net %{customdata:.0f}", xgap=0.5, showscale=True, ), row=2, col=1, ) fig.update_layout( title={"text": ""}, margin={"l": 65, "r": 120, "t": 110, "b": 55}, plot_bgcolor=BG, paper_bgcolor=BG, ) fig.add_annotation( text="KV Events", xref="paper", yref="paper", x=1.015, y=0.62, xanchor="center", yanchor="middle", textangle=-90, font={"family": SANS, "size": 13, "color": TXT2}, showarrow=False, ) fig.add_annotation( text="KV Events", xref="paper", yref="paper", x=1.015, y=0.095, xanchor="center", yanchor="middle", textangle=-90, font={"family": SANS, "size": 13, "color": TXT2}, showarrow=False, ) fig.add_annotation( text="Store", xref="paper", yref="paper", x=1.06, y=0.935, xanchor="center", yanchor="bottom", font={"family": SANS, "size": 13, "color": TXT2}, showarrow=False, ) fig.add_annotation( text="Remove", xref="paper", yref="paper", x=1.06, y=0.305, xanchor="center", yanchor="top", font={"family": SANS, "size": 13, "color": TXT2}, showarrow=False, ) fig.add_annotation( text="Store", xref="paper", yref="paper", x=1.06, y=0.18, xanchor="center", yanchor="bottom", font={"family": SANS, "size": 13, "color": TXT2}, showarrow=False, ) fig.add_annotation( text="Remove", xref="paper", yref="paper", x=1.06, y=0.005, xanchor="center", yanchor="top", font={"family": SANS, "size": 13, "color": TXT2}, showarrow=False, ) fig.add_annotation( text="KV CACHE EVENT DENSITY", xref="paper", yref="paper", x=0.0, y=1.16, xanchor="left", yanchor="bottom", font={"family": SANS, "size": 18, "color": TXT}, showarrow=False, ) fig.add_annotation( text=subtitle, xref="paper", yref="paper", x=0.0, y=1.09, xanchor="left", yanchor="bottom", font={"family": SANS, "size": 12, "color": TXT2}, showarrow=False, ) fig.add_annotation( text=( f"\u25a0 Store (prefill) " f"\u25a0 Remove (eviction)" ), xref="paper", yref="paper", x=1.0131, y=1.09, xanchor="right", yanchor="bottom", font={"family": SANS, "size": 13, "color": TXT2}, showarrow=False, ) fig.update_xaxes(row=1, col=1, **_ax(range=[0, duration_ms])) fig.update_xaxes(row=2, col=1, **_ax(title="Time (ms)", range=[0, duration_ms])) fig.update_yaxes( row=1, col=1, **_ax(title="Worker", tickvals=list(range(0, 16, 5)) + [15], showgrid=False), ) fig.update_yaxes( row=2, col=1, **_ax( title="\u03a3 Workers", showgrid=False, tickvals=[], showticklabels=False, showline=True, linecolor=BORDER, linewidth=0.5, ), ) save(fig, "fig-1-kv-event-density", w=950, h=580) def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--real-data", type=Path, default=None, help="Path to kv_events JSON or .json.gz (Mooncake trace). Uses synthetic data if omitted.", ) args = parser.parse_args() print("Generating synthetic KV events (Gamma arrivals \u00d7 Zipf bursts)...") stores, removes = generate_kv_events() ts = sum(len(s) for s in stores) tr = sum(len(r) for r in removes) print(f" {ts:,} Store events, {tr:,} Remove events over 100ms\n") print("Rendering heatmap...") if args.real_data and args.real_data.exists(): print(f" (using real event data from {args.real_data})") make_heatmap(stores, removes, real_data_path=args.real_data) else: if args.real_data: print(f" WARNING: {args.real_data} not found, using synthetic data") make_heatmap(stores, removes) print(f"\nDone \u2192 {OUT}") if __name__ == "__main__": main()