Commit 319bc6b1 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[AMD][Enhancement] Add support for Vectorized FP8 DataPacking (#542)

* [Enhancement] Add support for new FP8 types in HIP code generation

* Updated `PrintConst` function in `codegen_hip.cc` to handle `float8_e4m3fnuz` type.
* Introduced new functions in `hip_fp8.h` for creating FP8 types, including `make_fp8_e4_4_t` and `make_fp8_e4_8_t`, enhancing type handling for FP8 data structures.
* Improved overall compatibility and performance for FP8 data types in HIP.

* workaround for competition

* enhance autotune

* autotune cache fix

* Implement validation for unused keys in AutoTuner configuration

* Added a check in the AutoTuner class to raise a ValueError if there are unused keys in the configuration, enhancing error handling and ensuring configuration integrity.

* lint fix

* revert changes of threads

* Update pipelining in `example_mla_decode.py` to improve performance

* Changed the number of stages in the pipelined loop from 0 to 2, enhancing the efficiency of the attention mechanism in the decoding process.

* Enhance Cython kernel validation by adding tensor attribute checks

* Updated the `CythonKernelWrapper` to include dedicated methods for validating tensor device, dtype, and static shape.
* Modified the `forward` method to utilize these new validation methods, improving error handling and ensuring input integrity.
* Updated the `lambda_forward` function in `CythonKernelAdapter` to reflect changes in validation parameters.
parent 3ca3a8af
...@@ -1067,7 +1067,7 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode *op, ...@@ -1067,7 +1067,7 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode *op,
for (int i = 0; i < lanes / 2; ++i) { for (int i = 0; i < lanes / 2; ++i) {
if (i != 0) if (i != 0)
os << ", "; os << ", ";
os << "__pack_nv_bfloat162(" << v << ", " << v << ")"; os << "__pack_bfloat162(" << v << ", " << v << ")";
} }
os << ')'; os << ')';
return; return;
...@@ -1151,6 +1151,10 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os, ...@@ -1151,6 +1151,10 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
os << "bfloat16_t"; os << "bfloat16_t";
os << '(' << std::scientific << op->value << 'f' << ')'; os << '(' << std::scientific << op->value << 'f' << ')';
return; return;
} else if (op->dtype.is_float8_e4m3fnuz()) {
os << "fp8_e4_t";
os << '(' << std::scientific << op->value << 'f' << ')';
return;
} }
// Type code is kFloat // Type code is kFloat
switch (op->dtype.bits()) { switch (op->dtype.bits()) {
......
...@@ -97,3 +97,15 @@ TL_DEVICE unsigned __pack_half2(const half_t x, const half_t y) { ...@@ -97,3 +97,15 @@ TL_DEVICE unsigned __pack_half2(const half_t x, const half_t y) {
unsigned v1 = *((unsigned short *)&y); unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0; return (v1 << 16) | v0;
} }
// Pack two bfloat16_t values.
TL_DEVICE unsigned __pack_bfloat162(const bfloat16_t x, const bfloat16_t y) {
unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}
template <typename T1, typename T2>
TL_DEVICE void AtomicAdd(T1 *address, T2 val) {
atomicAdd(reinterpret_cast<T1 *>(address), static_cast<T1>(val));
}
...@@ -13,3 +13,33 @@ struct __align__(16) fp8_e4_16_t { ...@@ -13,3 +13,33 @@ struct __align__(16) fp8_e4_16_t {
fp8_e4_8_t x; fp8_e4_8_t x;
fp8_e4_8_t y; fp8_e4_8_t y;
}; };
__device__ fp8_e4_4_t make_fp8_e4_4_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z,
fp8_e4_t w) {
// reinterpret the 4 fp8_e4_t values to signed char value and shift
signed char x_char = *reinterpret_cast<signed char *>(&x);
signed char y_char = *reinterpret_cast<signed char *>(&y);
signed char z_char = *reinterpret_cast<signed char *>(&z);
signed char w_char = *reinterpret_cast<signed char *>(&w);
int res = (w_char << 24) | (z_char << 16) | (y_char << 8) | x_char;
return *reinterpret_cast<fp8_e4_4_t *>(&res);
}
__device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z,
fp8_e4_t w, fp8_e4_t v, fp8_e4_t u,
fp8_e4_t t, fp8_e4_t s) {
signed char x_char = *reinterpret_cast<signed char *>(&x);
signed char y_char = *reinterpret_cast<signed char *>(&y);
signed char z_char = *reinterpret_cast<signed char *>(&z);
signed char w_char = *reinterpret_cast<signed char *>(&w);
signed char v_char = *reinterpret_cast<signed char *>(&v);
signed char u_char = *reinterpret_cast<signed char *>(&u);
signed char t_char = *reinterpret_cast<signed char *>(&t);
signed char s_char = *reinterpret_cast<signed char *>(&s);
int a = (w_char << 24) | (z_char << 16) | (y_char << 8) | x_char;
int b = (s_char << 24) | (t_char << 16) | (u_char << 8) | v_char;
fp8_e4_8_t res;
res.x = *reinterpret_cast<fp8_e4_4_t *>(&a);
res.y = *reinterpret_cast<fp8_e4_4_t *>(&b);
return res;
}
...@@ -98,6 +98,7 @@ class AutoTuner: ...@@ -98,6 +98,7 @@ class AutoTuner:
compile_args = CompileArgs() compile_args = CompileArgs()
profile_args = ProfileArgs() profile_args = ProfileArgs()
_kernel_parameters: Optional[Tuple[str, ...]] = None
_lock = threading.Lock() # For thread safety _lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary _memory_cache = {} # In-memory cache dictionary
cache_dir: Path = Path(TILELANG_CACHE_DIR) cache_dir: Path = Path(TILELANG_CACHE_DIR)
...@@ -165,7 +166,7 @@ class AutoTuner: ...@@ -165,7 +166,7 @@ class AutoTuner:
max_mismatched_ratio: float = 0.01, max_mismatched_ratio: float = 0.01,
skip_check: bool = False, skip_check: bool = False,
manual_check_prog: Callable = None, manual_check_prog: Callable = None,
cache_input_tensors: bool = True): cache_input_tensors: bool = False):
"""Set profiling arguments for the auto-tuner. """Set profiling arguments for the auto-tuner.
Args: Args:
...@@ -207,12 +208,26 @@ class AutoTuner: ...@@ -207,12 +208,26 @@ class AutoTuner:
return self return self
def generate_cache_key(self) -> Optional[AutotuneResult]: def set_kernel_parameters(self, parameters: Tuple[str, ...]):
# for cache key generation
self._kernel_parameters = parameters
def generate_cache_key(self, parameters: Dict[str, Any]) -> Optional[AutotuneResult]:
"""Generate a cache key for the auto-tuning process. """Generate a cache key for the auto-tuning process.
""" """
# extract parameters from the function signature
op_parameters = []
for _, default_value in parameters.items():
if default_value.default is not inspect.Parameter.empty:
op_parameters.append(default_value.default)
if self._kernel_parameters is not None:
op_parameters += self._kernel_parameters
func_source = inspect.getsource(self.fn) func_source = inspect.getsource(self.fn)
key_data = { key_data = {
"version": __version__, "version": __version__,
"op_parameters": tuple(op_parameters),
"func_source": func_source, "func_source": func_source,
"configs": self.configs, "configs": self.configs,
"compile_args": hash(self.compile_args), "compile_args": hash(self.compile_args),
...@@ -242,12 +257,16 @@ class AutoTuner: ...@@ -242,12 +257,16 @@ class AutoTuner:
""" """
_init_logger_handlers() _init_logger_handlers()
key = self.generate_cache_key() sig = inspect.signature(self.fn)
parameters = sig.parameters
key = self.generate_cache_key(parameters)
with self._lock: with self._lock:
if is_cache_enabled(): if is_cache_enabled():
# First check in-memory cache # First check in-memory cache
if key in self._memory_cache: if key in self._memory_cache:
self.logger.warning("Found kernel in memory cache. For better performance," \ logger.warning("Found kernel in memory cache. For better performance," \
" consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel.") " consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel.")
return self._memory_cache[key] return self._memory_cache[key]
...@@ -258,8 +277,6 @@ class AutoTuner: ...@@ -258,8 +277,6 @@ class AutoTuner:
self._memory_cache[key] = result self._memory_cache[key] = result
return result return result
sig = inspect.signature(self.fn)
parameters = sig.parameters
best_latency: float = 1e8 best_latency: float = 1e8
best_config: Optional[Dict[str, Any]] = None best_config: Optional[Dict[str, Any]] = None
best_kernel: Optional[tilelang.JITKernel] = None best_kernel: Optional[tilelang.JITKernel] = None
...@@ -343,9 +360,13 @@ class AutoTuner: ...@@ -343,9 +360,13 @@ class AutoTuner:
config_args = [] config_args = []
for config in self.configs: for config in self.configs:
new_kwargs = {} new_kwargs = {}
keys = config.keys()
for name, _ in parameters.items(): for name, _ in parameters.items():
if name in config: if name in config:
new_kwargs[name] = config[name] new_kwargs[name] = config[name]
unused_keys = set(keys) - set(new_kwargs.keys())
if len(unused_keys) > 0:
raise ValueError(f"Unused keys in config: {unused_keys}")
config_args.append(new_kwargs) config_args.append(new_kwargs)
num_workers = max(1, int(get_available_cpu_count() * 0.9)) num_workers = max(1, int(get_available_cpu_count() * 0.9))
...@@ -461,8 +482,30 @@ class _AutoTunerImplementation: ...@@ -461,8 +482,30 @@ class _AutoTunerImplementation:
rep: int = 100 rep: int = 100
timeout: int = 100 timeout: int = 100
configs: Any = None configs: Any = None
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto
def __init__(self, configs: Any, warmup: int = 25, rep: int = 100, timeout: int = 100) -> None: ref_prog: Callable = None
supply_prog: Callable = None
rtol: float = 1e-2
atol: float = 1e-2
max_mismatched_ratio: float = 0.01
skip_check: bool = False
manual_check_prog: Callable = None
cache_input_tensors: bool = False
def __init__(self,
configs: Any,
warmup: int = 25,
rep: int = 100,
timeout: int = 100,
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto,
ref_prog: Callable = None,
supply_prog: Callable = None,
rtol: float = 1e-2,
atol: float = 1e-2,
max_mismatched_ratio: float = 0.01,
skip_check: bool = False,
manual_check_prog: Callable = None,
cache_input_tensors: bool = False) -> None:
"""Initialize the AutoTunerImplementation. """Initialize the AutoTunerImplementation.
Args: Args:
...@@ -507,8 +550,21 @@ class _AutoTunerImplementation: ...@@ -507,8 +550,21 @@ class _AutoTunerImplementation:
def jit_compile(**config_arg): def jit_compile(**config_arg):
return fn(*args, **kwargs, __tune_params=config_arg) return fn(*args, **kwargs, __tune_params=config_arg)
autotuner = AutoTuner(fn, configs=configs) autotuner = AutoTuner(
fn, configs=configs).set_profile_args(
supply_type=self.supply_type,
ref_prog=self.ref_prog,
supply_prog=self.supply_prog,
rtol=self.rtol,
atol=self.atol,
max_mismatched_ratio=self.max_mismatched_ratio,
skip_check=self.skip_check,
manual_check_prog=self.manual_check_prog,
cache_input_tensors=self.cache_input_tensors,
)
autotuner.jit_compile = jit_compile autotuner.jit_compile = jit_compile
autotuner.set_kernel_parameters(key)
autotuner.run = partial(autotuner.run, warmup, rep, timeout) autotuner.run = partial(autotuner.run, warmup, rep, timeout)
artifact = autotuner.run() artifact = autotuner.run()
...@@ -520,12 +576,24 @@ class _AutoTunerImplementation: ...@@ -520,12 +576,24 @@ class _AutoTunerImplementation:
def autotune( # This is the new public interface def autotune( # This is the new public interface
func: Union[Callable[_P, _RProg], PrimFunc, None] = None, func: Union[Callable[_P, _RProg], PrimFunc, None] = None,
*, # Indicates subsequent arguments are keyword-only *, # Indicates subsequent arguments are keyword-only
configs: Any, configs: Any,
warmup: int = 25, # profile arguments
rep: int = 100, warmup: int = 25,
timeout: int = 100): rep: int = 100,
timeout: int = 100,
# compile arguments
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto,
ref_prog: Callable = None,
supply_prog: Callable = None,
rtol: float = 1e-2,
atol: float = 1e-2,
max_mismatched_ratio: float = 0.01,
skip_check: bool = False,
manual_check_prog: Callable = None,
cache_input_tensors: bool = False,
):
""" """
Just-In-Time (JIT) compiler decorator for TileLang functions. Just-In-Time (JIT) compiler decorator for TileLang functions.
...@@ -569,5 +637,18 @@ def autotune( # This is the new public interface ...@@ -569,5 +637,18 @@ def autotune( # This is the new public interface
# Create a _AutoTunerImplementation instance with the provided/defaulted arguments. # Create a _AutoTunerImplementation instance with the provided/defaulted arguments.
# This instance is a decorator that will be applied to the function later. # This instance is a decorator that will be applied to the function later.
configured_decorator = _AutoTunerImplementation( configured_decorator = _AutoTunerImplementation(
configs=configs, warmup=warmup, rep=rep, timeout=timeout) configs=configs,
warmup=warmup,
rep=rep,
timeout=timeout,
supply_type=supply_type,
ref_prog=ref_prog,
supply_prog=supply_prog,
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio,
skip_check=skip_check,
manual_check_prog=manual_check_prog,
cache_input_tensors=cache_input_tensors,
)
return configured_decorator return configured_decorator
...@@ -355,6 +355,8 @@ class MatrixCoreIntrinEmitter(object): ...@@ -355,6 +355,8 @@ class MatrixCoreIntrinEmitter(object):
BLOCK_M = block_row_warps * warp_rows BLOCK_M = block_row_warps * warp_rows
BLOCK_N = block_col_warps * warp_cols BLOCK_N = block_col_warps * warp_cols
M_DIM, N_DIM = self.M_DIM, self.N_DIM M_DIM, N_DIM = self.M_DIM, self.N_DIM
C_buf_dims = len(C_buf.shape)
assert C_buf_dims in {2, 4}, "C_buf should be 2D or 4D"
# STS # STS
# MMA Store must be in simulated instead of TVM Intrins # MMA Store must be in simulated instead of TVM Intrins
...@@ -366,9 +368,15 @@ class MatrixCoreIntrinEmitter(object): ...@@ -366,9 +368,15 @@ class MatrixCoreIntrinEmitter(object):
for i, j in T.grid(warp_rows, warp_cols): for i, j in T.grid(warp_rows, warp_cols):
for local_id in T.vectorized(local_size_out): for local_id in T.vectorized(local_size_out):
row, col = T.meta_var(mfma_store_index_map(tx, local_id)) row, col = T.meta_var(mfma_store_index_map(tx, local_id))
C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, if C_buf_dims == 2:
col] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + C_buf[(warp_m * warp_rows + i) * M_DIM + row,
local_id] (warp_n * warp_cols + j) * N_DIM +
col] = C_local_buf[i * (warp_cols * local_size_out) +
j * local_size_out + local_id]
else:
C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row,
col] = C_local_buf[i * warp_cols * local_size_out +
j * local_size_out + local_id]
@T.macro @T.macro
def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding): def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding):
......
...@@ -430,8 +430,17 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -430,8 +430,17 @@ class CythonKernelAdapter(BaseKernelAdapter):
def _convert_torch_func(self) -> Callable: def _convert_torch_func(self) -> Callable:
"""Returns a PyTorch-compatible function wrapper for the kernel.""" """Returns a PyTorch-compatible function wrapper for the kernel."""
def lambda_forward(*args, stream: int = -1): def lambda_forward(*args, stream: int = -1, skip_tensor_validation: bool = False):
return self.cython_wrapper.forward([*args], stream=stream) """
Args:
args: List of input tensors
stream: CUDA stream ID, default to -1, will use the current stream if not specified
skip_tensor_validation: Whether to skip tensor attributes validation which
includes shape, dtype, device, etc.
"""
return self.cython_wrapper.forward([*args],
stream=stream,
skip_tensor_validation=skip_tensor_validation)
return lambda_forward return lambda_forward
......
...@@ -23,6 +23,7 @@ cdef class CythonKernelWrapper: ...@@ -23,6 +23,7 @@ cdef class CythonKernelWrapper:
list param_dtypes # Cache for parameter dtypes list param_dtypes # Cache for parameter dtypes
list param_shapes # Cache for parameter shapes as native Python lists list param_shapes # Cache for parameter shapes as native Python lists
object get_current_device object get_current_device
def __cinit__(self, result_idx, params, lib): def __cinit__(self, result_idx, params, lib):
# Initialize wrapper with kernel configuration # Initialize wrapper with kernel configuration
self.result_idx = result_idx self.result_idx = result_idx
...@@ -64,7 +65,30 @@ cdef class CythonKernelWrapper: ...@@ -64,7 +65,30 @@ cdef class CythonKernelWrapper:
self.buffer_device_map = buffer_device_map self.buffer_device_map = buffer_device_map
return self return self
cpdef forward(self, list inputs, int64_t stream = -1): cpdef void _check_buffer_device(self, list tensor_list):
if isinstance(tensor_list[0], torch.Tensor):
tensor_list_device_type = tensor_list[0].device.type
for param, (buffer_idx, device) in self.buffer_device_map.items():
if isinstance(tensor_list[buffer_idx], torch.Tensor):
tensor_device = tensor_list[buffer_idx].device
if (tensor_list_device_type != device.type or
(tensor_device.index is not None and device.index is not None and tensor_device.index != device.index)):
raise ValueError(f"Buffer device mismatch for parameter {param}: expected {device}, got {tensor_device}")
cpdef void _check_buffer_dtype(self, list tensor_list):
for param, (buffer_idx, torch_dtype) in self.buffer_dtype_map.items():
if isinstance(tensor_list[buffer_idx], torch.Tensor):
if tensor_list[buffer_idx].dtype != torch_dtype:
raise ValueError(f"Buffer dtype mismatch for parameter {param}: expected {torch_dtype}, got {tensor_list[buffer_idx].dtype}")
cpdef void _check_static_shape(self, list tensor_list):
for param, (buffer_idx, shape_list) in self.static_shape_map.items():
if isinstance(tensor_list[buffer_idx], torch.Tensor):
for shape_idx, shape in shape_list:
if tensor_list[buffer_idx].shape[shape_idx] != shape:
raise ValueError(f"Static shape mismatch for parameter {param}: expected {shape} at index {shape_idx}, got {tensor_list[buffer_idx].shape}")
cpdef forward(self, list inputs, int64_t stream = -1, bint skip_tensor_validation = False):
# Validate input dimensions and prepare for kernel execution # Validate input dimensions and prepare for kernel execution
cdef int total_params = len(self.params) cdef int total_params = len(self.params)
cdef int total_inputs = len(inputs) cdef int total_inputs = len(inputs)
...@@ -133,29 +157,10 @@ cdef class CythonKernelWrapper: ...@@ -133,29 +157,10 @@ cdef class CythonKernelWrapper:
raise ValueError(f"Unsupported tensor type: {type(tensor)}") raise ValueError(f"Unsupported tensor type: {type(tensor)}")
# Check buffer device # Check buffer device
# cdef str tensor_list_device_type = tensor_list[0].device.type if not skip_tensor_validation:
if isinstance(tensor_list[0], torch.Tensor): self._check_buffer_device(tensor_list)
tensor_list_device_type = tensor_list[0].device.type self._check_buffer_dtype(tensor_list)
for param, (buffer_idx, device) in self.buffer_device_map.items(): self._check_static_shape(tensor_list)
if isinstance(tensor_list[buffer_idx], torch.Tensor):
tensor_device = tensor_list[buffer_idx].device
# Compare device types and indices separately to handle both string and torch.device objects
if (tensor_list_device_type != device.type or
(tensor_device.index is not None and device.index is not None and tensor_device.index != device.index)):
raise ValueError(f"Buffer device mismatch for parameter {param}: expected {device}, got {tensor_device}")
# Check buffer dtype map
for param, (buffer_idx, torch_dtype) in self.buffer_dtype_map.items():
if isinstance(tensor_list[buffer_idx], torch.Tensor):
if tensor_list[buffer_idx].dtype != torch_dtype:
raise ValueError(f"Buffer dtype mismatch for parameter {param}: expected {torch_dtype}, got {tensor_list[buffer_idx].dtype}")
# Check static shape map
for param, (buffer_idx, shape_list) in self.static_shape_map.items():
if isinstance(tensor_list[buffer_idx], torch.Tensor):
for shape_idx, shape in shape_list:
if tensor_list[buffer_idx].shape[shape_idx] != shape:
raise ValueError(f"Static shape mismatch for parameter {param}: expected {shape} at index {shape_idx}, got {tensor_list[buffer_idx].shape}")
# Add dynamic dimension values to kernel arguments # Add dynamic dimension values to kernel arguments
for _, (buffer_idx, shape_idx) in self.dynamic_symbolic_map.items(): for _, (buffer_idx, shape_idx) in self.dynamic_symbolic_map.items():
......
...@@ -29,6 +29,8 @@ class LibraryGenerator(object): ...@@ -29,6 +29,8 @@ class LibraryGenerator(object):
def load_lib(self, lib_path: Optional[str] = None): def load_lib(self, lib_path: Optional[str] = None):
if lib_path is None: if lib_path is None:
lib_path = self.libpath lib_path = self.libpath
else:
self.libpath = lib_path
return ctypes.CDLL(lib_path) return ctypes.CDLL(lib_path)
def compile_lib(self, timeout: float = None): def compile_lib(self, timeout: float = None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment