"...targets/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "05f2ee1c5c2884f2991289db272c29ef5800062b"
Unverified Commit bc99a88d authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Error checking for mesh resource and update GemmPrimitive to use...


[JAX] Error checking for mesh resource and update GemmPrimitive to use global_mesh_resource().fsdp_resource (#2088)

* Enforce global MeshResource is set
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Use global_mesh_resource().fsdp_resource in gemm primitive
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Update tests
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Update gemm.py
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Update test_layer.py
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 51f19fdc
...@@ -219,7 +219,9 @@ def train_and_evaluate(args): ...@@ -219,7 +219,9 @@ def train_and_evaluate(args):
else: else:
fp8_recipe = None fp8_recipe = None
with te.fp8_autocast(enabled=args.use_fp8, fp8_recipe=fp8_recipe): with te.fp8_autocast(
enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource()
):
encoder = Net(num_embed) encoder = Net(num_embed)
# We use nn.Embed, thus inputs need to be in int # We use nn.Embed, thus inputs need to be in int
inputs = jnp.zeros(input_shape, dtype=jnp.int32) inputs = jnp.zeros(input_shape, dtype=jnp.int32)
......
...@@ -193,7 +193,9 @@ def train_and_evaluate(args): ...@@ -193,7 +193,9 @@ def train_and_evaluate(args):
else: else:
fp8_recipe = None fp8_recipe = None
with te.fp8_autocast(enabled=args.use_fp8, fp8_recipe=fp8_recipe): with te.fp8_autocast(
enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource()
):
cnn = Net(args.use_te) cnn = Net(args.use_te)
var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16)) var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16))
tx = optax.sgd(args.lr, args.momentum) tx = optax.sgd(args.lr, args.momentum)
......
...@@ -173,7 +173,7 @@ class TestDistributedLayernormMLP: ...@@ -173,7 +173,7 @@ class TestDistributedLayernormMLP:
) )
# Single GPU # Single GPU
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
single_jitter = jax.jit( single_jitter = jax.jit(
value_and_grad_func, value_and_grad_func,
static_argnums=range(len(inputs), len(static_inputs) + len(inputs)), static_argnums=range(len(inputs), len(static_inputs) + len(inputs)),
...@@ -330,7 +330,7 @@ class TestDistributedLayernormMLP: ...@@ -330,7 +330,7 @@ class TestDistributedLayernormMLP:
with use_jax_gemm(enabled=with_jax_gemm): with use_jax_gemm(enabled=with_jax_gemm):
# Single GPUs # Single GPUs
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
ln_mlp_single = LayerNormMLP( ln_mlp_single = LayerNormMLP(
layernorm_type=layernorm_type, layernorm_type=layernorm_type,
intermediate_dim=INTERMEDIATE, intermediate_dim=INTERMEDIATE,
......
...@@ -28,6 +28,7 @@ from transformer_engine.jax.quantize import ( ...@@ -28,6 +28,7 @@ from transformer_engine.jax.quantize import (
is_fp8_available, is_fp8_available,
update_collections, update_collections,
) )
from transformer_engine.jax.sharding import MeshResource, global_shard_guard
@pytest.fixture(autouse=True, scope="function") @pytest.fixture(autouse=True, scope="function")
...@@ -490,19 +491,28 @@ class BaseTester: ...@@ -490,19 +491,28 @@ class BaseTester:
def test_forward(self, data_shape, dtype, attrs): def test_forward(self, data_shape, dtype, attrs):
"""Test normal datatype forward""" """Test normal datatype forward"""
QuantizeConfig.finalize() # Ensure FP8 disabled. QuantizeConfig.finalize() # Ensure FP8 disabled.
self.runner(attrs).test_forward(data_shape, dtype) with global_shard_guard(
MeshResource()
): # Empty MeshResource is used as we are running on a single device
self.runner(attrs).test_forward(data_shape, dtype)
def test_backward(self, data_shape, dtype, attrs): def test_backward(self, data_shape, dtype, attrs):
"""Test normal datatype backward""" """Test normal datatype backward"""
QuantizeConfig.finalize() # Ensure FP8 disabled. QuantizeConfig.finalize() # Ensure FP8 disabled.
self.runner(attrs).test_backward(data_shape, dtype) with global_shard_guard(
MeshResource()
): # Empty MeshResource is used as we are running on a single device
self.runner(attrs).test_backward(data_shape, dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
"""Test forward with fp8 enabled""" """Test forward with fp8 enabled"""
QuantizeConfig.initialize(fp8_recipe=fp8_recipe) QuantizeConfig.initialize(fp8_recipe=fp8_recipe)
self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3) with global_shard_guard(
MeshResource()
): # Empty MeshResource is used as we are running on a single device
self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3)
QuantizeConfig.finalize() QuantizeConfig.finalize()
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
...@@ -510,7 +520,10 @@ class BaseTester: ...@@ -510,7 +520,10 @@ class BaseTester:
def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
"""Test backward with fp8 enabled""" """Test backward with fp8 enabled"""
QuantizeConfig.initialize(fp8_recipe=fp8_recipe) QuantizeConfig.initialize(fp8_recipe=fp8_recipe)
self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3) with global_shard_guard(
MeshResource()
): # Empty MeshResource is used as we are running on a single device
self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3)
QuantizeConfig.finalize() QuantizeConfig.finalize()
......
...@@ -34,6 +34,7 @@ from ..quantize import ( ...@@ -34,6 +34,7 @@ from ..quantize import (
is_fp8_gemm_with_all_layouts_supported, is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv, apply_padding_to_scale_inv,
) )
from ..sharding import global_mesh_resource
from .misc import get_padded_spec from .misc import get_padded_spec
...@@ -490,7 +491,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -490,7 +491,8 @@ class GemmPrimitive(BasePrimitive):
# Non-contracting dims of RHS always needs to be gathered along the FSDP axis # Non-contracting dims of RHS always needs to be gathered along the FSDP axis
rhs_non_cspecs = tuple( rhs_non_cspecs = tuple(
None if spec is not None and "fsdp" in spec else spec for spec in rhs_non_cspecs None if spec is not None and spec == global_mesh_resource().fsdp_resource else spec
for spec in rhs_non_cspecs
) )
# Non-contracting dims of LHS to be gathered along the SP axis. # Non-contracting dims of LHS to be gathered along the SP axis.
......
...@@ -404,9 +404,6 @@ def fp8_autocast( ...@@ -404,9 +404,6 @@ def fp8_autocast(
if fp8_recipe is None: if fp8_recipe is None:
fp8_recipe = recipe.DelayedScaling() fp8_recipe = recipe.DelayedScaling()
if mesh_resource is None:
mesh_resource = MeshResource()
Config = DelayedScalingQuantizeConfig Config = DelayedScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
Config = BlockScalingQuantizeConfig Config = BlockScalingQuantizeConfig
......
...@@ -286,7 +286,7 @@ class MeshResource: ...@@ -286,7 +286,7 @@ class MeshResource:
cp_resource: str = None cp_resource: str = None
_GLOBAL_MESH_RESOURCE = MeshResource() _GLOBAL_MESH_RESOURCE = None
@contextmanager @contextmanager
...@@ -314,6 +314,11 @@ def global_mesh_resource() -> MeshResource: ...@@ -314,6 +314,11 @@ def global_mesh_resource() -> MeshResource:
Returns: Returns:
The current MeshResource instance The current MeshResource instance
""" """
assert _GLOBAL_MESH_RESOURCE is not None, (
"Global mesh resource is not set. Please set the MeshResource via a global_shard_guard"
" context. If you are not using multiple GPUs, you can use an empty MeshResource by"
" wrapping your program in 'with global_shard_guard(MeshResource()):'"
)
return _GLOBAL_MESH_RESOURCE return _GLOBAL_MESH_RESOURCE
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment