Unverified Commit e54ebc2f authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[doc] fix doc build error caused by msgspec (#7659)

parent 67e02fa8
...@@ -3,6 +3,7 @@ sphinx-book-theme==1.0.1 ...@@ -3,6 +3,7 @@ sphinx-book-theme==1.0.1
sphinx-copybutton==0.5.2 sphinx-copybutton==0.5.2
myst-parser==2.0.0 myst-parser==2.0.0
sphinx-argparse==0.4.0 sphinx-argparse==0.4.0
msgspec
# packages to install to build the documentation # packages to install to build the documentation
pydantic pydantic
......
import torch
from .interface import Platform, PlatformEnum, UnspecifiedPlatform from .interface import Platform, PlatformEnum, UnspecifiedPlatform
current_platform: Platform current_platform: Platform
# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because
# they only indicate the build configuration, not the runtime environment.
# For example, people can install a cuda build of pytorch but run on tpu.
is_tpu = False
try:
import torch_xla.core.xla_model as xm
xm.xla_device(devkind="TPU")
is_tpu = True
except Exception:
pass
is_cuda = False
try:
import pynvml
pynvml.nvmlInit()
try:
if pynvml.nvmlDeviceGetCount() > 0:
is_cuda = True
finally:
pynvml.nvmlShutdown()
except Exception:
pass
is_rocm = False
try: try:
import libtpu import amdsmi
except ImportError: amdsmi.amdsmi_init()
libtpu = None try:
if len(amdsmi.amdsmi_get_processor_handles()) > 0:
is_rocm = True
finally:
amdsmi.amdsmi_shut_down()
except Exception:
pass
if libtpu is not None: if is_tpu:
# people might install pytorch built with cuda but run on tpu # people might install pytorch built with cuda but run on tpu
# so we need to check tpu first # so we need to check tpu first
from .tpu import TpuPlatform from .tpu import TpuPlatform
current_platform = TpuPlatform() current_platform = TpuPlatform()
elif torch.version.cuda is not None: elif is_cuda:
from .cuda import CudaPlatform from .cuda import CudaPlatform
current_platform = CudaPlatform() current_platform = CudaPlatform()
elif torch.version.hip is not None: elif is_rocm:
from .rocm import RocmPlatform from .rocm import RocmPlatform
current_platform = RocmPlatform() current_platform = RocmPlatform()
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment