Unverified Commit 0d251991 authored by Frédéric Bastien's avatar Frédéric Bastien Committed by GitHub
Browse files

WAR a JAX issue (#221)


Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>
parent b20c0531
...@@ -46,9 +46,9 @@ def is_fp8_available(gpu_id=None) -> Tuple[bool, str]: ...@@ -46,9 +46,9 @@ def is_fp8_available(gpu_id=None) -> Tuple[bool, str]:
global _is_fp8_available, _reason_for_no_fp8 global _is_fp8_available, _reason_for_no_fp8
if _is_fp8_available is None: if _is_fp8_available is None:
_is_fp8_available = True _is_fp8_available = True
devices = jax.local_devices() # JAX doesn't provide the local GPU id.
for gpu in devices: for local_gpu_id in range(len(jax.local_devices())):
ret, msg = _check_fp8_support(gpu.id) ret, msg = _check_fp8_support(local_gpu_id)
if ret is False: if ret is False:
_is_fp8_available = ret _is_fp8_available = ret
_reason_for_no_fp8 = msg _reason_for_no_fp8 = msg
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment