Unverified Commit c8861376 authored by Huazhong Ji's avatar Huazhong Ji Committed by GitHub
Browse files

Improve `transformers-cli env` reporting (#31003)

* Improve `transformers-cli env` reporting

* move the line `"Using GPU in script?": "<fill in>"` to in if conditional
statement

* same option for npu
parent c3044ec2
...@@ -26,6 +26,7 @@ from ..utils import ( ...@@ -26,6 +26,7 @@ from ..utils import (
is_safetensors_available, is_safetensors_available,
is_tf_available, is_tf_available,
is_torch_available, is_torch_available,
is_torch_npu_available,
) )
from . import BaseTransformersCLICommand from . import BaseTransformersCLICommand
...@@ -88,6 +89,7 @@ class EnvironmentCommand(BaseTransformersCLICommand): ...@@ -88,6 +89,7 @@ class EnvironmentCommand(BaseTransformersCLICommand):
pt_version = torch.__version__ pt_version = torch.__version__
pt_cuda_available = torch.cuda.is_available() pt_cuda_available = torch.cuda.is_available()
pt_npu_available = is_torch_npu_available()
tf_version = "not installed" tf_version = "not installed"
tf_cuda_available = "NA" tf_cuda_available = "NA"
...@@ -129,9 +131,15 @@ class EnvironmentCommand(BaseTransformersCLICommand): ...@@ -129,9 +131,15 @@ class EnvironmentCommand(BaseTransformersCLICommand):
"Flax version (CPU?/GPU?/TPU?)": f"{flax_version} ({jax_backend})", "Flax version (CPU?/GPU?/TPU?)": f"{flax_version} ({jax_backend})",
"Jax version": f"{jax_version}", "Jax version": f"{jax_version}",
"JaxLib version": f"{jaxlib_version}", "JaxLib version": f"{jaxlib_version}",
"Using GPU in script?": "<fill in>",
"Using distributed or parallel set-up in script?": "<fill in>", "Using distributed or parallel set-up in script?": "<fill in>",
} }
if pt_cuda_available:
info["Using GPU in script?"] = "<fill in>"
info["GPU type"] = torch.cuda.get_device_name()
elif pt_npu_available:
info["Using NPU in script?"] = "<fill in>"
info["NPU type"] = torch.npu.get_device_name()
info["CANN version"] = torch.version.cann
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n") print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
print(self.format_dict(info)) print(self.format_dict(info))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment