• Daniël de Kok's avatar
    Add FlashInfer support (#2354) · 7830de15
    Daniël de Kok authored
    This change adds support for FlashInfer. FlashInfer can be enabled using
    `FLASH_INFER=1` and is currently only implemented in `FlashCausalLM`.
    Since this functionality is currently only for testing, FlashInfer is
    not installed anywhere yet.
    
    The FlashInfer API is quite different from FlashAttention/vLLM in that
    it requires more global bookkeeping:
    
    * A wrapper class needs to be contstructed (which we just call *state*).
      Since this is fairly expensive (due to pinned host memory allocation),
      we only do this once in a FlashCausalLM instance or for each CUDA
      Graph size.
    * Each model forward call needs to be wrapped in `begin_forward` and
      `end_forward`. This sets up data structures that can be reused for all
      calls to attention for that forward call.
    
    When calling attention, we need access to the state object. To avoid
    passing an argument down the call chain (which would require changes to
    all models), we use a context variable.
    
    Each model forward call is wrapped using a context manager that does all
    the bookkeeping for such a call:
    
    * Set the context variable to the forward call's state.
    * Call `begin_forward` on the state.
    * Yield.
    * Call `end_forward` on the state.
    * Reset the context variable.
    
    We cannot use a single shared global variable for this, since e.g. CUDA
    Graphs of different sizes each have their own state.
    7830de15
common.py 1.35 KB