Unverified Commit 9241f4fd authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Move cached kernel to srt.utils (#10776)

parent 063c3791
......@@ -5,7 +5,7 @@ import triton
import triton.language as tl
from sglang.srt.lora.utils import LoRABatchInfo
from sglang.utils import cached_triton_kernel
from sglang.srt.utils import cached_triton_kernel
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
......
......@@ -3,7 +3,7 @@ import triton
import triton.language as tl
from sglang.srt.lora.utils import LoRABatchInfo
from sglang.utils import cached_triton_kernel
from sglang.srt.utils import cached_triton_kernel
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
......
......@@ -22,6 +22,7 @@ import ctypes
import dataclasses
import functools
import importlib
import inspect
import io
import ipaddress
import itertools
......@@ -3224,3 +3225,120 @@ def get_extend_input_len_swa_limit(
# and we can only free out-of-sliding-window kv indices after each prefill.
# 3. page_size is because we want to have 1 token extra for generated tokens.
return page_size + 2 * max(sliding_window_size, chunked_prefill_size)
class CachedKernel:
"""
Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
This wrapper caches compiled Triton kernels based on keys extracted by a
user-provided key function to avoid redundant compilations.
"""
def __init__(self, fn, key_fn=None):
self.fn = fn
assert isinstance(fn, triton.runtime.jit.JITFunction)
original_fn = fn.fn
self.signature = inspect.signature(original_fn)
self.param_names = tuple(self.signature.parameters.keys())
self.num_args = len(self.param_names)
# Check that no parameters have default values
for name, param in self.signature.parameters.items():
assert (
param.default is inspect.Parameter.empty
), f"Parameter '{name}' has a default value. Default parameters are not supported in cached kernels."
functools.update_wrapper(self, original_fn)
self.kernel_cache = {}
# Store the key function
self.key_fn = key_fn
def __getitem__(self, grid):
"""
Index with grid to get a launcher function.
Returns a launcher that will handle caching based on the key function.
"""
assert (
isinstance(grid, tuple) and len(grid) <= 3
), "Grid must be a tuple with at most 3 dimensions."
# Normalize grid once
if len(grid) < 3:
grid = grid + (1,) * (3 - len(grid))
def launcher(*args, **kwargs):
cache_key = self.key_fn(args, kwargs)
cached_kernel = self.kernel_cache.get(cache_key)
if cached_kernel is None:
# First time: compile and cache the kernel
cached_kernel = self.fn[grid](*args, **kwargs)
self.kernel_cache[cache_key] = cached_kernel
return cached_kernel
else:
# Use cached kernel
all_args = self._build_args(args, kwargs)
cached_kernel[grid](*all_args)
return cached_kernel
return launcher
def _build_args(self, args, kwargs):
"""
Build the complete argument list for kernel invocation.
"""
complete_args = list(args)
for i in range(len(args), self.num_args):
name = self.param_names[i]
value = kwargs.get(name, inspect.Parameter.empty)
if value is not inspect.Parameter.empty:
complete_args.append(value)
else:
raise ValueError(f"Missing argument: {name}")
return complete_args
def _clear_cache(self):
"""
Clear the kernel cache for testing purposes.
"""
self.kernel_cache.clear()
def cached_triton_kernel(key_fn=None):
"""
Decorator that enables key-based caching for Triton kernels using a key function.
It essentially bypasses Triton's built-in caching mechanism, allowing users to
define their own caching strategy based on kernel parameters. This helps reduce
the heavy overheads of Triton kernel launch when the kernel specialization dispatch
is simple.
Usage:
@cached_triton_kernel(key_fn=lambda args, kwargs: kwargs.get('BLOCK_SIZE', 1024))
@triton.jit
def my_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
...
# Invoke normally
my_kernel[grid](x, y, BLOCK_SIZE=1024)
Args:
key_fn: A function that takes (args, kwargs) and returns the cache key(s).
The key can be a single value or a tuple of values.
Returns:
A decorator that wraps the kernel with caching functionality.
Note: Kernels with default parameter values are not supported and will raise an assertion error.
"""
def decorator(fn):
return CachedKernel(fn, key_fn)
return decorator
"""Common utilities"""
import functools
import importlib
import inspect
import json
import logging
import os
......@@ -24,7 +22,6 @@ from typing import Any, Callable, List, Optional, Tuple, Type, Union
import numpy as np
import pybase64
import requests
import triton
from IPython.display import HTML, display
from pydantic import BaseModel
from tqdm import tqdm
......@@ -552,120 +549,3 @@ def resolve_obj_by_qualname(qualname: str) -> Any:
module_name, obj_name = qualname.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, obj_name)
class CachedKernel:
"""
Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
This wrapper caches compiled Triton kernels based on keys extracted by a
user-provided key function to avoid redundant compilations.
"""
def __init__(self, fn, key_fn=None):
self.fn = fn
assert isinstance(fn, triton.runtime.jit.JITFunction)
original_fn = fn.fn
self.signature = inspect.signature(original_fn)
self.param_names = tuple(self.signature.parameters.keys())
self.num_args = len(self.param_names)
# Check that no parameters have default values
for name, param in self.signature.parameters.items():
assert (
param.default is inspect.Parameter.empty
), f"Parameter '{name}' has a default value. Default parameters are not supported in cached kernels."
functools.update_wrapper(self, original_fn)
self.kernel_cache = {}
# Store the key function
self.key_fn = key_fn
def __getitem__(self, grid):
"""
Index with grid to get a launcher function.
Returns a launcher that will handle caching based on the key function.
"""
assert (
isinstance(grid, tuple) and len(grid) <= 3
), "Grid must be a tuple with at most 3 dimensions."
# Normalize grid once
if len(grid) < 3:
grid = grid + (1,) * (3 - len(grid))
def launcher(*args, **kwargs):
cache_key = self.key_fn(args, kwargs)
cached_kernel = self.kernel_cache.get(cache_key)
if cached_kernel is None:
# First time: compile and cache the kernel
cached_kernel = self.fn[grid](*args, **kwargs)
self.kernel_cache[cache_key] = cached_kernel
return cached_kernel
else:
# Use cached kernel
all_args = self._build_args(args, kwargs)
cached_kernel[grid](*all_args)
return cached_kernel
return launcher
def _build_args(self, args, kwargs):
"""
Build the complete argument list for kernel invocation.
"""
complete_args = list(args)
for i in range(len(args), self.num_args):
name = self.param_names[i]
value = kwargs.get(name, inspect.Parameter.empty)
if value is not inspect.Parameter.empty:
complete_args.append(value)
else:
raise ValueError(f"Missing argument: {name}")
return complete_args
def _clear_cache(self):
"""
Clear the kernel cache for testing purposes.
"""
self.kernel_cache.clear()
def cached_triton_kernel(key_fn=None):
"""
Decorator that enables key-based caching for Triton kernels using a key function.
It essentially bypasses Triton's built-in caching mechanism, allowing users to
define their own caching strategy based on kernel parameters. This helps reduce
the heavy overheads of Triton kernel launch when the kernel specialization dispatch
is simple.
Usage:
@cached_triton_kernel(key_fn=lambda args, kwargs: kwargs.get('BLOCK_SIZE', 1024))
@triton.jit
def my_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
...
# Invoke normally
my_kernel[grid](x, y, BLOCK_SIZE=1024)
Args:
key_fn: A function that takes (args, kwargs) and returns the cache key(s).
The key can be a single value or a tuple of values.
Returns:
A decorator that wraps the kernel with caching functionality.
Note: Kernels with default parameter values are not supported and will raise an assertion error.
"""
def decorator(fn):
return CachedKernel(fn, key_fn)
return decorator
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment