Commit c638ac7e authored by Phuong Nguyen's avatar Phuong Nguyen Committed by Kshitij Janardan Lakhani
Browse files

[JAX] Add Shardy warning in GEMM custom call (#2101)



* added shardy warning
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>


---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent e94041a5
...@@ -8,6 +8,7 @@ import operator ...@@ -8,6 +8,7 @@ import operator
from collections.abc import Iterable from collections.abc import Iterable
from typing import Tuple, Sequence, Union from typing import Tuple, Sequence, Union
from functools import partial, reduce from functools import partial, reduce
import warnings
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -658,6 +659,12 @@ class GemmPrimitive(BasePrimitive): ...@@ -658,6 +659,12 @@ class GemmPrimitive(BasePrimitive):
prefix = "GemmPrimitive_" prefix = "GemmPrimitive_"
warnings.warn(
"Known issues with TE GemmPrimitives when Shardy propagation is enabled. For now,"
" please turn off Shardy by exporting the environment variable"
" 'JAX_USE_SHARDY_PARTITIONER=0' if you experience any problems."
)
def _generate_operand_rules(name, ndim, cdims): def _generate_operand_rules(name, ndim, cdims):
specs = [] specs = []
ldims = tuple(i for i in range(ndim) if i not in cdims) ldims = tuple(i for i in range(ndim) if i not in cdims)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment