"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7630c11f326025ef67a99db913de9fafcfc0704d"
Unverified Commit 7ef309ca authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Add jax flax to env command (#12251)

* fix_torch_device_generate_test

* remove @

* add commands for flax/jax
parent e3cb7a0b
...@@ -16,7 +16,7 @@ import platform ...@@ -16,7 +16,7 @@ import platform
from argparse import ArgumentParser from argparse import ArgumentParser
from .. import __version__ as version from .. import __version__ as version
from ..file_utils import is_tf_available, is_torch_available from ..file_utils import is_flax_available, is_tf_available, is_torch_available
from . import BaseTransformersCLICommand from . import BaseTransformersCLICommand
...@@ -52,12 +52,29 @@ class EnvironmentCommand(BaseTransformersCLICommand): ...@@ -52,12 +52,29 @@ class EnvironmentCommand(BaseTransformersCLICommand):
# returns list of devices, convert to bool # returns list of devices, convert to bool
tf_cuda_available = bool(tf.config.list_physical_devices("GPU")) tf_cuda_available = bool(tf.config.list_physical_devices("GPU"))
flax_version = "not installed"
jax_version = "not installed"
jaxlib_version = "not installed"
jax_backend = "NA"
if is_flax_available():
import flax
import jax
import jaxlib
flax_version = flax.__version__
jax_version = jax.__version__
jaxlib_version = jaxlib.__version__
jax_backend = jax.lib.xla_bridge.get_backend().platform
info = { info = {
"`transformers` version": version, "`transformers` version": version,
"Platform": platform.platform(), "Platform": platform.platform(),
"Python version": platform.python_version(), "Python version": platform.python_version(),
"PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})", "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
"Tensorflow version (GPU?)": f"{tf_version} ({tf_cuda_available})", "Tensorflow version (GPU?)": f"{tf_version} ({tf_cuda_available})",
"Flax version (CPU?/GPU?/TPU?)": f"{flax_version} ({jax_backend})",
"Jax version": f"{jax_version}",
"JaxLib version": f"{jaxlib_version}",
"Using GPU in script?": "<fill in>", "Using GPU in script?": "<fill in>",
"Using distributed or parallel set-up in script?": "<fill in>", "Using distributed or parallel set-up in script?": "<fill in>",
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment