jax_0.3.14_ROCM.patch 606 Bytes
Newer Older
zhuwenwen's avatar
zhuwenwen committed
1
2
3
4
5
6
7
8
9
10
11
12
diff -Naur jax-jax-0.3.14-dtk22.04.2/jax/_src/lib/xla_bridge.py jax-jax-0.3.14-dtk22.04.2.chg/jax/_src/lib/xla_bridge.py
--- jax-jax-0.3.14-dtk22.04.2/jax/_src/lib/xla_bridge.py	2022-07-19 17:16:32.000000000 +0800
+++ jax-jax-0.3.14-dtk22.04.2.chg/jax/_src/lib/xla_bridge.py	2022-07-20 10:23:15.000000000 +0800
@@ -261,7 +261,7 @@
 
   b = backends()
   for p in platforms:
-    if p in b.keys():
+    if p in b.keys() and p!='cuda':
       return p
   raise RuntimeError(f"Unknown backend: '{platform}' requested, but no "
                      f"platforms that are instances of {platform} are present. "