• drbh's avatar
    Impl simple mamba model (#1480) · bd405e03
    drbh authored
    This draft PR is a work in progress implementation of the mamba model.
    This PR currently loads weights, and produces correct logits after a
    single pass.
    
    This PR still needs to correctly integrate this model so it produces
    tokens as expected, and apply optimization to avoid all copies during
    runtime/unnecessary operations.
    
    #### Helpful resources
    [Mamba: Linear-Time Sequence Modeling with Selective State Spaces
    (Albert Gu and Tri Dao)](https://arxiv.org/abs/2312.00752)
    https://github.com/johnma2006/mamba-minimal
    
    https://github.com/huggingface/candle/blob/main/candle-examples/examples/mamba-minimal/model.rs
    https://github.com/huggingface/transformers/pull/28094
    
    
    
    Notes: this dev work is currently targeting `state-spaces/mamba-130m`,
    so if you want to test please use that model. Additionally when starting
    the router the prefill needs to be limited: `cargo run --
    --max-batch-prefill-tokens 768 --max-input-length 768`
    
    
    ## Update / Current State
    
    Integration tests have been added and basic functionality such as model
    loading is supported.
    
    ```bash
    cd integration-tests
    pytest -vv models/test_fused_kernel_mamba.py
    ```
    - [x] add tests
    - [x] load model
    - [x] make simple request 
    - [ ] resolve warmup issue
    - [ ] resolve output issues
    
    
    fetching models tested during dev
    ```bash
    text-generation-server download-weights state-spaces/mamba-130m
    text-generation-server download-weights state-spaces/mamba-1.4b
    text-generation-server download-weights state-spaces/mamba-2.8b
    ```
    
    The server can be run 
    ```bash
    cd server
     MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 python text_generation_server/cli.py serve state-spaces/mamba-2.8b
    ```
    
    router
    ```bash
    cargo run
    ```
    
    make a request
    ```bash
    curl -s localhost:3000/generate \
        -X POST \
        -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
        -H 'Content-Type: application/json' | jq
    ```
    
    response
    ```json
    {
      "generated_text": "\n\nDeep learning is a machine learning technique that uses a deep neural network to learn from data."
    }
    ```
    
    ---------
    Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
    bd405e03
mamba.py 24.1 KB