Unverified Commit 2e23ad71 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Add Shardy warning in GEMM custom call (#2101)



* added shardy warning
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>


---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 78e097f1
...@@ -8,6 +8,7 @@ import operator ...@@ -8,6 +8,7 @@ import operator
from collections.abc import Iterable from collections.abc import Iterable
from typing import Tuple, Sequence, Union from typing import Tuple, Sequence, Union
from functools import partial, reduce from functools import partial, reduce
import warnings
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -658,6 +659,12 @@ class GemmPrimitive(BasePrimitive): ...@@ -658,6 +659,12 @@ class GemmPrimitive(BasePrimitive):
prefix = "GemmPrimitive_" prefix = "GemmPrimitive_"
warnings.warn(
"Known issues with TE GemmPrimitives when Shardy propagation is enabled. For now,"
" please turn off Shardy by exporting the environment variable"
" 'JAX_USE_SHARDY_PARTITIONER=0' if you experience any problems."
)
def _generate_operand_rules(name, ndim, cdims): def _generate_operand_rules(name, ndim, cdims):
specs = [] specs = []
ldims = tuple(i for i in range(ndim) if i not in cdims) ldims = tuple(i for i in range(ndim) if i not in cdims)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment