• momonga's avatar
    Fix calling cuda() on load_in_8bit (#1153) · 7402a355
    momonga authored
    
    
    This PR addresses an issue where calling `model = model.cuda()` would
    throw a ValueError when `quantize` is set to "bitsandbytes".
    
    ```
    > File "/opt/conda/lib/python3.9/site-packages/text_generation_server/server.py", line 147, in serve_inner
        model = get_model(
      File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/__init__.py", line 295, in get_model
        return CausalLM(
      File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/causal_lm.py", line 515, in __init__
        model = model.cuda()
      File "/opt/conda/lib/python3.9/site-packages/transformers/modeling_utils.py", line 1998, in cuda
        raise ValueError(
    ValueError: Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct `dtype`.
    ```
    Co-authored-by: default avatarmmnga <mmnga1mmnga@gmail.com>
    7402a355
causal_lm.py 27.2 KB