Unverified Commit 8bf37f0e authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Fix cb.CUDAOptions usage for Triton 3.6.0 (#2610)



* Fix cb.CUDAOptions usage for Triton 3.6.0
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update utils.py
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>

* Update utils.py
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>

* Update utils.py
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>

---------
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 605786f4
......@@ -36,6 +36,8 @@ import warnings
from typing import Any, Callable, Mapping
import zlib
from packaging import version
from jax import core
import jax
import jax.numpy as jnp
......@@ -274,13 +276,16 @@ def compile_triton(
return _TRITON_KERNEL_CACHE[cache_key]
# Compile kernel
cuda_option_kwargs = {}
if version.parse(_TRITON_VERSION) < version.parse("3.6.0"):
cuda_option_kwargs["cluster_dims"] = (1, 1, 1)
options = cb.CUDAOptions(
num_warps=num_warps,
num_stages=num_stages,
num_ctas=num_ctas,
cluster_dims=(1, 1, 1),
debug=False,
enable_fp_fusion=enable_fp_fusion,
**cuda_option_kwargs,
)
# Mark constants as constexpr in signature
......@@ -303,8 +308,6 @@ def compile_triton(
# Create kernel object for JAX
# From jax/jaxlib/gpu/triton_kernels.cc:
from packaging import version
if version.parse(jax.__version__) >= version.parse("0.8.2"):
kernel = gpu_triton.TritonKernel(
compiled.name, # arg0: kernel_name (str)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment