Unverified Commit 8bf37f0e authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Fix cb.CUDAOptions usage for Triton 3.6.0 (#2610)



* Fix cb.CUDAOptions usage for Triton 3.6.0
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update utils.py
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>

* Update utils.py
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>

* Update utils.py
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>

---------
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 605786f4
...@@ -36,6 +36,8 @@ import warnings ...@@ -36,6 +36,8 @@ import warnings
from typing import Any, Callable, Mapping from typing import Any, Callable, Mapping
import zlib import zlib
from packaging import version
from jax import core from jax import core
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -274,13 +276,16 @@ def compile_triton( ...@@ -274,13 +276,16 @@ def compile_triton(
return _TRITON_KERNEL_CACHE[cache_key] return _TRITON_KERNEL_CACHE[cache_key]
# Compile kernel # Compile kernel
cuda_option_kwargs = {}
if version.parse(_TRITON_VERSION) < version.parse("3.6.0"):
cuda_option_kwargs["cluster_dims"] = (1, 1, 1)
options = cb.CUDAOptions( options = cb.CUDAOptions(
num_warps=num_warps, num_warps=num_warps,
num_stages=num_stages, num_stages=num_stages,
num_ctas=num_ctas, num_ctas=num_ctas,
cluster_dims=(1, 1, 1),
debug=False, debug=False,
enable_fp_fusion=enable_fp_fusion, enable_fp_fusion=enable_fp_fusion,
**cuda_option_kwargs,
) )
# Mark constants as constexpr in signature # Mark constants as constexpr in signature
...@@ -303,8 +308,6 @@ def compile_triton( ...@@ -303,8 +308,6 @@ def compile_triton(
# Create kernel object for JAX # Create kernel object for JAX
# From jax/jaxlib/gpu/triton_kernels.cc: # From jax/jaxlib/gpu/triton_kernels.cc:
from packaging import version
if version.parse(jax.__version__) >= version.parse("0.8.2"): if version.parse(jax.__version__) >= version.parse("0.8.2"):
kernel = gpu_triton.TritonKernel( kernel = gpu_triton.TritonKernel(
compiled.name, # arg0: kernel_name (str) compiled.name, # arg0: kernel_name (str)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment