"...graphbolt/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "efe28493d3f5c9911b0648dba79021d139044d62"
Commit c70b6f4d authored by ceerrep's avatar ceerrep
Browse files

fix: use 'cuda:0' by default if torch_device is 'cuda'

parent ee24eb8d
...@@ -130,6 +130,7 @@ class KTransformersInterface(TransformersInterface): ...@@ -130,6 +130,7 @@ class KTransformersInterface(TransformersInterface):
logger.debug(f"input_ids: {input_ids.shape}") logger.debug(f"input_ids: {input_ids.shape}")
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
device = "cuda:0" if device == "cuda" else device
if is_new: if is_new:
self.ever_generated_ids.clear() self.ever_generated_ids.clear()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment