Unverified Commit 9dd61922 authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

[JAX] Fix imports in test for deprecated jax.experimental.pjit (#2274)



* Fix imports in test for deprecated jax.experimental.pjit
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix: Pass NamedSharding instead of PartitionSpec to compare_ops() so that when the in and out sharding is used to create a jitted function, it has the mesh info
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
parent 5624dbb4
...@@ -8,7 +8,7 @@ from itertools import product ...@@ -8,7 +8,7 @@ from itertools import product
import pytest import pytest
import jax import jax
from jax.experimental.pjit import pjit, _UNSPECIFIED from jax._src.sharding_impls import UNSPECIFIED as _UNSPECIFIED
from transformer_engine.jax.sharding import MeshResource from transformer_engine.jax.sharding import MeshResource
...@@ -154,13 +154,15 @@ def compare_ops( ...@@ -154,13 +154,15 @@ def compare_ops(
grad_args = tuple(range(len(inputs))) grad_args = tuple(range(len(inputs)))
target_grad_func = jax.value_and_grad(target_func, argnums=grad_args) target_grad_func = jax.value_and_grad(target_func, argnums=grad_args)
target_pjitter = pjit(target_grad_func, in_shardings=in_shardings, out_shardings=out_shardings) target_jitter = jax.jit(
target_fwd, target_grads = target_pjitter(*inputs, **kwargs) target_grad_func, in_shardings=in_shardings, out_shardings=out_shardings
target_hlo = target_pjitter.lower(*inputs, **kwargs).compile().as_text() )
target_fwd, target_grads = target_jitter(*inputs, **kwargs)
target_hlo = target_jitter.lower(*inputs, **kwargs).compile().as_text()
ref_grad_func = jax.value_and_grad(ref_func, argnums=grad_args) ref_grad_func = jax.value_and_grad(ref_func, argnums=grad_args)
ref_pjitter = pjit(ref_grad_func, in_shardings=in_shardings, out_shardings=out_shardings) ref_jitter = jax.jit(ref_grad_func, in_shardings=in_shardings, out_shardings=out_shardings)
ref_fwd, ref_grads = ref_pjitter(*inputs, **kwargs) ref_fwd, ref_grads = ref_jitter(*inputs, **kwargs)
assert_allclose(target_fwd, ref_fwd, dtype=metric_fwd_dtype) assert_allclose(target_fwd, ref_fwd, dtype=metric_fwd_dtype)
......
...@@ -134,9 +134,12 @@ class TestDistributedLayernorm: ...@@ -134,9 +134,12 @@ class TestDistributedLayernorm:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) x_named_sharding = NamedSharding(mesh, x_pspec)
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec)) g_named_sharding = NamedSharding(mesh, g_pspec)
beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec)) b_named_sharding = NamedSharding(mesh, b_pspec)
x_ = jax.device_put(x, x_named_sharding)
gamma_ = jax.device_put(gamma, g_named_sharding)
beta_ = jax.device_put(beta, b_named_sharding)
with warnings.catch_warnings(record=True) as warns: with warnings.catch_warnings(record=True) as warns:
try: try:
...@@ -148,8 +151,11 @@ class TestDistributedLayernorm: ...@@ -148,8 +151,11 @@ class TestDistributedLayernorm:
grad_args=(0, 1, 2), grad_args=(0, 1, 2),
metric_fwd_dtype=q_dtype, metric_fwd_dtype=q_dtype,
metric_bwd_dtype=q_dtype, metric_bwd_dtype=q_dtype,
in_shardings=(x_pspec, g_pspec, b_pspec), in_shardings=(x_named_sharding, g_named_sharding, b_named_sharding),
out_shardings=(None, (x_pspec, g_pspec, b_pspec)), out_shardings=(
None,
(x_named_sharding, g_named_sharding, b_named_sharding),
),
) )
except AssertionError as err: except AssertionError as err:
# Layernorm should still produce the correct numerical result with # Layernorm should still produce the correct numerical result with
...@@ -210,8 +216,10 @@ class TestDistributedLayernorm: ...@@ -210,8 +216,10 @@ class TestDistributedLayernorm:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) x_named_sharding = NamedSharding(mesh, x_pspec)
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec)) g_named_sharding = NamedSharding(mesh, g_pspec)
x_ = jax.device_put(x, x_named_sharding)
gamma_ = jax.device_put(gamma, g_named_sharding)
with warnings.catch_warnings(record=True) as warns: with warnings.catch_warnings(record=True) as warns:
try: try:
...@@ -223,8 +231,8 @@ class TestDistributedLayernorm: ...@@ -223,8 +231,8 @@ class TestDistributedLayernorm:
grad_args=(0, 1), grad_args=(0, 1),
metric_fwd_dtype=q_dtype, metric_fwd_dtype=q_dtype,
metric_bwd_dtype=q_dtype, metric_bwd_dtype=q_dtype,
in_shardings=(x_pspec, g_pspec), in_shardings=(x_named_sharding, g_named_sharding),
out_shardings=(None, (x_pspec, g_pspec)), out_shardings=(None, (x_named_sharding, g_named_sharding)),
) )
except AssertionError as err: except AssertionError as err:
# RmsNorm should still produce the correct numerical result with # RmsNorm should still produce the correct numerical result with
......
...@@ -103,8 +103,10 @@ class TestDistributedSoftmax: ...@@ -103,8 +103,10 @@ class TestDistributedSoftmax:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, autocast(mesh_resource=mesh_resource): with mesh, autocast(mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) x_named_sharding = NamedSharding(mesh, x_pspec)
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) mask_named_sharding = NamedSharding(mesh, mask_pspec)
x_ = jax.device_put(x, x_named_sharding)
mask_ = jax.device_put(mask, mask_named_sharding)
with warnings.catch_warnings(record=True) as warns: with warnings.catch_warnings(record=True) as warns:
try: try:
...@@ -116,8 +118,8 @@ class TestDistributedSoftmax: ...@@ -116,8 +118,8 @@ class TestDistributedSoftmax:
grad_args=(0,), grad_args=(0,),
metric_fwd_dtype=dtype, metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype, metric_bwd_dtype=dtype,
in_shardings=(x_pspec, mask_pspec), in_shardings=(x_named_sharding, mask_named_sharding),
out_shardings=(None, (x_pspec,)), out_shardings=(None, x_named_sharding),
) )
except AssertionError as err: except AssertionError as err:
# Softmax should still produce the correct numerical result with # Softmax should still produce the correct numerical result with
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment