speculative_decode.yaml 1.88 KB
Newer Older
raojy's avatar
raojy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# Speculative decoding benchmark configuration
# Tests reorder_batch_threshold optimization

model:
  name: "deepseek-v3"
  num_layers: 60
  num_q_heads: 128
  num_kv_heads: 1
  head_dim: 576
  kv_lora_rank: 512
  qk_nope_head_dim: 128
  qk_rope_head_dim: 64
  v_head_dim: 128

batch_specs:
  # Pure speculative decode (K-token verification)
  - "q2s1k"      # 2-token spec, 1k KV
  - "q4s1k"      # 4-token spec, 1k KV
  - "q8s1k"      # 8-token spec, 1k KV
  - "q16s1k"     # 16-token spec, 1k KV

  # Speculative with different context lengths
  - "q4s2k"      # 4-token spec, 2k KV
  - "q4s4k"      # 4-token spec, 4k KV
  - "q8s2k"      # 8-token spec, 2k KV
  - "q8s4k"      # 8-token spec, 4k KV

  # Mixed: speculative + regular decode
  - "32q4s1k"                    # 32 spec requests
  - "16q4s1k_16q1s1k"              # 16 spec + 16 regular
  - "8q8s2k_24q1s2k"               # 8 spec (8-tok) + 24 regular

  # Mixed: speculative + prefill + decode
  - "2q1k_16q4s1k_16q1s1k"         # 2 prefill + 16 spec + 16 decode
  - "4q2k_32q4s2k_32q1s2k"         # 4 prefill + 32 spec + 32 decode

  # Large batches with speculation
  - "64q4s1k"                    # 64 spec requests
  - "32q8s2k"                    # 32 spec (8-token)
  - "16q16s4k"                   # 16 spec (16-token)

# Backends that support query length > 1
backends:
  - FLASH_ATTN_MLA    # reorder_batch_threshold = 512
  - FLASHMLA          # reorder_batch_threshold = 1 (tunable)

# FlashInfer-MLA also supports uniform spec-as-decode but with different mechanism
# - FLASHINFER_MLA

# Benchmark settings
device: "cuda:0"
repeats: 10  # More repeats for statistical significance
warmup_iters: 5
profile_memory: false

# Test these threshold values for optimization
parameter_sweep:
  param_name: "reorder_batch_threshold"
  values: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
  include_auto: false
  label_format: "{backend}_threshold_{value}"