Commit c70b6f4d authored by ceerrep's avatar ceerrep
Browse files

fix: use 'cuda:0' by default if torch_device is 'cuda'

parent ee24eb8d
......@@ -130,6 +130,7 @@ class KTransformersInterface(TransformersInterface):
logger.debug(f"input_ids: {input_ids.shape}")
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
device = "cuda:0" if device == "cuda" else device
if is_new:
self.ever_generated_ids.clear()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment