Unverified Commit 2e23ad71 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Add Shardy warning in GEMM custom call (#2101)



* added shardy warning
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>


---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 78e097f1
......@@ -8,6 +8,7 @@ import operator
from collections.abc import Iterable
from typing import Tuple, Sequence, Union
from functools import partial, reduce
import warnings
import jax
import jax.numpy as jnp
......@@ -658,6 +659,12 @@ class GemmPrimitive(BasePrimitive):
prefix = "GemmPrimitive_"
warnings.warn(
"Known issues with TE GemmPrimitives when Shardy propagation is enabled. For now,"
" please turn off Shardy by exporting the environment variable"
" 'JAX_USE_SHARDY_PARTITIONER=0' if you experience any problems."
)
def _generate_operand_rules(name, ndim, cdims):
specs = []
ldims = tuple(i for i in range(ndim) if i not in cdims)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment