"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "f0a9404881777ba0496e56e62d682ebb3896e91c"
Unverified Commit 23f4864d authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Replace deprecated sharding API in JAX test (#332)



Replace deprecated sharding API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f0ddab82
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
import jax import jax
import numpy as np import numpy as np
import pytest import pytest
from jax.experimental import maps
from utils import is_devices_enough from utils import is_devices_enough
from transformer_engine.jax.flax import extend_logical_axis_rules from transformer_engine.jax.flax import extend_logical_axis_rules
...@@ -79,7 +78,7 @@ class TestGeneralFunc: ...@@ -79,7 +78,7 @@ class TestGeneralFunc:
sharding_type): sharding_type):
devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape)
with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)): with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)):
with maps.Mesh(devices, mesh_names): with jax.sharding.Mesh(devices, mesh_names):
assert infer_major_sharding_type() is sharding_type.value[0] assert infer_major_sharding_type() is sharding_type.value[0]
@pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG) @pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
...@@ -150,7 +149,7 @@ class TestShardingMetaGenerator: ...@@ -150,7 +149,7 @@ class TestShardingMetaGenerator:
devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape)
with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)): with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)):
with maps.Mesh(devices, mesh_names): with jax.sharding.Mesh(devices, mesh_names):
test_sm = get_fp8_meta_sharding_meta( test_sm = get_fp8_meta_sharding_meta(
sharding_type, sharding_type,
num_of_fp8_meta, num_of_fp8_meta,
...@@ -240,7 +239,7 @@ class TestShardingMetaGenerator: ...@@ -240,7 +239,7 @@ class TestShardingMetaGenerator:
devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape)
with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)): with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)):
with maps.Mesh(devices, mesh_names): with jax.sharding.Mesh(devices, mesh_names):
test_sm = get_dot_sharding_meta( test_sm = get_dot_sharding_meta(
sharding_type, sharding_type,
a_shape, a_shape,
...@@ -319,7 +318,7 @@ class TestShardingMetaGenerator: ...@@ -319,7 +318,7 @@ class TestShardingMetaGenerator:
devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape)
with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)): with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)):
with maps.Mesh(devices, mesh_names): with jax.sharding.Mesh(devices, mesh_names):
ref_sm, need_assert = get_ref_sm() ref_sm, need_assert = get_ref_sm()
try: try:
test_sm = get_elementwise_sharding_meta( test_sm = get_elementwise_sharding_meta(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment