Unverified Commit a9284245 authored by Gregory Shtrasberg's avatar Gregory Shtrasberg Committed by GitHub
Browse files

[Bugfix][ROCm] Using device_type because on ROCm the API is still torch.cuda (#17601)


Signed-off-by: default avatarGregory Shtrasberg <Gregory.Shtrasberg@amd.com>
parent c8386fa6
...@@ -406,12 +406,12 @@ class Platform: ...@@ -406,12 +406,12 @@ class Platform:
"""Raises if this request is unsupported on this platform""" """Raises if this request is unsupported on this platform"""
def __getattr__(self, key: str): def __getattr__(self, key: str):
device = getattr(torch, self.device_name, None) device = getattr(torch, self.device_type, None)
if device is not None and hasattr(device, key): if device is not None and hasattr(device, key):
return getattr(device, key) return getattr(device, key)
else: else:
logger.warning("Current platform %s does not have '%s'" \ logger.warning("Current platform %s does not have '%s'" \
" attribute.", self.device_name, key) " attribute.", self.device_type, key)
return None return None
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment