Unverified Commit dd7ab715 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix error with triton 3.5 (#2286)



* Update permutation.py
Signed-off-by: default avatarfzyzcjy <5236035+fzyzcjy@users.noreply.github.com>

* Update permutation.py
Signed-off-by: default avatarfzyzcjy <5236035+fzyzcjy@users.noreply.github.com>

* Update transformer_engine/pytorch/triton/permutation.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update transformer_engine/pytorch/triton/permutation.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarfzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent fd234d80
......@@ -12,11 +12,16 @@ import triton.language as tl
from triton.language import core
from triton.language.standard import _log2
from packaging import version
# The following three argsort related kernels are adapted from
# the issue https://github.com/triton-lang/triton/issues/3698
get_int_dtype = core.get_int_dtype
if version.parse(triton.__version__) >= version.parse("3.5.0"):
get_int_dtype = triton.constexpr_function(get_int_dtype)
@triton.jit
def _compare_and_swap(x, indices, flip, i: tl.constexpr, n_dims: tl.constexpr):
......@@ -37,7 +42,7 @@ def _compare_and_swap(x, indices, flip, i: tl.constexpr, n_dims: tl.constexpr):
l_indice = tl.reshape(tl.broadcast_to(tl.sum(z * (1 - mask), 1)[:, None, :], shape), x.shape)
r_indice = tl.reshape(tl.broadcast_to(tl.sum(z * mask, 1)[:, None, :], shape), x.shape)
idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
idtype = get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
il_value = l_value.to(idtype, bitcast=True)
ir_value = r_value.to(idtype, bitcast=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment