Unverified Commit df6f347f authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

[JAX] Move jax.experimental.maps.Mesh to jax.sharding.Mesh (#276)



Move jax.experimental.maps.Mesh to jax.sharding.Mesh
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 487871e2
...@@ -8,7 +8,6 @@ import flax ...@@ -8,7 +8,6 @@ import flax
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
from jax.experimental import maps
from utils import assert_allclose from utils import assert_allclose
from transformer_engine.common.recipe import DelayedScaling from transformer_engine.common.recipe import DelayedScaling
...@@ -218,7 +217,7 @@ class TestFP8Functions(unittest.TestCase): ...@@ -218,7 +217,7 @@ class TestFP8Functions(unittest.TestCase):
# TODO (Ming Huang): Suport multi-GPUs testing. # pylint: disable=fixme # TODO (Ming Huang): Suport multi-GPUs testing. # pylint: disable=fixme
mesh_shape = (1, 1) mesh_shape = (1, 1)
devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape)
with maps.Mesh(devices, ('dp', 'tp')): with jax.sharding.Mesh(devices, ('dp', 'tp')):
for sr, mst in srs: for sr, mst in srs:
with fp8_autocast(enabled=True, fp8_recipe=ds, sharding_resource=sr): with fp8_autocast(enabled=True, fp8_recipe=ds, sharding_resource=sr):
self.assertTrue(FP8Helper.is_fp8_enabled()) self.assertTrue(FP8Helper.is_fp8_enabled())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment