• Philipp Moritz's avatar
    [Kernel] FP8 support for MoE kernel / Mixtral (#4244) · eace8bf0
    Philipp Moritz authored
    This PR is the first step towards fixing https://github.com/vllm-project/vllm/pull/3208
    
    It implements dynamic per-tensor scaling (see https://github.com/vllm-project/vllm/pull/4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:
    
    ```python
    from vllm import LLM, SamplingParams
    
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
    
    llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")
    
    outputs = llm.generate(prompts, sampling_params)
    
    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
    ```
    
    **Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in https://github.com/vllm-project/vllm/pull/3954). With this PR, the results are as follows:
    
    <img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">
    
    
    **Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:
    
    ```
    |      Groups      |Version|Filter|n-shot|Metric|Value |   |Stderr|
    |------------------|-------|------|-----:|------|-----:|---|-----:|
    |mmlu              |N/A    |none  |     0|acc   |0.7018|±  |0.0036|
    | - humanities     |N/A    |none  |     5|acc   |0.6472|±  |0.0065|
    | - other          |N/A    |none  |     5|acc   |0.7673|±  |0.0072|
    | - social_sciences|N/A    |none  |     5|acc   |0.8099|±  |0.0070|
    | - stem           |N/A    |none  |     5|acc   |0.6131|±  |0.0083|
    ```
    this compares favorably with the fp16 results which are
    ```
    |      Groups      |Version|Filter|n-shot|Metric|Value |   |Stderr|
    |------------------|-------|------|-----:|------|-----:|---|-----:|
    |mmlu              |N/A    |none  |     0|acc   |0.7020|±  |0.1313|
    | - humanities     |N/A    |none  |     5|acc   |0.6425|±  |0.1349|
    | - other          |N/A    |none  |     5|acc   |0.7744|±  |0.1038|
    | - social_sciences|N/A    |none  |     5|acc   |0.8131|±  |0.0695|
    | - stem           |N/A    |none  |     5|acc   |0.6108|±  |0.1383|
    ```
    
    Happy hacking!
    eace8bf0
loader.py 15.7 KB