Commit 601057c9 authored by zhuwenwen's avatar zhuwenwen
Browse files

update rocm.py

parent b72de6bd
......@@ -144,7 +144,8 @@ class RocmPlatform(Platform):
def get_device_name(cls, device_id: int = 0) -> str:
physical_device_id = device_id_to_physical_device_id(device_id)
handle = amdsmi_get_processor_handles()[physical_device_id]
return amdsmi_get_gpu_asic_info(handle)["market_name"]
# return amdsmi_get_gpu_asic_info(handle)["market_name"]
return torch.cuda.get_device_name(device_id)
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
......@@ -226,8 +227,9 @@ class RocmPlatform(Platform):
device: Optional[torch.types.Device] = None
) -> float:
torch.cuda.reset_peak_memory_stats(device)
return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(
device)[0]
# return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(
# device)[0]
return torch.cuda.max_memory_allocated(device)
@classmethod
def get_device_communicator_cls(cls) -> str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment