""" This file incorporates code from Unsloth licensed under the Apache License, Version 2.0. See the original Unsloth repository at https://github.com/unslothai/unsloth. The following line https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23 is based on code from Unsloth, located at: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 Modifications made by Yanning Chen, 2024. """ import functools import importlib import operator from typing import Callable import torch import triton import triton.language as tl from packaging.version import Version from liger_kernel.utils import infer_device def is_hip() -> bool: return torch.version.hip is not None def ensure_contiguous(fn): @functools.wraps(fn) def wrapper(ctx, *args, **kwargs): def maybe_to_contiguous(x): return x.contiguous() if isinstance(x, torch.Tensor) else x args = [maybe_to_contiguous(arg) for arg in args] kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()} return fn(ctx, *args, **kwargs) return wrapper def calculate_settings(n): # reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 MAX_FUSED_SIZE = 65536 BLOCK_SIZE = triton.next_power_of_2(n) if BLOCK_SIZE > MAX_FUSED_SIZE: raise RuntimeError( f"Cannot launch Triton kernel since n = {n} exceeds the recommended Triton blocksize = {MAX_FUSED_SIZE}." ) num_warps = 4 if BLOCK_SIZE >= 32768: num_warps = 32 if not is_hip() else 16 elif BLOCK_SIZE >= 8192: num_warps = 16 elif BLOCK_SIZE >= 2048: num_warps = 8 return BLOCK_SIZE, num_warps def compare_version(package: str, operator: Callable, target: str): try: pkg = importlib.import_module(package) except ImportError: return False pkg_version = Version(pkg.__version__) return operator(pkg_version, Version(target)) def get_amp_custom_fwd_bwd() -> Callable: device = infer_device() if compare_version("torch", operator.ge, "2.4.0"): return ( functools.partial(torch.amp.custom_fwd, device_type=device), functools.partial(torch.amp.custom_bwd, device_type=device), ) if hasattr(torch, "npu") and getattr(torch.npu, "amp", None) is not None: return torch.npu.amp.custom_fwd, torch.npu.amp.custom_bwd return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd() torch_to_triton_dtype = { torch.float32: tl.float32, torch.float16: tl.float16, torch.bfloat16: tl.bfloat16, } @triton.jit def element_mul_kernel( X_ptr, X_stride, grad_output_ptr, n_cols, BLOCK_SIZE: tl.constexpr, ): """ This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. The multiplication is performed in-place on the tensor pointed by X_ptr. Parameters: X_ptr: Pointer to the input tensor. X_stride (int): The stride of the input tensor. grad_output_ptr: Pointer to the gradient output value. n_cols (int): The number of columns in the input tensor. BLOCK_SIZE (int): The block size for Triton operations. """ # Get the program ID and convert it to int64 to avoid overflow program_id = tl.program_id(0).to(tl.int64) # Locate the start index X_ptr += program_id * X_stride # Load the gradient output value grad_output = tl.load(grad_output_ptr) # Perform the element-wise multiplication for i in range(0, n_cols, BLOCK_SIZE): X_offsets = i + tl.arange(0, BLOCK_SIZE) X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) def get_npu_core_count(default: int = 20) -> int: """Return NPU vector core count. Fallback to `default` if Triton runtime or NPU device is unavailable. """ try: utils = triton.runtime.driver.active.utils props = utils.get_device_properties(0) return int(props.get("num_vectorcore", default)) except Exception: return default def set_large_grf_mode(kernel_args: dict): """Set large GRF mode for XPU devices.""" # On XPU triton installed along with pytorch-xpu will be called `pytorch-triton-xpu`, # triton XPU installed from source will be called `triton`. if compare_version("pytorch-triton-xpu", operator.ge, "3.6.0") or compare_version("triton", operator.ge, "3.6.0"): kernel_args["grf_mode"] = "256" else: # API was changed in https://github.com/intel/intel-xpu-backend-for-triton/pull/5430 kernel_args["grf_mode"] = "large"