Unverified Commit 8dbe0c52 authored by hfan's avatar hfan Committed by GitHub
Browse files

[Misc] Add TPU usage report when using tpu_inference. (#27423)


Signed-off-by: default avatarHongmin Fan <fanhongmin@google.com>
parent 5cc6bddb
...@@ -176,6 +176,32 @@ class UsageMessage: ...@@ -176,6 +176,32 @@ class UsageMessage:
self._report_usage_once(model_architecture, usage_context, extra_kvs) self._report_usage_once(model_architecture, usage_context, extra_kvs)
self._report_continuous_usage() self._report_continuous_usage()
def _report_tpu_inference_usage(self) -> bool:
try:
from tpu_inference import tpu_info, utils
self.gpu_count = tpu_info.get_num_chips()
self.gpu_type = tpu_info.get_tpu_type()
self.gpu_memory_per_device = utils.get_device_hbm_limit()
self.cuda_runtime = "tpu_inference"
return True
except Exception:
return False
def _report_torch_xla_usage(self) -> bool:
try:
import torch_xla
self.gpu_count = torch_xla.runtime.world_size()
self.gpu_type = torch_xla.tpu.get_tpu_type()
self.gpu_memory_per_device = torch_xla.core.xla_model.get_memory_info()[
"bytes_limit"
]
self.cuda_runtime = "torch_xla"
return True
except Exception:
return False
def _report_usage_once( def _report_usage_once(
self, self,
model_architecture: str, model_architecture: str,
...@@ -192,16 +218,10 @@ class UsageMessage: ...@@ -192,16 +218,10 @@ class UsageMessage:
) )
if current_platform.is_cuda(): if current_platform.is_cuda():
self.cuda_runtime = torch.version.cuda self.cuda_runtime = torch.version.cuda
if current_platform.is_tpu(): if current_platform.is_tpu(): # noqa: SIM102
try: if (not self._report_tpu_inference_usage()) and (
import torch_xla not self._report_torch_xla_usage()
):
self.gpu_count = torch_xla.runtime.world_size()
self.gpu_type = torch_xla.tpu.get_tpu_type()
self.gpu_memory_per_device = torch_xla.core.xla_model.get_memory_info()[
"bytes_limit"
]
except Exception:
logger.exception("Failed to collect TPU information") logger.exception("Failed to collect TPU information")
self.provider = _detect_cloud_provider() self.provider = _detect_cloud_provider()
self.architecture = platform.machine() self.architecture = platform.machine()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment