Commit 739c6565 authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by Kshitij Janardan Lakhani
Browse files

[JAX] Fix imports in test for deprecated jax.experimental.pjit (#2274)



* Fix imports in test for deprecated jax.experimental.pjit
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix: Pass NamedSharding instead of PartitionSpec to compare_ops() so that when the in and out sharding is used to create a jitted function, it has the mesh info
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
parent 966a5b9b
......@@ -8,7 +8,7 @@ from itertools import product
import pytest
import jax
from jax.experimental.pjit import pjit, _UNSPECIFIED
from jax._src.sharding_impls import UNSPECIFIED as _UNSPECIFIED
from transformer_engine.jax.sharding import MeshResource
......@@ -154,13 +154,15 @@ def compare_ops(
grad_args = tuple(range(len(inputs)))
target_grad_func = jax.value_and_grad(target_func, argnums=grad_args)
target_pjitter = pjit(target_grad_func, in_shardings=in_shardings, out_shardings=out_shardings)
target_fwd, target_grads = target_pjitter(*inputs, **kwargs)
target_hlo = target_pjitter.lower(*inputs, **kwargs).compile().as_text()
target_jitter = jax.jit(
target_grad_func, in_shardings=in_shardings, out_shardings=out_shardings
)
target_fwd, target_grads = target_jitter(*inputs, **kwargs)
target_hlo = target_jitter.lower(*inputs, **kwargs).compile().as_text()
ref_grad_func = jax.value_and_grad(ref_func, argnums=grad_args)
ref_pjitter = pjit(ref_grad_func, in_shardings=in_shardings, out_shardings=out_shardings)
ref_fwd, ref_grads = ref_pjitter(*inputs, **kwargs)
ref_jitter = jax.jit(ref_grad_func, in_shardings=in_shardings, out_shardings=out_shardings)
ref_fwd, ref_grads = ref_jitter(*inputs, **kwargs)
assert_allclose(target_fwd, ref_fwd, dtype=metric_fwd_dtype)
......
......@@ -134,9 +134,12 @@ class TestDistributedLayernorm:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec))
x_named_sharding = NamedSharding(mesh, x_pspec)
g_named_sharding = NamedSharding(mesh, g_pspec)
b_named_sharding = NamedSharding(mesh, b_pspec)
x_ = jax.device_put(x, x_named_sharding)
gamma_ = jax.device_put(gamma, g_named_sharding)
beta_ = jax.device_put(beta, b_named_sharding)
with warnings.catch_warnings(record=True) as warns:
try:
......@@ -148,8 +151,11 @@ class TestDistributedLayernorm:
grad_args=(0, 1, 2),
metric_fwd_dtype=q_dtype,
metric_bwd_dtype=q_dtype,
in_shardings=(x_pspec, g_pspec, b_pspec),
out_shardings=(None, (x_pspec, g_pspec, b_pspec)),
in_shardings=(x_named_sharding, g_named_sharding, b_named_sharding),
out_shardings=(
None,
(x_named_sharding, g_named_sharding, b_named_sharding),
),
)
except AssertionError as err:
# Layernorm should still produce the correct numerical result with
......@@ -210,8 +216,10 @@ class TestDistributedLayernorm:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
x_named_sharding = NamedSharding(mesh, x_pspec)
g_named_sharding = NamedSharding(mesh, g_pspec)
x_ = jax.device_put(x, x_named_sharding)
gamma_ = jax.device_put(gamma, g_named_sharding)
with warnings.catch_warnings(record=True) as warns:
try:
......@@ -223,8 +231,8 @@ class TestDistributedLayernorm:
grad_args=(0, 1),
metric_fwd_dtype=q_dtype,
metric_bwd_dtype=q_dtype,
in_shardings=(x_pspec, g_pspec),
out_shardings=(None, (x_pspec, g_pspec)),
in_shardings=(x_named_sharding, g_named_sharding),
out_shardings=(None, (x_named_sharding, g_named_sharding)),
)
except AssertionError as err:
# RmsNorm should still produce the correct numerical result with
......
......@@ -103,8 +103,10 @@ class TestDistributedSoftmax:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, autocast(mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec))
x_named_sharding = NamedSharding(mesh, x_pspec)
mask_named_sharding = NamedSharding(mesh, mask_pspec)
x_ = jax.device_put(x, x_named_sharding)
mask_ = jax.device_put(mask, mask_named_sharding)
with warnings.catch_warnings(record=True) as warns:
try:
......@@ -116,8 +118,8 @@ class TestDistributedSoftmax:
grad_args=(0,),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, mask_pspec),
out_shardings=(None, (x_pspec,)),
in_shardings=(x_named_sharding, mask_named_sharding),
out_shardings=(None, x_named_sharding),
)
except AssertionError as err:
# Softmax should still produce the correct numerical result with
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment