## Attention Prefill
--routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 fa3 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag "Llama-3.1-70B"

# Ragged prefill for DeepSeep-R1
--routine BatchPrefillWithRaggedKVCacheWrapper --backends fa2 fa3 cutlass cudnn --batch_size 16 --s_qo 1024 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_qk 192 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag "DeepSeek-R1"

## Attention Decode
--routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 fa2_tc cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 16 --s_qo 1 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag "Llama-3.1-70B"

## Attention MLA
# DeepSeek-R1
--routine BatchMLAPagedAttentionWrapper --backends trtllm-gen-native fa2 fa3 --page_size 32 --batch_size 16 --s_qo 1 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_ckv 512 --head_dim_kpe 64 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --generate_repro_command --case_tag "DeepSeek-R1"

## FP8 bmm
--routine bmm_fp8 --batch_size 64 --m 4 --n 1024 --k 7168 --input_dtype fp8_e4m3 --mat2_dtype fp8_e4m3 --out_dtype bfloat16 --backends cudnn cublas cutlass --refcheck -vv --generate_repro_command

## FP8 GEMM with groupwise scaling
--routine gemm_fp8_nt_groupwise --m 16 --n 1024 --k 7168 --mma_sm 1 --scale_major_mode MN --backends cutlass --refcheck -vv --generate_repro_command

## FP8 group GEMM with groupwise scaling
--routine group_gemm_fp8_nt_groupwise --m 16 --n 1024 --k 7168 --mma_sm 1 --group_size 2 --scale_major_mode MN --refcheck -vv --generate_repro_command

## FP4 GEMM
# non-autotuned
--routine mm_fp4 --m 512 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --refcheck -vv --generate_repro_command

# autotuned
--routine mm_fp4 --m 512 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --autotune --refcheck -vv --generate_repro_command

## MoE
--routine trtllm_fp4_block_scale_moe --num_tokens 1024 --hidden_size 1024 --intermediate_size 1024 --num_experts 256 --top_k 8 --n_group 8 --topk_group 4 --routed_scaling_factor 2.5 --use_routing_bias --routing_method deepseek_v3 --use_shuffled_weight -vv --generate_repro_command --case_tag "trtllm_moe_sample"
--routine trtllm_fp4_block_scale_moe --num_tokens 1024 --hidden_size 1024 --intermediate_size 1024 --num_experts 128 --top_k 8 --routing_method renormalize_naive --use_shuffled_weight -vv --generate_repro_command --case_tag "trtllm_moe_sample"
--routine trtllm_fp8_block_scale_moe --num_tokens 1024 --hidden_size 1024 --intermediate_size 1024 --num_experts 256 --top_k 8 --n_group 8 --topk_group 4 --routed_scaling_factor 2.5 --use_routing_bias --routing_method deepseek_v3 --use_shuffled_weight -vv --generate_repro_command --case_tag "trtllm_moe_sample"
--routine trtllm_fp8_per_tensor_scale_moe --num_tokens 1024 --hidden_size 1024 --intermediate_size 1024 --num_experts 128 --top_k 1 --routed_scaling_factor 2.5 --use_routing_bias --routing_method llama4 --use_routing_scales_on_input -vv --generate_repro_command --case_tag "trtllm_moe_sample"
--routine trtllm_fp8_block_scale_moe --num_tokens 1024 --hidden_size 1024 --intermediate_size 1024 --num_experts 128 --top_k 1 --routing_method renormalize --use_shuffled_weight -vv --generate_repro_command --case_tag "trtllm_moe_sample"

# CUTLASS MoE API
--routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 2 --top_k 2 --cutlass_variant base --input_dtype float16 -vv --generate_repro_command --case_tag "cutlass_moe_base"
--routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 2 --top_k 2 --cutlass_variant fp8 --input_dtype float16 -vv --generate_repro_command --case_tag "cutlass_moe_fp8_scale"
--routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 2 --top_k 2 --cutlass_variant nvfp4 --input_dtype float16 -vv --generate_repro_command --case_tag "cutlass_moe_nvfp4_weights"
--routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 2 --top_k 2 --cutlass_variant nvfp4 --quantized_input --input_dtype float16 -vv --generate_repro_command --case_tag "cutlass_moe_nvfp4_weights_quantized"
--routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 8 --top_k 2 --cutlass_variant base --input_dtype float16 --tp_size 2 --tp_rank 0 --ep_size 4 --ep_rank 0 -vv --generate_repro_command --case_tag "cutlass_moe_nvfp4_ep_tp"

## MoE Communication (requires mpirun, e.g.: mpirun -np 8 python benchmarks/flashinfer_benchmark.py ...)
# Basic A2A dispatch+combine without quantization
#--routine moe_a2a_dispatch_combine --num_tokens 1024 --hidden_size 7168 --num_experts 256 --top_k 8 -vv --generate_repro_command --case_tag "moe_a2a_basic"
# With FP8 per-tensor quantization
#--routine moe_a2a_dispatch_combine --num_tokens 1024 --hidden_size 7168 --num_experts 256 --top_k 8 --quant_dtype fp8 -vv --generate_repro_command --case_tag "moe_a2a_fp8"
# With NVFP4 block-scale quantization
#--routine moe_a2a_dispatch_combine --num_tokens 1024 --hidden_size 7168 --num_experts 256 --top_k 8 --quant_dtype nvfp4 -vv --generate_repro_command --case_tag "moe_a2a_nvfp4"
# With FP8 block-scale quantization
#--routine moe_a2a_dispatch_combine --num_tokens 1024 --hidden_size 7168 --num_experts 256 --top_k 8 --quant_dtype fp8_block_scale -vv --generate_repro_command --case_tag "moe_a2a_fp8_block_scale"
# With real MoE kernel (NVFP4)
#--routine moe_a2a_dispatch_combine --num_tokens 1024 --hidden_size 7168 --num_experts 256 --top_k 8 --quant_dtype nvfp4 --real_math --intermediate_size 18432 --per_phase_timing -vv --generate_repro_command --case_tag "moe_a2a_nvfp4_real_math"
# With real MoE kernel (FP8 block-scale)
#--routine moe_a2a_dispatch_combine --num_tokens 1024 --hidden_size 7168 --num_experts 256 --top_k 8 --quant_dtype fp8_block_scale --real_math --intermediate_size 18432 --per_phase_timing -vv --generate_repro_command --case_tag "moe_a2a_fp8_bs_real_math"

## RMSNorm
# Basic RMSNorm with 2D input shape
--routine rmsnorm --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag "rmsnorm_llama_hidden"
--routine rmsnorm --batch_size 64 --hidden_size 8192 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag "rmsnorm_large_hidden"

# RMSNorm with 3D input shape (batch, num_heads, head_dim)
--routine rmsnorm --batch_size 32 --num_heads 32 --hidden_size 128 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag "rmsnorm_3d_gqa"
--routine rmsnorm --batch_size 16 --num_heads 64 --hidden_size 128 --input_dtype float16 --refcheck -vv --generate_repro_command --case_tag "rmsnorm_3d_mha"

# RMSNorm with PDL enabled
--routine rmsnorm --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --enable_pdl --refcheck -vv --generate_repro_command --case_tag "rmsnorm_pdl"

## RMSNorm with Quantized Output
# RMSNorm with FP8 e4m3 output
--routine rmsnorm_quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype fp8_e4m3 --scale 1.0 --refcheck -vv --generate_repro_command --case_tag "rmsnorm_quant_fp8_e4m3"
--routine rmsnorm_quant --batch_size 64 --hidden_size 8192 --input_dtype bfloat16 --out_dtype fp8_e4m3 --scale 1.0 --refcheck -vv --generate_repro_command --case_tag "rmsnorm_quant_large"

# RMSNorm with FP8 e5m2 output
--routine rmsnorm_quant --batch_size 32 --hidden_size 4096 --input_dtype float16 --out_dtype fp8_e5m2 --scale 1.0 --refcheck -vv --generate_repro_command --case_tag "rmsnorm_quant_fp8_e5m2"

## Fused Add + RMSNorm with Quantized Output
# Fused add + RMSNorm with FP8 e4m3 output
--routine fused_add_rmsnorm_quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype fp8_e4m3 --scale 1.0 --refcheck -vv --generate_repro_command --case_tag "fused_add_rmsnorm_quant_fp8_e4m3"
--routine fused_add_rmsnorm_quant --batch_size 64 --hidden_size 8192 --input_dtype bfloat16 --out_dtype fp8_e4m3 --scale 1.0 --refcheck -vv --generate_repro_command --case_tag "fused_add_rmsnorm_quant_large"

# Fused add + RMSNorm with PDL enabled
--routine fused_add_rmsnorm_quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype fp8_e4m3 --scale 1.0 --enable_pdl --refcheck -vv --generate_repro_command --case_tag "fused_add_rmsnorm_quant_pdl"

## RMSNorm with FP4 Quantization (Blackwell SM10.0+ only, cute-dsl backend)
# NVFP4 format (block_size=16, e4m3 scale factors) - nvfp4 is default out_dtype
--routine rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "rmsnorm_fp4quant_nvfp4"
--routine rmsnorm_fp4quant --batch_size 64 --hidden_size 8192 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "rmsnorm_fp4quant_nvfp4_large"

# NVFP4 with global scale
--routine rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --use_global_scale -vv --generate_repro_command --case_tag "rmsnorm_fp4quant_nvfp4_global"

# NVFP4 with swizzled scale factor layout for tensor core GEMM
--routine rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --is_sf_swizzled_layout -vv --generate_repro_command --case_tag "rmsnorm_fp4quant_nvfp4_swizzled"

# MXFP4 format (block_size=32, ue8m0 scale factors)
--routine rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype mxfp4 -vv --generate_repro_command --case_tag "rmsnorm_fp4quant_mxfp4"

# 3D input shape (batch, num_heads, head_dim)
--routine rmsnorm_fp4quant --batch_size 32 --num_heads 32 --hidden_size 128 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "rmsnorm_fp4quant_3d"

## Fused Add + RMSNorm with FP4 Quantization (Blackwell SM10.0+ only, cute-dsl backend)
# NVFP4 format (block_size=16, e4m3 scale factors) - nvfp4 is default out_dtype
--routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_nvfp4"
--routine add_rmsnorm_fp4quant --batch_size 64 --hidden_size 8192 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_nvfp4_large"

# NVFP4 with global scale
--routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --use_global_scale -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_nvfp4_global"

# NVFP4 with swizzled scale factor layout for tensor core GEMM
--routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --is_sf_swizzled_layout -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_nvfp4_swizzled"

# MXFP4 format (block_size=32, ue8m0 scale factors)
--routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype mxfp4 -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_mxfp4"

# 3D input shape (batch, num_heads, head_dim)
--routine add_rmsnorm_fp4quant --batch_size 32 --num_heads 32 --hidden_size 128 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_3d"

# Output both swizzled and unswizzled scale factors (for dual-use scenarios)
--routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --output_both_sf_layouts -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_both_sf"
--routine add_rmsnorm_fp4quant --batch_size 64 --hidden_size 8192 --input_dtype bfloat16 --output_both_sf_layouts -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_both_sf_large"

# Both SF layouts with global scale
--routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --use_global_scale --output_both_sf_layouts -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_both_sf_global"

# Both SF layouts with MXFP4 format
--routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype mxfp4 --output_both_sf_layouts -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_mxfp4_both_sf"

## Quantization (Blackwell SM10.0+ only)
# MxFP8 Quantization - basic
--routine mxfp8_quantize --m 1024 --k 4096 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "mxfp8_quantize_basic"
--routine mxfp8_quantize --m 2048 --k 8192 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "mxfp8_quantize_large"

# MxFP8 Quantization - float16 input
--routine mxfp8_quantize --m 1024 --k 4096 --input_dtype float16 -vv --generate_repro_command --case_tag "mxfp8_quantize_fp16"

# MxFP8 Quantization - with swizzled layout disabled
--routine mxfp8_quantize --m 1024 --k 4096 --input_dtype bfloat16 --no_sf_swizzled_layout -vv --generate_repro_command --case_tag "mxfp8_quantize_no_swizzle"

# MxFP8 Quantization - with PDL enabled
--routine mxfp8_quantize --m 2048 --k 8192 --input_dtype bfloat16 --enable_pdl -vv --generate_repro_command --case_tag "mxfp8_quantize_pdl"

# MxFP8 Quantization - with refcheck (round-trip verification)
--routine mxfp8_quantize --m 1024 --k 4096 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag "mxfp8_quantize_refcheck"

# MxFP4 Quantization (Blackwell SM10.0+ only)
--routine mxfp4_quantize --m 1024 --k 4096 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "mxfp4_quantize_basic"
--routine mxfp4_quantize --m 2048 --k 8192 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "mxfp4_quantize_large"
--routine mxfp4_quantize --m 1024 --k 4096 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag "mxfp4_quantize_refcheck"

# NVFP4 Quantization (Blackwell SM10.0+ only)
# With 128x4 layout (default, for large tileN GEMMs)
--routine nvfp4_quantize --m 1024 --k 4096 --input_dtype bfloat16 --global_scale 1.0 --sf_layout 128x4 -vv --generate_repro_command --case_tag "nvfp4_quantize_128x4"
--routine nvfp4_quantize --m 2048 --k 8192 --input_dtype bfloat16 --global_scale 1.0 --sf_layout 128x4 -vv --generate_repro_command --case_tag "nvfp4_quantize_128x4_large"

# With 8x4 layout (for small tileN GEMMs)
--routine nvfp4_quantize --m 1024 --k 4096 --input_dtype bfloat16 --global_scale 1.0 --sf_layout 8x4 -vv --generate_repro_command --case_tag "nvfp4_quantize_8x4"

# With shuffle for TRTLLM backend
--routine nvfp4_quantize --m 1024 --k 4096 --input_dtype bfloat16 --global_scale 1.0 --do_shuffle -vv --generate_repro_command --case_tag "nvfp4_quantize_shuffle"

# With PDL enabled
--routine nvfp4_quantize --m 1024 --k 4096 --input_dtype bfloat16 --global_scale 1.0 --enable_pdl -vv --generate_repro_command --case_tag "nvfp4_quantize_pdl"

# NVFP4 Batched Quantization (Blackwell SM10.0+ only)
--routine nvfp4_batched_quantize --batch_size 4 --m 1024 --k 4096 --input_dtype bfloat16 --global_scale 1.0 -vv --generate_repro_command --case_tag "nvfp4_batched_basic"
--routine nvfp4_batched_quantize --batch_size 8 --m 2048 --k 8192 --input_dtype bfloat16 --global_scale 1.0 -vv --generate_repro_command --case_tag "nvfp4_batched_large"
--routine nvfp4_batched_quantize --batch_size 4 --m 1024 --k 4096 --input_dtype float16 --global_scale 1.0 -vv --generate_repro_command --case_tag "nvfp4_batched_fp16"

## Sampling
# Basic softmax with temperature
--routine softmax --batch_size 32 --vocab_size 32000 --temperature 1.0 --input_dtype float32 -vv --generate_repro_command --case_tag "softmax_llama"
--routine softmax --batch_size 64 --vocab_size 128256 --temperature 0.8 --input_dtype float32 -vv --generate_repro_command --case_tag "softmax_llama3_temp"

# Sampling from probabilities
--routine sampling_from_probs --batch_size 32 --vocab_size 32000 -vv --generate_repro_command --case_tag "sampling_from_probs_llama"
--routine sampling_from_probs --batch_size 64 --vocab_size 128256 -vv --generate_repro_command --case_tag "sampling_from_probs_llama3"

# Sampling from logits (fused softmax + sampling)
--routine sampling_from_logits --batch_size 32 --vocab_size 32000 --input_dtype float32 -vv --generate_repro_command --case_tag "sampling_from_logits_llama"
--routine sampling_from_logits --batch_size 64 --vocab_size 128256 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "sampling_from_logits_llama3"

# Top-K sampling
--routine top_k_sampling_from_probs --batch_size 32 --vocab_size 32000 --top_k 50 -vv --generate_repro_command --case_tag "top_k_sampling_k50"
--routine top_k_sampling_from_probs --batch_size 32 --vocab_size 128256 --top_k 100 -vv --generate_repro_command --case_tag "top_k_sampling_k100"

# Top-P (nucleus) sampling
--routine top_p_sampling_from_probs --batch_size 32 --vocab_size 32000 --top_p 0.9 -vv --generate_repro_command --case_tag "top_p_sampling_p09"
--routine top_p_sampling_from_probs --batch_size 32 --vocab_size 128256 --top_p 0.95 -vv --generate_repro_command --case_tag "top_p_sampling_p095"

# Combined Top-K + Top-P sampling
--routine top_k_top_p_sampling_from_probs --batch_size 32 --vocab_size 32000 --top_k 50 --top_p 0.9 --filter_apply_order top_k_first -vv --generate_repro_command --case_tag "top_k_top_p_probs"
--routine top_k_top_p_sampling_from_logits --batch_size 32 --vocab_size 32000 --top_k 50 --top_p 0.9 --filter_apply_order top_k_first --input_dtype float32 -vv --generate_repro_command --case_tag "top_k_top_p_logits"

# Min-P sampling
--routine min_p_sampling_from_probs --batch_size 32 --vocab_size 32000 --min_p 0.1 -vv --generate_repro_command --case_tag "min_p_sampling_p01"
--routine min_p_sampling_from_probs --batch_size 32 --vocab_size 128256 --min_p 0.05 -vv --generate_repro_command --case_tag "min_p_sampling_p005"

# Top-K renormalization
--routine top_k_renorm_probs --batch_size 32 --vocab_size 32000 --top_k 50 --input_dtype float32 -vv --generate_repro_command --case_tag "top_k_renorm"

# Top-P renormalization
--routine top_p_renorm_probs --batch_size 32 --vocab_size 32000 --top_p 0.9 -vv --generate_repro_command --case_tag "top_p_renorm"

# Top-K mask logits
--routine top_k_mask_logits --batch_size 32 --vocab_size 32000 --top_k 50 --input_dtype float32 -vv --generate_repro_command --case_tag "top_k_mask"

# Chain speculative sampling (for speculative decoding)
--routine chain_speculative_sampling --batch_size 16 --vocab_size 32000 --num_speculate_tokens 5 -vv --generate_repro_command --case_tag "chain_spec_sampling_5"
--routine chain_speculative_sampling --batch_size 32 --vocab_size 128256 --num_speculate_tokens 8 -vv --generate_repro_command --case_tag "chain_spec_sampling_8"

# Top-K selection (radix-based)
--routine top_k --batch_size 32 --vocab_size 32000 --top_k 50 --input_dtype float32 -vv --generate_repro_command --case_tag "top_k_radix"
--routine top_k --batch_size 64 --vocab_size 128256 --top_k 100 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "top_k_radix_large"

# Top-K with page table transform
--routine top_k_page_table_transform --batch_size 16 --num_rows 16 --max_len 4096 --top_k 64 --input_dtype float32 -vv --generate_repro_command --case_tag "top_k_page_table"

# Top-K with ragged transform
--routine top_k_ragged_transform --batch_size 16 --num_rows 16 --max_len 4096 --top_k 64 --input_dtype float32 -vv --generate_repro_command --case_tag "top_k_ragged"

## RoPE (Rotary Positional Embeddings)
# Basic RoPE with indptr/offsets
--routine apply_rope --batch_size 16 --seq_len 1024 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --input_dtype float16 -vv --generate_repro_command --case_tag "apply_rope_llama"
--routine apply_rope --batch_size 32 --seq_len 2048 --num_qo_heads 64 --num_kv_heads 8 --head_dim 128 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "apply_rope_llama70b"

# RoPE with position IDs
--routine apply_rope_pos_ids --batch_size 16 --seq_len 1024 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --input_dtype float16 -vv --generate_repro_command --case_tag "apply_rope_pos_ids"
--routine apply_rope_pos_ids --batch_size 32 --seq_len 2048 --num_qo_heads 64 --num_kv_heads 8 --head_dim 128 --input_dtype bfloat16 --interleave -vv --generate_repro_command --case_tag "apply_rope_pos_ids_interleave"

# Llama 3.1 style RoPE
--routine apply_llama31_rope --batch_size 16 --seq_len 1024 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --rope_theta 500000.0 --rope_scale 1.0 --low_freq_factor 1.0 --high_freq_factor 4.0 --old_context_len 8192 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "apply_llama31_rope"
--routine apply_llama31_rope_pos_ids --batch_size 16 --seq_len 1024 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --rope_theta 500000.0 --rope_scale 1.0 --low_freq_factor 1.0 --high_freq_factor 4.0 --old_context_len 8192 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "apply_llama31_rope_pos_ids"

# RoPE with cos/sin cache
--routine apply_rope_with_cos_sin_cache --batch_size 16 --seq_len 1024 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --input_dtype float16 -vv --generate_repro_command --case_tag "apply_rope_cos_sin_cache"
--routine apply_rope_with_cos_sin_cache --batch_size 32 --seq_len 2048 --num_qo_heads 64 --num_kv_heads 8 --head_dim 128 --input_dtype bfloat16 --interleave -vv --generate_repro_command --case_tag "apply_rope_cos_sin_cache_interleave"

# MLA RoPE with FP8 quantization (SM8.9+ required)
--routine mla_rope_quantize_fp8 --batch_size 16 --seq_len 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim 192 --no_rope_dim 64 --input_dtype bfloat16 --quant_dtype fp8_e4m3 -vv --generate_repro_command --case_tag "mla_rope_fp8_deepseek"

# RoPE with FP8 quantization (SM8.9+ required)
--routine rope_quantize_fp8 --batch_size 16 --seq_len 1024 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --input_dtype bfloat16 --quant_dtype fp8_e4m3 -vv --generate_repro_command --case_tag "rope_fp8_llama"
--routine rope_quantize_fp8 --batch_size 32 --seq_len 2048 --num_qo_heads 64 --num_kv_heads 8 --head_dim 128 --input_dtype bfloat16 --quant_dtype fp8_e4m3 -vv --generate_repro_command --case_tag "rope_fp8_llama70b"

# RoPE with FP8 quantization and paged KV cache append (SM8.9+ required)
--routine rope_quantize_fp8_append_paged_kv_cache --batch_size 16 --seq_len 64 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --page_size 16 --kv_layout NHD --input_dtype bfloat16 --quant_dtype fp8_e4m3 -vv --generate_repro_command --case_tag "rope_fp8_paged_kv"
--routine rope_quantize_fp8_append_paged_kv_cache --batch_size 32 --seq_len 64 --num_qo_heads 64 --num_kv_heads 8 --head_dim 128 --page_size 16 --kv_layout HND --input_dtype bfloat16 --quant_dtype fp8_e4m3 -vv --generate_repro_command --case_tag "rope_fp8_paged_kv_hnd"

## Mamba (Selective State Space Models)
# Single-token prediction (STP), Mamba-2 style (nheads/ngroups=8)
--routine selective_state_update --batch_size 64 --nheads 64 --dim 128 --dstate 128 --ngroups 8 --input_dtype bfloat16 --backends flashinfer triton --refcheck -vv --generate_repro_command --case_tag "mamba2_stp_bf16"

# STP with float32 state
--routine selective_state_update --batch_size 64 --nheads 64 --dim 128 --dstate 128 --ngroups 8 --state_dtype float32 --backends flashinfer triton --refcheck -vv --generate_repro_command --case_tag "mamba2_stp_state_fp32"

# STP with z gating enabled
--routine selective_state_update --batch_size 64 --nheads 64 --dim 128 --dstate 128 --ngroups 8 --has_z --backends flashinfer triton --refcheck -vv --generate_repro_command --case_tag "mamba2_stp_with_z"

# STP with dt_softplus enabled
--routine selective_state_update --batch_size 64 --nheads 64 --dim 128 --dstate 128 --ngroups 8 --dt_softplus --backends flashinfer triton --refcheck -vv --generate_repro_command --case_tag "mamba2_stp_softplus"

# STP with z gating and dt_softplus
--routine selective_state_update --batch_size 64 --nheads 64 --dim 128 --dstate 128 --ngroups 8 --has_z --dt_softplus --backends flashinfer triton --refcheck -vv --generate_repro_command --case_tag "mamba2_stp_z_softplus"

# STP with nheads/ngroups=1 (all heads in one group)
--routine selective_state_update --batch_size 32 --nheads 16 --dim 128 --dstate 128 --ngroups 16 --backends flashinfer triton --refcheck -vv --generate_repro_command --case_tag "mamba2_stp_ratio1"

# STP with nheads/ngroups=16
--routine selective_state_update --batch_size 32 --nheads 128 --dim 128 --dstate 128 --ngroups 8 --backends flashinfer triton --refcheck -vv --generate_repro_command --case_tag "mamba2_stp_ratio16"

# Multi-token prediction (MTP), cache_steps=1
--routine selective_state_update --batch_size 64 --nheads 64 --dim 128 --dstate 128 --ngroups 8 --cache_steps 1 --backends flashinfer triton --refcheck -vv --generate_repro_command --case_tag "mamba2_mtp1"

# MTP with cache_steps=2
--routine selective_state_update --batch_size 64 --nheads 64 --dim 128 --dstate 128 --ngroups 8 --cache_steps 2 --backends flashinfer triton --refcheck -vv --generate_repro_command --case_tag "mamba2_mtp2"

# Large batch size STP
--routine selective_state_update --batch_size 256 --nheads 64 --dim 128 --dstate 128 --ngroups 8 --backends flashinfer triton -vv --generate_repro_command --case_tag "mamba2_stp_large_batch"

# FlashInfer-only (no refcheck, perf focus)
--routine selective_state_update --batch_size 128 --nheads 64 --dim 128 --dstate 128 --ngroups 8 --backends flashinfer -vv --generate_repro_command --case_tag "mamba2_stp_perf"
