1. 09 Aug, 2024 2 commits
    • Nicolas Patry's avatar
      Using an enum for flash backens (paged/flashdecoding/flashinfer) (#2385) · 7a48a847
      Nicolas Patry authored
      * Using an enum for flash backens (paged/flashdecoding/flashinfer)
      
      * Early exit on server too.
      
      * Clippy.
      
      * Fix clippy and fmt.
      7a48a847
    • Daniël de Kok's avatar
      Add FlashInfer support (#2354) · 7830de15
      Daniël de Kok authored
      This change adds support for FlashInfer. FlashInfer can be enabled using
      `FLASH_INFER=1` and is currently only implemented in `FlashCausalLM`.
      Since this functionality is currently only for testing, FlashInfer is
      not installed anywhere yet.
      
      The FlashInfer API is quite different from FlashAttention/vLLM in that
      it requires more global bookkeeping:
      
      * A wrapper class needs to be contstructed (which we just call *state*).
        Since this is fairly expensive (due to pinned host memory allocation),
        we only do this once in a FlashCausalLM instance or for each CUDA
        Graph size.
      * Each model forward call needs to be wrapped in `begin_forward` and
        `end_forward`. This sets up data structures that can be reused for all
        calls to attention for that forward call.
      
      When calling attention, we need access to the state object. To avoid
      passing an argument down the call chain (which would require changes to
      all models), we use a context variable.
      
      Each model forward call is wrapped using a context manager that does all
      the bookkeeping for such a call:
      
      * Set the context variable to the forward call's state.
      * Call `begin_forward` on the state.
      * Yield.
      * Call `end_forward` on the state.
      * Reset the context variable.
      
      We cannot use a single shared global variable for this, since e.g. CUDA
      Graphs of different sizes each have their own state.
      7830de15
  2. 01 Jul, 2024 1 commit
    • Nicolas Patry's avatar
      [Major Change][Undecided yet] Move to FlashDecoding instead of PagedAttention kernel. (#1940) · 4327210e
      Nicolas Patry authored
      * Using flash decoding
      
      Conditional flashdecoding.
      
      Fix max_q.
      
      Working kvcache
      
      Working version with flash decoding.
      
      Make it work for mistral.
      
      Fix after rebase..
      
      Less intrusive.
      
      REvert changes in modeling.
      
      Speedup flashdecoding.
      
      HHachweew
      Hack to make other models work.
      
      Fixing non flash decoding llama path.
      
      Router logic knows about page size.
      
      Missing 2 models.
      
      Missing cohere.
      
      Fixing cohere flash decoding.
      
      Revamped all this architecture.
      
      Fix cohere.
      
      Fixing falcon.
      
      Enabling custom block size schedule.
      
      Update router/src/infer.rs
      
      Not sending preallocated output.
      
      * Making it work on non flash decoding.
      
      * Fix Cohere.
      
      * Fix non decoding paths.
      
      * Rebased.
      
      * No need for cache_manager anymore.
      
      * Update?
      
      * "ipex" -> "cpu"
      
      * These do not belong.
      
      * Factoring cu_seqlen_qk for better abstracting over every model.
      
      * Fixing non flash tests/imports.
      
      * Changing return everywhere.
      
      * Update mistral past.
      
      * Fixing Mi{s,x}tral (non functional in Flash Decoding mode though).
      
      * Fixup mistral clamping (had issues with cuda graphs).
      
      * No need to recreate anything actually.
      4327210e