Unverified Commit 96966f53 authored by Filippo Momentè's avatar Filippo Momentè Committed by GitHub
Browse files

Add device arg to model_args passed to LLM object in VLLM model class (#2879)

* fix: pass device arg in model_ar in vllm_causallms

* casting device arg to str in vLLM model args
parent 4dbd5ec9
...@@ -100,6 +100,7 @@ class VLLM(TemplateLM): ...@@ -100,6 +100,7 @@ class VLLM(TemplateLM):
"swap_space": int(swap_space), "swap_space": int(swap_space),
"quantization": quantization, "quantization": quantization,
"seed": int(seed), "seed": int(seed),
"device": str(device),
} }
self.model_args.update(kwargs) self.model_args.update(kwargs)
self.batch_size = ( self.batch_size = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment