1. 08 Feb, 2024 2 commits
    • OlivierDehaene's avatar
      09b7c26b
    • drbh's avatar
      Impl simple mamba model (#1480) · bd405e03
      drbh authored
      This draft PR is a work in progress implementation of the mamba model.
      This PR currently loads weights, and produces correct logits after a
      single pass.
      
      This PR still needs to correctly integrate this model so it produces
      tokens as expected, and apply optimization to avoid all copies during
      runtime/unnecessary operations.
      
      #### Helpful resources
      [Mamba: Linear-Time Sequence Modeling with Selective State Spaces
      (Albert Gu and Tri Dao)](https://arxiv.org/abs/2312.00752)
      https://github.com/johnma2006/mamba-minimal
      
      https://github.com/huggingface/candle/blob/main/candle-examples/examples/mamba-minimal/model.rs
      https://github.com/huggingface/transformers/pull/28094
      
      
      
      Notes: this dev work is currently targeting `state-spaces/mamba-130m`,
      so if you want to test please use that model. Additionally when starting
      the router the prefill needs to be limited: `cargo run --
      --max-batch-prefill-tokens 768 --max-input-length 768`
      
      
      ## Update / Current State
      
      Integration tests have been added and basic functionality such as model
      loading is supported.
      
      ```bash
      cd integration-tests
      pytest -vv models/test_fused_kernel_mamba.py
      ```
      - [x] add tests
      - [x] load model
      - [x] make simple request 
      - [ ] resolve warmup issue
      - [ ] resolve output issues
      
      
      fetching models tested during dev
      ```bash
      text-generation-server download-weights state-spaces/mamba-130m
      text-generation-server download-weights state-spaces/mamba-1.4b
      text-generation-server download-weights state-spaces/mamba-2.8b
      ```
      
      The server can be run 
      ```bash
      cd server
       MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 python text_generation_server/cli.py serve state-spaces/mamba-2.8b
      ```
      
      router
      ```bash
      cargo run
      ```
      
      make a request
      ```bash
      curl -s localhost:3000/generate \
          -X POST \
          -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
          -H 'Content-Type: application/json' | jq
      ```
      
      response
      ```json
      {
        "generated_text": "\n\nDeep learning is a machine learning technique that uses a deep neural network to learn from data."
      }
      ```
      
      ---------
      Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
      bd405e03