Unverified Commit e9669a40 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(server): do not use device_map auto on single GPU (#362)

parent cfaa8580
...@@ -468,9 +468,12 @@ class CausalLM(Model): ...@@ -468,9 +468,12 @@ class CausalLM(Model):
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None, device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
) )
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
tokenizer.pad_token_id = ( tokenizer.pad_token_id = (
model.config.pad_token_id model.config.pad_token_id
if model.config.pad_token_id is not None if model.config.pad_token_id is not None
......
...@@ -518,9 +518,12 @@ class Seq2SeqLM(Model): ...@@ -518,9 +518,12 @@ class Seq2SeqLM(Model):
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None, device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
) )
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment