Unverified Commit bc99a88d authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Error checking for mesh resource and update GemmPrimitive to use...


[JAX] Error checking for mesh resource and update GemmPrimitive to use global_mesh_resource().fsdp_resource (#2088)

* Enforce global MeshResource is set
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Use global_mesh_resource().fsdp_resource in gemm primitive
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Update tests
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Update gemm.py
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Update test_layer.py
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 51f19fdc
......@@ -219,7 +219,9 @@ def train_and_evaluate(args):
else:
fp8_recipe = None
with te.fp8_autocast(enabled=args.use_fp8, fp8_recipe=fp8_recipe):
with te.fp8_autocast(
enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource()
):
encoder = Net(num_embed)
# We use nn.Embed, thus inputs need to be in int
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
......
......@@ -193,7 +193,9 @@ def train_and_evaluate(args):
else:
fp8_recipe = None
with te.fp8_autocast(enabled=args.use_fp8, fp8_recipe=fp8_recipe):
with te.fp8_autocast(
enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource()
):
cnn = Net(args.use_te)
var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16))
tx = optax.sgd(args.lr, args.momentum)
......
......@@ -173,7 +173,7 @@ class TestDistributedLayernormMLP:
)
# Single GPU
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
single_jitter = jax.jit(
value_and_grad_func,
static_argnums=range(len(inputs), len(static_inputs) + len(inputs)),
......@@ -330,7 +330,7 @@ class TestDistributedLayernormMLP:
with use_jax_gemm(enabled=with_jax_gemm):
# Single GPUs
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
ln_mlp_single = LayerNormMLP(
layernorm_type=layernorm_type,
intermediate_dim=INTERMEDIATE,
......
......@@ -28,6 +28,7 @@ from transformer_engine.jax.quantize import (
is_fp8_available,
update_collections,
)
from transformer_engine.jax.sharding import MeshResource, global_shard_guard
@pytest.fixture(autouse=True, scope="function")
......@@ -490,19 +491,28 @@ class BaseTester:
def test_forward(self, data_shape, dtype, attrs):
"""Test normal datatype forward"""
QuantizeConfig.finalize() # Ensure FP8 disabled.
self.runner(attrs).test_forward(data_shape, dtype)
with global_shard_guard(
MeshResource()
): # Empty MeshResource is used as we are running on a single device
self.runner(attrs).test_forward(data_shape, dtype)
def test_backward(self, data_shape, dtype, attrs):
"""Test normal datatype backward"""
QuantizeConfig.finalize() # Ensure FP8 disabled.
self.runner(attrs).test_backward(data_shape, dtype)
with global_shard_guard(
MeshResource()
): # Empty MeshResource is used as we are running on a single device
self.runner(attrs).test_backward(data_shape, dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
"""Test forward with fp8 enabled"""
QuantizeConfig.initialize(fp8_recipe=fp8_recipe)
self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3)
with global_shard_guard(
MeshResource()
): # Empty MeshResource is used as we are running on a single device
self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3)
QuantizeConfig.finalize()
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
......@@ -510,7 +520,10 @@ class BaseTester:
def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
"""Test backward with fp8 enabled"""
QuantizeConfig.initialize(fp8_recipe=fp8_recipe)
self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3)
with global_shard_guard(
MeshResource()
): # Empty MeshResource is used as we are running on a single device
self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3)
QuantizeConfig.finalize()
......
......@@ -34,6 +34,7 @@ from ..quantize import (
is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv,
)
from ..sharding import global_mesh_resource
from .misc import get_padded_spec
......@@ -490,7 +491,8 @@ class GemmPrimitive(BasePrimitive):
# Non-contracting dims of RHS always needs to be gathered along the FSDP axis
rhs_non_cspecs = tuple(
None if spec is not None and "fsdp" in spec else spec for spec in rhs_non_cspecs
None if spec is not None and spec == global_mesh_resource().fsdp_resource else spec
for spec in rhs_non_cspecs
)
# Non-contracting dims of LHS to be gathered along the SP axis.
......
......@@ -404,9 +404,6 @@ def fp8_autocast(
if fp8_recipe is None:
fp8_recipe = recipe.DelayedScaling()
if mesh_resource is None:
mesh_resource = MeshResource()
Config = DelayedScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
Config = BlockScalingQuantizeConfig
......
......@@ -286,7 +286,7 @@ class MeshResource:
cp_resource: str = None
_GLOBAL_MESH_RESOURCE = MeshResource()
_GLOBAL_MESH_RESOURCE = None
@contextmanager
......@@ -314,6 +314,11 @@ def global_mesh_resource() -> MeshResource:
Returns:
The current MeshResource instance
"""
assert _GLOBAL_MESH_RESOURCE is not None, (
"Global mesh resource is not set. Please set the MeshResource via a global_shard_guard"
" context. If you are not using multiple GPUs, you can use an empty MeshResource by"
" wrapping your program in 'with global_shard_guard(MeshResource()):'"
)
return _GLOBAL_MESH_RESOURCE
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment