utils.py 19 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Phuong Nguyen's avatar
Phuong Nguyen committed
2
3
4
5
6
7
8
#
# See LICENSE for license information.
"""
Triton utilities for JAX primitives.

This module provides utility functions for integrating Triton kernels into
JAX primitives. Triton is only imported when this module is used.
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

Triton Package Compatibility:
    There are two Triton packages that can be used:

    1. 'triton' (from OpenAI/PyPI): Standard package, works with JAX out of the box.
       Install with: pip install triton

    2. 'pytorch-triton' (from PyTorch's index): Bundled with PyTorch, includes
       PyTorch-specific patches. Version format: "3.0.0+<commit_sha>"

       IMPORTANT: The 'pytorch-triton' package on PyPI (version 0.0.1) is a
       placeholder that will NOT work. The real pytorch-triton is only available
       from PyTorch's package index and is auto-installed with PyTorch:
           pip install torch --index-url https://download.pytorch.org/whl/cu121

       pytorch-triton has been tested to work with JAX Triton kernels.

Environment Variables:
    NVTE_USE_PYTORCH_TRITON: If set to "1", explicitly acknowledge using
        pytorch-triton for JAX Triton kernels (suppresses warnings). This is
        useful when both JAX and PyTorch are installed in the same environment.
        Default is "0".
Phuong Nguyen's avatar
Phuong Nguyen committed
31
32
33
"""

import hashlib
34
35
import os
import warnings
Phuong Nguyen's avatar
Phuong Nguyen committed
36
37
38
from typing import Any, Callable, Mapping
import zlib

39
40
from packaging import version

Phuong Nguyen's avatar
Phuong Nguyen committed
41
42
43
44
45
from jax import core
import jax
import jax.numpy as jnp


46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# Placeholder package version on PyPI that should never be used
_PYTORCH_TRITON_PLACEHOLDER_VERSION = "0.0.1"


def _detect_triton_package():
    """Detect which Triton package is installed and validate compatibility.

    Returns:
        tuple: (triton_version: str or None, is_pytorch_triton: bool, is_placeholder: bool)

    The function detects:
    - None: Triton not installed
    - Standard triton from OpenAI (versions like "3.1.0")
    - Real pytorch-triton from PyTorch's index (versions like "3.0.0+45fff310c8")
    - Placeholder pytorch-triton from PyPI (version "0.0.1" - broken, raises RuntimeError)
    """
    try:
        import triton

        triton_version = getattr(triton, "__version__", "unknown")
    except ImportError:
        return None, False, False
    except RuntimeError as e:
        # The placeholder pytorch-triton package from PyPI raises:
        # RuntimeError: "Should never be installed"
        if "Should never be installed" in str(e):
            return _PYTORCH_TRITON_PLACEHOLDER_VERSION, False, True
        raise

    # Check for placeholder package (version 0.0.1 from PyPI)
    is_placeholder = triton_version == _PYTORCH_TRITON_PLACEHOLDER_VERSION

    # Real pytorch-triton versions have a commit SHA suffix like "3.0.0+45fff310c8"
    is_pytorch_triton = "+" in triton_version and len(triton_version.split("+")[-1]) >= 8

    return triton_version, is_pytorch_triton, is_placeholder


def _check_triton_compatibility():
    """Check Triton package compatibility and emit warnings if necessary.

    This function handles the case where both JAX and PyTorch may be installed,
    each expecting different Triton packages:
    - JAX typically uses the standard 'triton' package from OpenAI
    - PyTorch uses 'pytorch-triton' which is versioned with commit SHAs

    The NVTE_USE_PYTORCH_TRITON environment variable can be used to explicitly
    acknowledge using pytorch-triton with JAX (suppresses warnings).

    Raises:
        ImportError: If triton is not installed or the placeholder package is detected.
    """
    triton_version, is_pytorch_triton, is_placeholder = _detect_triton_package()

    # Handle placeholder package from PyPI
    if is_placeholder:
        raise ImportError(
            "Detected the placeholder 'pytorch-triton' package (version 0.0.1) from PyPI.\n"
            "This is NOT a functional Triton installation.\n\n"
            "The placeholder package exists to prevent namespace conflicts. To fix this:\n\n"
            "Option 1 - Use standard Triton (recommended for JAX-only environments):\n"
            "    pip uninstall pytorch-triton triton\n"
            "    pip install triton\n\n"
            "Option 2 - Use real pytorch-triton (for mixed JAX+PyTorch environments):\n"
            "    pip uninstall pytorch-triton triton\n"
            "    pip install torch --index-url https://download.pytorch.org/whl/cu121\n"
            "    # pytorch-triton is automatically installed as a torch dependency\n\n"
            "Note: Do NOT run 'pip install pytorch-triton' directly - this installs\n"
            "the broken placeholder. The real pytorch-triton only comes from PyTorch's index."
        )

    if triton_version is None:
        raise ImportError(
            "Triton is required for transformer_engine.jax.triton_extensions.\n\n"
            "Option 1 - Install standard Triton (recommended for JAX-only):\n"
            "    pip install triton\n\n"
            "Option 2 - Install PyTorch with pytorch-triton (for mixed environments):\n"
            "    pip install torch --index-url https://download.pytorch.org/whl/cu121\n\n"
            "If you don't need Triton, use transformer_engine.jax.cpp_extensions instead."
        )

    use_pytorch_triton_explicit = bool(int(os.environ.get("NVTE_USE_PYTORCH_TRITON", "0")))

    if is_pytorch_triton:
        if use_pytorch_triton_explicit:
            # User explicitly opted in - just log info (no warning)
            pass  # Silent acknowledgment, no warning needed
        else:
            # pytorch-triton detected but user didn't explicitly opt in
            warnings.warn(
                f"Detected pytorch-triton package (version {triton_version}) instead of the"
                " standard 'triton' package from OpenAI. This typically happens when PyTorch is"
                " installed alongside JAX.\n\npytorch-triton is compatible with JAX Triton"
                " kernels. To suppress this warning, set:\n    export"
                " NVTE_USE_PYTORCH_TRITON=1\n\nAlternatively, for a JAX-only environment:\n  - Use"
                " separate virtual environments for JAX and PyTorch, or\n  - Use"
                " transformer_engine.jax.cpp_extensions instead (CUDA-based, no Triton needed)",
                category=UserWarning,
                stacklevel=3,
            )

    return triton_version, is_pytorch_triton


# Perform compatibility check and get triton info
_TRITON_VERSION, _IS_PYTORCH_TRITON = _check_triton_compatibility()

Phuong Nguyen's avatar
Phuong Nguyen committed
153
154
155
156
157
158
159
160
161
162
163
164
165
try:
    from jax._src.lib import gpu_triton
    from triton.compiler import compiler as tc
    from triton.backends.nvidia import compiler as cb
    from triton.runtime import autotuner
except ImportError as e:
    raise ImportError(
        "Triton is required for transformer_engine.jax.triton_extensions. "
        "Install with: pip install triton\n"
        "If you don't need Triton, use transformer_engine.jax.cpp_extensions instead."
    ) from e


166
__all__ = ["triton_call_lowering", "get_triton_info"]
Phuong Nguyen's avatar
Phuong Nguyen committed
167
168
169
170
171

# Triton kernel cache (module-level, shared across all kernels)
_TRITON_KERNEL_CACHE = {}


172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def get_triton_info():
    """Get information about the installed Triton package.

    Returns:
        dict: Dictionary containing:
            - version (str): Triton version string (e.g., "3.1.0" or "3.0.0+45fff310c8")
            - is_pytorch_triton (bool): True if using real pytorch-triton from PyTorch's index
            - is_openai_triton (bool): True if using standard triton from OpenAI/PyPI
            - env_acknowledged (bool): True if NVTE_USE_PYTORCH_TRITON=1 is set
            - source (str): "pytorch" or "openai" indicating the package source

    Example:
        from transformer_engine.jax.triton_extensions import get_triton_info
        info = get_triton_info()
        print(f"Triton version: {info['version']} (from {info['source']})")
        if info['is_pytorch_triton']:
             print("Using pytorch-triton - compatible with both PyTorch and JAX")
    """
    env_acknowledged = bool(int(os.environ.get("NVTE_USE_PYTORCH_TRITON", "0")))

    return {
        "version": _TRITON_VERSION,
        "is_pytorch_triton": _IS_PYTORCH_TRITON,
        "is_openai_triton": not _IS_PYTORCH_TRITON,
        "env_acknowledged": env_acknowledged and _IS_PYTORCH_TRITON,
        "source": "pytorch" if _IS_PYTORCH_TRITON else "openai",
    }


Phuong Nguyen's avatar
Phuong Nguyen committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
def get_triton_dtype(aval):
    """Convert JAX dtype to Triton type string.

    Args:
        aval: JAX ShapedArray

    Returns:
        Triton type string (e.g., "*fp32" for pointer, "i32" for scalar)
    """
    dtype_map = {
        jnp.dtype("bfloat16"): "bf16",
        jnp.dtype("float64"): "fp64",
        jnp.dtype("float32"): "fp32",
        jnp.dtype("float16"): "fp16",
        jnp.dtype("float8_e4m3fn"): "fp8e4nv",
        jnp.dtype("float8_e5m2"): "fp8e5",
        jnp.dtype("int64"): "i64",
        jnp.dtype("int32"): "i32",
        jnp.dtype("int16"): "i16",
        jnp.dtype("int8"): "i8",
        jnp.dtype("uint64"): "u64",
        jnp.dtype("uint32"): "u32",
        jnp.dtype("uint16"): "u16",
        jnp.dtype("uint8"): "u8",
        jnp.dtype("bool"): "i1",
    }

    assert isinstance(aval, core.ShapedArray), "aval must be a JAX ShapedArray"
    return f"*{dtype_map[aval.dtype]}"


def compile_triton(
    kernel_fn: Callable,
    signature: Mapping[str, str],
    constants: Mapping[str, Any],
    num_warps: int,
    num_stages: int,
    num_ctas: int,
    compute_capability: int,
    enable_fp_fusion: bool = False,
):
    """Compile a Triton kernel to PTX.

    Kernels are cached to avoid recompilation.

    Args:
        kernel_fn: Triton kernel function (decorated with @triton.jit)
        signature: Dict mapping arg names to types (e.g., {"x_ptr": "*fp32", "n": "i32"})
        constants: Dict of compile-time constants
        num_warps: Number of warps per block
        num_stages: Number of pipeline stages
        num_ctas: Number of CTAs (cooperative thread arrays)
        compute_capability: CUDA compute capability
        enable_fp_fusion: Enable FP fusion optimizations (default False for accuracy)

    Returns:
        TritonKernel object for JAX
    """
    # Create cache key
    cache_key = hashlib.md5(
        str(
            (
                kernel_fn.__name__,
                tuple(sorted(signature.items())),
                tuple(sorted(constants.items())),
                num_warps,
                num_stages,
                num_ctas,
                enable_fp_fusion,
                compute_capability,
            )
        ).encode()
    ).hexdigest()

    if cache_key in _TRITON_KERNEL_CACHE:
        return _TRITON_KERNEL_CACHE[cache_key]

    # Compile kernel
279
280
281
    cuda_option_kwargs = {}
    if version.parse(_TRITON_VERSION) < version.parse("3.6.0"):
        cuda_option_kwargs["cluster_dims"] = (1, 1, 1)
Phuong Nguyen's avatar
Phuong Nguyen committed
282
283
284
285
286
287
    options = cb.CUDAOptions(
        num_warps=num_warps,
        num_stages=num_stages,
        num_ctas=num_ctas,
        debug=False,
        enable_fp_fusion=enable_fp_fusion,
288
        **cuda_option_kwargs,
Phuong Nguyen's avatar
Phuong Nguyen committed
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
    )

    # Mark constants as constexpr in signature
    signature_with_constexpr = dict(signature)
    for const_name in constants.keys():
        if const_name in signature_with_constexpr:
            signature_with_constexpr[const_name] = "constexpr"

    src = tc.ASTSource(
        fn=kernel_fn,
        constexprs=constants,
        signature=signature_with_constexpr,
    )

    compiled = tc.compile(
        src,
        target=tc.GPUTarget("cuda", compute_capability, 32),
        options=options.__dict__,
    )

    # Create kernel object for JAX
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
    # From jax/jaxlib/gpu/triton_kernels.cc:
    if version.parse(jax.__version__) >= version.parse("0.8.2"):
        kernel = gpu_triton.TritonKernel(
            compiled.name,  # arg0: kernel_name (str)
            num_warps,  # arg1: num_warps (int)
            num_ctas,  # arg2: num_ctas (int)
            compiled.metadata.shared,  # arg3: shared_mem_bytes (int)
            compiled.asm["ptx"],  # arg4: ptx (str)
            "",  # arg5: ttir (str) - empty
            compute_capability,  # arg6: compute_capability (int)
        )
    else:
        kernel = gpu_triton.TritonKernel(
            compiled.name,
            num_warps,
            compiled.metadata.shared,
            compiled.asm["ptx"],
            "",  # ttir
            compute_capability,
            1,
            1,
            1,
        )
Phuong Nguyen's avatar
Phuong Nguyen committed
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355

    _TRITON_KERNEL_CACHE[cache_key] = kernel
    return kernel


def triton_call_lowering(
    ctx,
    kernel_fn: Callable,
    *array_args,
    grid,
    input_output_aliases: Mapping[int, int] = None,
    constexprs: Mapping[str, Any] = None,
):
    """Helper for MLIR lowering that calls a Triton kernel.

    Use this in your primitive's lowering method to call Triton kernels.

    Args:
        ctx: MLIR lowering context
        kernel_fn: Triton kernel function
        *array_args: Input arrays (from ctx)
        grid: Grid dimensions (int or tuple)
        input_output_aliases: Mapping of input to output aliases
356
357
358
        constexprs: Compile-time constants for the kernel. This includes both
                    tl.constexpr arguments AND scalar runtime arguments (like
                    num_tokens, strides) that are known at JAX trace time.
Phuong Nguyen's avatar
Phuong Nguyen committed
359
360
361
362
363
364
365
366
367
368
369
370

    Returns:
        MLIR lowering result

    Example:
        @staticmethod
        def lowering(ctx, x, *, block_size):
            from ..triton_extensions import triton_call_lowering
            n = ctx.avals_in[0].size
            return triton_call_lowering(
                ctx, my_kernel, x,
                grid=(triton.cdiv(n, block_size),),
371
372
373
374
                constexprs={
                    "n_elements": n,  # scalar arg (not tl.constexpr in kernel)
                    "BLOCK_SIZE": block_size,  # tl.constexpr arg
                },
Phuong Nguyen's avatar
Phuong Nguyen committed
375
376
377
378
379
380
381
382
383
384
385
386
            )
    """
    # Get compute capability using gpu_triton
    compute_capability = gpu_triton.get_compute_capability(0)  # device 0

    # Build signature dict: map arg names to types
    # Get arg names from kernel function
    if isinstance(kernel_fn, autotuner.Autotuner):
        arg_names = kernel_fn.fn.arg_names
    else:
        arg_names = kernel_fn.arg_names

387
388
389
    # Build signature for tensor arguments only (inputs + outputs)
    # Scalar arguments should be passed via constexprs and will be
    # specialized into the kernel at compile time
Phuong Nguyen's avatar
Phuong Nguyen committed
390
    all_avals = list(ctx.avals_in) + list(ctx.avals_out)
391
392
393
    constexpr_names = set(constexprs.keys()) if constexprs else set()
    tensor_arg_names = [n for n in arg_names if n not in constexpr_names]
    signature = {n: get_triton_dtype(a) for n, a in zip(tensor_arg_names, all_avals)}
Phuong Nguyen's avatar
Phuong Nguyen committed
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414

    # Normalize grid to 3D
    if isinstance(grid, int):
        grid_tuple = (grid, 1, 1)
    elif len(grid) == 1:
        grid_tuple = (grid[0], 1, 1)
    elif len(grid) == 2:
        grid_tuple = (grid[0], grid[1], 1)
    else:
        grid_tuple = grid[:3]

    # Default values for the kernel
    actual_kernel_fn = kernel_fn
    num_warps = 32
    num_stages = (
        1  # TODO(Phuong): consider if it is beneficial to expose num_warps, num_stages, num_ctas
    )
    num_ctas = 1
    kernel_constexprs = constexprs if constexprs is not None else {}

    # Handle autotuned kernels - compile all configs
415
416
    is_autotuned = isinstance(kernel_fn, autotuner.Autotuner)
    if is_autotuned:
Phuong Nguyen's avatar
Phuong Nguyen committed
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
        # Compile all configs for runtime selection
        kernel_calls = []
        actual_kernel_fn = kernel_fn.fn

        for config in kernel_fn.configs:
            # Extract parameters from config
            config_num_warps = config.num_warps if config.num_warps is not None else num_warps
            config_num_stages = config.num_stages if config.num_stages is not None else num_stages
            config_num_ctas = config.num_ctas if config.num_ctas is not None else num_ctas

            # Merge config kwargs with user constexprs
            config_constexprs = {**config.kwargs, **(constexprs if constexprs else {})}

            # Compile this config
            config_kernel = compile_triton(
                actual_kernel_fn,
                signature,
                config_constexprs,
                config_num_warps,
                config_num_stages,
                config_num_ctas,
                compute_capability,
                enable_fp_fusion=False,
            )

            # Create kernel call for this config
            config_params = []
            for _ in list(ctx.avals_in) + list(ctx.avals_out):
                config_params.append(gpu_triton.create_array_parameter(0, 16))

            config_call = gpu_triton.TritonKernelCall(
                config_kernel,
                grid_tuple[0],
                grid_tuple[1],
                grid_tuple[2],
                config_params,
            )

            kernel_calls.append((config_call, str(config)))

457
458
459
460
461
462
463
464
465
466
467
468
469
        # IMPORTANT: We pass an empty tuple for input_output_aliases_with_sizes.
        #
        # Background:
        # 1. jax.ffi.ffi_lowering(operand_output_aliases=...) is a HINT to XLA that an
        #    output can reuse an input's buffer. XLA may or may not honor this.
        # 2. TritonAutotunedKernelCall's input_output_aliases_with_sizes triggers
        #    save/restore logic during autotuning (see jaxlib/gpu/triton_kernels.cc:630-701).
        #
        # The problem: The save phase (triton_kernels.cc:632) only saves if buffers[input_idx] == buffers[output_idx],
        # but the restore phase (triton_kernels.cc:697-700) unconditionally iterates over all aliases and tries
        # to access input_copies[input_idx]. If XLA didn't actually alias the buffers, input_copies[input_idx] doesn't exist, creating an empty vector whose .data() returns nullptr, causing CUDA_ERROR_INVALID_VALUE during the restore memcpy.
        #
        # WAR: Don't pass aliases to TritonAutotunedKernelCall.
Phuong Nguyen's avatar
Phuong Nguyen committed
470
471
472
        kernel_call = gpu_triton.TritonAutotunedKernelCall(
            f"{actual_kernel_fn.__name__}_autotuned",
            kernel_calls,
473
            (),  # Empty to avoid buggy save/restore in jaxlib/gpu/triton_kernels.cc
Phuong Nguyen's avatar
Phuong Nguyen committed
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
        )

    else:
        # Regular kernel: compile single config
        kernel = compile_triton(
            actual_kernel_fn,
            signature,
            kernel_constexprs,
            num_warps,
            num_stages,
            num_ctas,
            compute_capability,
            enable_fp_fusion=False,
        )

        kernel_params = []
        for _ in list(ctx.avals_in) + list(ctx.avals_out):
            kernel_params.append(gpu_triton.create_array_parameter(0, 16))

        kernel_call = gpu_triton.TritonKernelCall(
            kernel,
            grid_tuple[0],
            grid_tuple[1],
            grid_tuple[2],
            kernel_params,
        )

    serialized_metadata = b""
    call_proto = kernel_call.to_proto(actual_kernel_fn.__name__, serialized_metadata)

504
505
506
507
    if input_output_aliases:
        ffi_operand_output_aliases = input_output_aliases
    else:
        ffi_operand_output_aliases = None
Phuong Nguyen's avatar
Phuong Nguyen committed
508
509
510
511
512
513

    # Use JAX FFI lowering with compressed protobuf
    rule = jax.ffi.ffi_lowering(
        "triton_kernel_call",  # Custom call target registered in gpu_triton.py
        api_version=2,
        backend_config=zlib.compress(call_proto),
514
        operand_output_aliases=ffi_operand_output_aliases,
Phuong Nguyen's avatar
Phuong Nguyen committed
515
516
517
    )

    return rule(ctx, *array_args)