• Daniël de Kok's avatar
    Add support for FP8 KV cache scales (#2628) · eab07f74
    Daniël de Kok authored
    * Add support for FP8 KV cache scales
    
    Since FP8 only has limited dynamic range, we can scale keys/values
    before storing them into the cache (and unscale them in attention). To
    avoid rescaling the cache as the absmax values change, good scales are
    usually determined per layer using calibration calibration data and stored
    in the checkpoint.
    
    This change adds support for for using key-value scales and loading them
    from checkpoints in the two most common formats:
    
    - Separate per-layer `k_scale` and `v_scale` scalars.
    - Per-layer `kv_scale` scalar (older format).
    
    Currently, scales are only used with an `float8_e4m3fn` cache.
    
    Besides adding support for key/value scales, the `fp8_quantize` function
    is also extended to support quantization with a kernel vendored from
    vLLM. This is slightly faster than the PyTorch implementation, but also
    scales in FP32, potentially improving accuracy.
    
    * Update FP8 KV cache test to use checkpoint with scales
    
    * `can_scale`: check that the attention is flashinfer
    eab07f74
pyproject.toml 3.36 KB