Unverified Commit a9284245 authored by Gregory Shtrasberg's avatar Gregory Shtrasberg Committed by GitHub
Browse files

[Bugfix][ROCm] Using device_type because on ROCm the API is still torch.cuda (#17601)


Signed-off-by: default avatarGregory Shtrasberg <Gregory.Shtrasberg@amd.com>
parent c8386fa6
......@@ -406,12 +406,12 @@ class Platform:
"""Raises if this request is unsupported on this platform"""
def __getattr__(self, key: str):
device = getattr(torch, self.device_name, None)
device = getattr(torch, self.device_type, None)
if device is not None and hasattr(device, key):
return getattr(device, key)
else:
logger.warning("Current platform %s does not have '%s'" \
" attribute.", self.device_name, key)
" attribute.", self.device_type, key)
return None
@classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment