Unverified Commit 7402a355 authored by momonga's avatar momonga Committed by GitHub
Browse files

Fix calling cuda() on load_in_8bit (#1153)



This PR addresses an issue where calling `model = model.cuda()` would
throw a ValueError when `quantize` is set to "bitsandbytes".

```
> File "/opt/conda/lib/python3.9/site-packages/text_generation_server/server.py", line 147, in serve_inner
    model = get_model(
  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/__init__.py", line 295, in get_model
    return CausalLM(
  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/causal_lm.py", line 515, in __init__
    model = model.cuda()
  File "/opt/conda/lib/python3.9/site-packages/transformers/modeling_utils.py", line 1998, in cuda
    raise ValueError(
ValueError: Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct `dtype`.
```
Co-authored-by: default avatarmmnga <mmnga1mmnga@gmail.com>
parent 3af1a114
......@@ -511,7 +511,7 @@ class CausalLM(Model):
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
if torch.cuda.is_available() and torch.cuda.device_count() == 1 and quantize != "bitsandbytes":
model = model.cuda()
if tokenizer.pad_token_id is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment