speculative.md 6.55 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Speculative Decoding

[Medusa](https://arxiv.org/abs/2401.10774) and [EAGLE](https://arxiv.org/pdf/2401.15077) 
training and model export are supported (fast decoding is supported through TensorRT-LLM).
To run the examples, follow [README.md](README.md) to setup the containerized environment
and `NGC_CLI_API_KEY`, then
```sh
TP=8 bash medusa_sft.sh meta-llama/Llama-3.1-8B-Instruct
```
EAGLE training is similar. Just replace `medusa_sft.sh` with `eagle_sft.sh`
(requires `nvidia-modelopt>=0.20.0`).

Medusa head top-1 accuracy is reported per step (**NOTE:** the accuracy here does not
translate to the acceptance rate described in the writeup. The top-1 of the 1st head
can however signal whether the training is converged). By the end of the example, the
end results are stored in the following locations.
```sh
/tmp/megatron_workspace/meta-llama/
├── Llama-3.1-8B-Instruct_medusa
│   ├── iter_0000001
│   └── ...
├── Llama-3.1-8B-Instruct_medusa_quant
│   ├── iter_0000001
│   └── ...
└── Llama-3.1-8B-Instruct_medusa_quant_trtllm_export
```
`Llama-3.1-8B-Instruct_medusa_quant_trtllm_export` is the TensorRT-LLM checkpoint. To
deploy, check the TensorRT-LLM section below. 

> **IMPORTANT:** The sample flow `medusa_sft.sh` does not contain synthetic data generation.
> To achieve the best acceptance rate, check the whole receipt and options in the following sections.

## Table of Contents

[[_TOC_]]

## Training and Export Workflow

In practice, speculative decoding should be combined with quantization (weights and kv-cache)
to achieve the the highest tokens-per-second-per-user (or TPS) without changing the quality of
the model. We provide quantization-aware training (QAT) receipt with self-distillation in the following.


### Model Convertion

To ensure no quality degredation, base model is frozen and the draft model is attached as a
transformation. By providing `--export-num-medusa-heads` or `--export-num-eagle-layers`,
the resulting model stored in `${MLM_MODEL_SAVE}` will have randomly initialized draft model weights.

```
python examples/post_training_opt/convert_gpt.py \
    --export-num-medusa-heads 4 \
    --load ${MLM_MODEL_CKPT} --save ${MLM_MODEL_SAVE} ${OTHER_MLM_ARGS}
```

> **NOTE:** `MLM_MODEL_SAVE=Llama-3.1-8B-Instruct_medusa` in the example.

### Synthetic Data Generation

Rather than learning the language and syntax, the draft model is trained to mimic the base
model output. As a result, self-synthesized data is crucial for the draft model accuracy
and acceptance rate (AR). In EAGLE training, hidden state and logits distillation are also
applied.

For simplicity and efficiency, we use `vllm serve --quantization modelopt` to host an quantized
endpoint and we feed multi-turn conversation data to synthesize the assistant output.
See ModelOpt's example (https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/speculative_decoding)
for more details. The final output is stored as jsonlines in an OpenAI chat completion format.


### Quantization-Aware Training (QAT)

For quantize-aware training (QAT), the process is `bf16 training`, `fake quantization`, `qat`.
Since the base model weights are frozen, the initial training is mainly to get an more accurate
range of the draft model activation and weights. We store a new checkpoint where the model
now has additional quantization scalars for both the base and draft models. We launch the
finetuning again to continue the training with fake quantization until convergence.

```sh
python examples/post_training_opt/finetune_gpt.py \
    --export-num-medusa-heads 4 \
    --load ${MLM_MODEL_SAVE} --save ${MLM_MODEL_SAVE} ${OTHER_MLM_ARGS}
python examples/post_training_opt/text_generation_ptq.py \
    --export-quant-cfg fp8 \
    --decoder llama \
    --export-num-medusa-heads 4 \
    --load ${MLM_MODEL_SAVE} --save ${MLM_QUANT_SAVE} ${OTHER_MLM_ARGS}
python examples/post_training_opt/finetune_gpt.py \
    --export-num-medusa-heads 4 \
    --load ${MLM_QUANT_SAVE} --save ${MLM_QUANT_SAVE} ${OTHER_MLM_ARGS}
```

> **NOTE:** `MLM_QUANT_SAVE=Llama-3.1-8B-Instruct_medusa_quant` in the example.

### Export TensorRT-LLM Checkpoint

To finally export a TensorRT-LLM checkpoint, we leverage the same script by providing
`${TRTLLM_CKPT}` and the inference `${TP}`.

```sh
python examples/post_training_opt/text_generation_ptq.py \
    --export-dir ${TRTLLM_CKPT} \
    --inference-tensor-parallel ${TP} \
    --export-quant-cfg None \
    --decoder llama \
    --export-num-medusa-heads 4 \
    --load ${MLM_QUANT_SAVE} ${OTHER_MLM_ARGS}
```

> **NOTE:** `TRTLLM_CKPT=Llama-3.1-8B-Instruct_medusa_quant_trtllm_export` in the example.

**TensorRT-LLM deployment:** To build (`trtllm-build`) and run TensorRT-LLM engine, follow the steps here 
https://github.com/NVIDIA/TensorRT-Model-Optimizer#installation--docker to prepare the container.

For `tensorrt-llm>0.12`, the builder can detect this is a Medusa checkpoint directly
```sh
trtllm-build --checkpoint_dir Llama-3.1-8B-Instruct_medusa_quant_trtllm_export --output_dir /tmp/trtllm_engine ${other args}
```

The `run.py` (https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/run.py) and `gptManagerBenchmark` (https://github.com/NVIDIA/TensorRT-LLM/tree/main/benchmarks/cpp)
both support Medusa decoding by supplying argument `--medusa_choices`. This argument describes the sparse attention tree structure used in the Medusa writeup. For examples,
the following option is tree with 63 nodes which represent 63 draft tokens proposed by the 4 Medusa heads.
```sh
--medusa_choices="[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]"
```

> **ADVANCED USAGE:** When training, we typically train `4` heads if memory is sufficient and by default the max draft length is `63`.
> Optionally, users can change these values something smaller in TensorRT-LLM checkpoint's `config.json` before calling `trtllm-build`.
> For example, it is possible to only use 2 heads with maximum draft tokens 7 if this is a sweet spot. You must also change
> `--medusa_choices` to make sure you are not accessing draft tokens from the 3rd and 4th heads as well as shorting the list to have
> length 7.