"...git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "94b5edeb5384ea2a46533a11dd5938b2c859bf5c"
Unverified Commit df6f347f authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

[JAX] Move jax.experimental.maps.Mesh to jax.sharding.Mesh (#276)



Move jax.experimental.maps.Mesh to jax.sharding.Mesh
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 487871e2
......@@ -8,7 +8,6 @@ import flax
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental import maps
from utils import assert_allclose
from transformer_engine.common.recipe import DelayedScaling
......@@ -218,7 +217,7 @@ class TestFP8Functions(unittest.TestCase):
# TODO (Ming Huang): Suport multi-GPUs testing. # pylint: disable=fixme
mesh_shape = (1, 1)
devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape)
with maps.Mesh(devices, ('dp', 'tp')):
with jax.sharding.Mesh(devices, ('dp', 'tp')):
for sr, mst in srs:
with fp8_autocast(enabled=True, fp8_recipe=ds, sharding_resource=sr):
self.assertTrue(FP8Helper.is_fp8_enabled())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment