Commit 319bc6b1 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[AMD][Enhancement] Add support for Vectorized FP8 DataPacking (#542)

* [Enhancement] Add support for new FP8 types in HIP code generation

* Updated `PrintConst` function in `codegen_hip.cc` to handle `float8_e4m3fnuz` type.
* Introduced new functions in `hip_fp8.h` for creating FP8 types, including `make_fp8_e4_4_t` and `make_fp8_e4_8_t`, enhancing type handling for FP8 data structures.
* Improved overall compatibility and performance for FP8 data types in HIP.

* workaround for competition

* enhance autotune

* autotune cache fix

* Implement validation for unused keys in AutoTuner configuration

* Added a check in the AutoTuner class to raise a ValueError if there are unused keys in the configuration, enhancing error handling and ensuring configuration integrity.

* lint fix

* revert changes of threads

* Update pipelining in `example_mla_decode.py` to improve performance

* Changed the number of stages in the pipelined loop from 0 to 2, enhancing the efficiency of the attention mechanism in the decoding process.

* Enhance Cython kernel validation by adding tensor attribute checks

* Updated the `CythonKernelWrapper` to include dedicated methods for validating tensor device, dtype, and static shape.
* Modified the `forward` method to utilize these new validation methods, improving error handling and ensuring input integrity.
* Updated the `lambda_forward` function in `CythonKernelAdapter` to reflect changes in validation parameters.
parent 3ca3a8af
......@@ -1067,7 +1067,7 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode *op,
for (int i = 0; i < lanes / 2; ++i) {
if (i != 0)
os << ", ";
os << "__pack_nv_bfloat162(" << v << ", " << v << ")";
os << "__pack_bfloat162(" << v << ", " << v << ")";
}
os << ')';
return;
......@@ -1151,6 +1151,10 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
os << "bfloat16_t";
os << '(' << std::scientific << op->value << 'f' << ')';
return;
} else if (op->dtype.is_float8_e4m3fnuz()) {
os << "fp8_e4_t";
os << '(' << std::scientific << op->value << 'f' << ')';
return;
}
// Type code is kFloat
switch (op->dtype.bits()) {
......
......@@ -97,3 +97,15 @@ TL_DEVICE unsigned __pack_half2(const half_t x, const half_t y) {
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}
// Pack two bfloat16_t values.
TL_DEVICE unsigned __pack_bfloat162(const bfloat16_t x, const bfloat16_t y) {
unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}
template <typename T1, typename T2>
TL_DEVICE void AtomicAdd(T1 *address, T2 val) {
atomicAdd(reinterpret_cast<T1 *>(address), static_cast<T1>(val));
}
......@@ -13,3 +13,33 @@ struct __align__(16) fp8_e4_16_t {
fp8_e4_8_t x;
fp8_e4_8_t y;
};
__device__ fp8_e4_4_t make_fp8_e4_4_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z,
fp8_e4_t w) {
// reinterpret the 4 fp8_e4_t values to signed char value and shift
signed char x_char = *reinterpret_cast<signed char *>(&x);
signed char y_char = *reinterpret_cast<signed char *>(&y);
signed char z_char = *reinterpret_cast<signed char *>(&z);
signed char w_char = *reinterpret_cast<signed char *>(&w);
int res = (w_char << 24) | (z_char << 16) | (y_char << 8) | x_char;
return *reinterpret_cast<fp8_e4_4_t *>(&res);
}
__device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z,
fp8_e4_t w, fp8_e4_t v, fp8_e4_t u,
fp8_e4_t t, fp8_e4_t s) {
signed char x_char = *reinterpret_cast<signed char *>(&x);
signed char y_char = *reinterpret_cast<signed char *>(&y);
signed char z_char = *reinterpret_cast<signed char *>(&z);
signed char w_char = *reinterpret_cast<signed char *>(&w);
signed char v_char = *reinterpret_cast<signed char *>(&v);
signed char u_char = *reinterpret_cast<signed char *>(&u);
signed char t_char = *reinterpret_cast<signed char *>(&t);
signed char s_char = *reinterpret_cast<signed char *>(&s);
int a = (w_char << 24) | (z_char << 16) | (y_char << 8) | x_char;
int b = (s_char << 24) | (t_char << 16) | (u_char << 8) | v_char;
fp8_e4_8_t res;
res.x = *reinterpret_cast<fp8_e4_4_t *>(&a);
res.y = *reinterpret_cast<fp8_e4_4_t *>(&b);
return res;
}
......@@ -98,6 +98,7 @@ class AutoTuner:
compile_args = CompileArgs()
profile_args = ProfileArgs()
_kernel_parameters: Optional[Tuple[str, ...]] = None
_lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary
cache_dir: Path = Path(TILELANG_CACHE_DIR)
......@@ -165,7 +166,7 @@ class AutoTuner:
max_mismatched_ratio: float = 0.01,
skip_check: bool = False,
manual_check_prog: Callable = None,
cache_input_tensors: bool = True):
cache_input_tensors: bool = False):
"""Set profiling arguments for the auto-tuner.
Args:
......@@ -207,12 +208,26 @@ class AutoTuner:
return self
def generate_cache_key(self) -> Optional[AutotuneResult]:
def set_kernel_parameters(self, parameters: Tuple[str, ...]):
# for cache key generation
self._kernel_parameters = parameters
def generate_cache_key(self, parameters: Dict[str, Any]) -> Optional[AutotuneResult]:
"""Generate a cache key for the auto-tuning process.
"""
# extract parameters from the function signature
op_parameters = []
for _, default_value in parameters.items():
if default_value.default is not inspect.Parameter.empty:
op_parameters.append(default_value.default)
if self._kernel_parameters is not None:
op_parameters += self._kernel_parameters
func_source = inspect.getsource(self.fn)
key_data = {
"version": __version__,
"op_parameters": tuple(op_parameters),
"func_source": func_source,
"configs": self.configs,
"compile_args": hash(self.compile_args),
......@@ -242,12 +257,16 @@ class AutoTuner:
"""
_init_logger_handlers()
key = self.generate_cache_key()
sig = inspect.signature(self.fn)
parameters = sig.parameters
key = self.generate_cache_key(parameters)
with self._lock:
if is_cache_enabled():
# First check in-memory cache
if key in self._memory_cache:
self.logger.warning("Found kernel in memory cache. For better performance," \
logger.warning("Found kernel in memory cache. For better performance," \
" consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel.")
return self._memory_cache[key]
......@@ -258,8 +277,6 @@ class AutoTuner:
self._memory_cache[key] = result
return result
sig = inspect.signature(self.fn)
parameters = sig.parameters
best_latency: float = 1e8
best_config: Optional[Dict[str, Any]] = None
best_kernel: Optional[tilelang.JITKernel] = None
......@@ -343,9 +360,13 @@ class AutoTuner:
config_args = []
for config in self.configs:
new_kwargs = {}
keys = config.keys()
for name, _ in parameters.items():
if name in config:
new_kwargs[name] = config[name]
unused_keys = set(keys) - set(new_kwargs.keys())
if len(unused_keys) > 0:
raise ValueError(f"Unused keys in config: {unused_keys}")
config_args.append(new_kwargs)
num_workers = max(1, int(get_available_cpu_count() * 0.9))
......@@ -461,8 +482,30 @@ class _AutoTunerImplementation:
rep: int = 100
timeout: int = 100
configs: Any = None
def __init__(self, configs: Any, warmup: int = 25, rep: int = 100, timeout: int = 100) -> None:
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto
ref_prog: Callable = None
supply_prog: Callable = None
rtol: float = 1e-2
atol: float = 1e-2
max_mismatched_ratio: float = 0.01
skip_check: bool = False
manual_check_prog: Callable = None
cache_input_tensors: bool = False
def __init__(self,
configs: Any,
warmup: int = 25,
rep: int = 100,
timeout: int = 100,
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto,
ref_prog: Callable = None,
supply_prog: Callable = None,
rtol: float = 1e-2,
atol: float = 1e-2,
max_mismatched_ratio: float = 0.01,
skip_check: bool = False,
manual_check_prog: Callable = None,
cache_input_tensors: bool = False) -> None:
"""Initialize the AutoTunerImplementation.
Args:
......@@ -507,8 +550,21 @@ class _AutoTunerImplementation:
def jit_compile(**config_arg):
return fn(*args, **kwargs, __tune_params=config_arg)
autotuner = AutoTuner(fn, configs=configs)
autotuner = AutoTuner(
fn, configs=configs).set_profile_args(
supply_type=self.supply_type,
ref_prog=self.ref_prog,
supply_prog=self.supply_prog,
rtol=self.rtol,
atol=self.atol,
max_mismatched_ratio=self.max_mismatched_ratio,
skip_check=self.skip_check,
manual_check_prog=self.manual_check_prog,
cache_input_tensors=self.cache_input_tensors,
)
autotuner.jit_compile = jit_compile
autotuner.set_kernel_parameters(key)
autotuner.run = partial(autotuner.run, warmup, rep, timeout)
artifact = autotuner.run()
......@@ -523,9 +579,21 @@ def autotune( # This is the new public interface
func: Union[Callable[_P, _RProg], PrimFunc, None] = None,
*, # Indicates subsequent arguments are keyword-only
configs: Any,
# profile arguments
warmup: int = 25,
rep: int = 100,
timeout: int = 100):
timeout: int = 100,
# compile arguments
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto,
ref_prog: Callable = None,
supply_prog: Callable = None,
rtol: float = 1e-2,
atol: float = 1e-2,
max_mismatched_ratio: float = 0.01,
skip_check: bool = False,
manual_check_prog: Callable = None,
cache_input_tensors: bool = False,
):
"""
Just-In-Time (JIT) compiler decorator for TileLang functions.
......@@ -569,5 +637,18 @@ def autotune( # This is the new public interface
# Create a _AutoTunerImplementation instance with the provided/defaulted arguments.
# This instance is a decorator that will be applied to the function later.
configured_decorator = _AutoTunerImplementation(
configs=configs, warmup=warmup, rep=rep, timeout=timeout)
configs=configs,
warmup=warmup,
rep=rep,
timeout=timeout,
supply_type=supply_type,
ref_prog=ref_prog,
supply_prog=supply_prog,
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio,
skip_check=skip_check,
manual_check_prog=manual_check_prog,
cache_input_tensors=cache_input_tensors,
)
return configured_decorator
......@@ -355,6 +355,8 @@ class MatrixCoreIntrinEmitter(object):
BLOCK_M = block_row_warps * warp_rows
BLOCK_N = block_col_warps * warp_cols
M_DIM, N_DIM = self.M_DIM, self.N_DIM
C_buf_dims = len(C_buf.shape)
assert C_buf_dims in {2, 4}, "C_buf should be 2D or 4D"
# STS
# MMA Store must be in simulated instead of TVM Intrins
......@@ -366,9 +368,15 @@ class MatrixCoreIntrinEmitter(object):
for i, j in T.grid(warp_rows, warp_cols):
for local_id in T.vectorized(local_size_out):
row, col = T.meta_var(mfma_store_index_map(tx, local_id))
if C_buf_dims == 2:
C_buf[(warp_m * warp_rows + i) * M_DIM + row,
(warp_n * warp_cols + j) * N_DIM +
col] = C_local_buf[i * (warp_cols * local_size_out) +
j * local_size_out + local_id]
else:
C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row,
col] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out +
local_id]
col] = C_local_buf[i * warp_cols * local_size_out +
j * local_size_out + local_id]
@T.macro
def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding):
......
......@@ -430,8 +430,17 @@ class CythonKernelAdapter(BaseKernelAdapter):
def _convert_torch_func(self) -> Callable:
"""Returns a PyTorch-compatible function wrapper for the kernel."""
def lambda_forward(*args, stream: int = -1):
return self.cython_wrapper.forward([*args], stream=stream)
def lambda_forward(*args, stream: int = -1, skip_tensor_validation: bool = False):
"""
Args:
args: List of input tensors
stream: CUDA stream ID, default to -1, will use the current stream if not specified
skip_tensor_validation: Whether to skip tensor attributes validation which
includes shape, dtype, device, etc.
"""
return self.cython_wrapper.forward([*args],
stream=stream,
skip_tensor_validation=skip_tensor_validation)
return lambda_forward
......
......@@ -23,6 +23,7 @@ cdef class CythonKernelWrapper:
list param_dtypes # Cache for parameter dtypes
list param_shapes # Cache for parameter shapes as native Python lists
object get_current_device
def __cinit__(self, result_idx, params, lib):
# Initialize wrapper with kernel configuration
self.result_idx = result_idx
......@@ -64,7 +65,30 @@ cdef class CythonKernelWrapper:
self.buffer_device_map = buffer_device_map
return self
cpdef forward(self, list inputs, int64_t stream = -1):
cpdef void _check_buffer_device(self, list tensor_list):
if isinstance(tensor_list[0], torch.Tensor):
tensor_list_device_type = tensor_list[0].device.type
for param, (buffer_idx, device) in self.buffer_device_map.items():
if isinstance(tensor_list[buffer_idx], torch.Tensor):
tensor_device = tensor_list[buffer_idx].device
if (tensor_list_device_type != device.type or
(tensor_device.index is not None and device.index is not None and tensor_device.index != device.index)):
raise ValueError(f"Buffer device mismatch for parameter {param}: expected {device}, got {tensor_device}")
cpdef void _check_buffer_dtype(self, list tensor_list):
for param, (buffer_idx, torch_dtype) in self.buffer_dtype_map.items():
if isinstance(tensor_list[buffer_idx], torch.Tensor):
if tensor_list[buffer_idx].dtype != torch_dtype:
raise ValueError(f"Buffer dtype mismatch for parameter {param}: expected {torch_dtype}, got {tensor_list[buffer_idx].dtype}")
cpdef void _check_static_shape(self, list tensor_list):
for param, (buffer_idx, shape_list) in self.static_shape_map.items():
if isinstance(tensor_list[buffer_idx], torch.Tensor):
for shape_idx, shape in shape_list:
if tensor_list[buffer_idx].shape[shape_idx] != shape:
raise ValueError(f"Static shape mismatch for parameter {param}: expected {shape} at index {shape_idx}, got {tensor_list[buffer_idx].shape}")
cpdef forward(self, list inputs, int64_t stream = -1, bint skip_tensor_validation = False):
# Validate input dimensions and prepare for kernel execution
cdef int total_params = len(self.params)
cdef int total_inputs = len(inputs)
......@@ -133,29 +157,10 @@ cdef class CythonKernelWrapper:
raise ValueError(f"Unsupported tensor type: {type(tensor)}")
# Check buffer device
# cdef str tensor_list_device_type = tensor_list[0].device.type
if isinstance(tensor_list[0], torch.Tensor):
tensor_list_device_type = tensor_list[0].device.type
for param, (buffer_idx, device) in self.buffer_device_map.items():
if isinstance(tensor_list[buffer_idx], torch.Tensor):
tensor_device = tensor_list[buffer_idx].device
# Compare device types and indices separately to handle both string and torch.device objects
if (tensor_list_device_type != device.type or
(tensor_device.index is not None and device.index is not None and tensor_device.index != device.index)):
raise ValueError(f"Buffer device mismatch for parameter {param}: expected {device}, got {tensor_device}")
# Check buffer dtype map
for param, (buffer_idx, torch_dtype) in self.buffer_dtype_map.items():
if isinstance(tensor_list[buffer_idx], torch.Tensor):
if tensor_list[buffer_idx].dtype != torch_dtype:
raise ValueError(f"Buffer dtype mismatch for parameter {param}: expected {torch_dtype}, got {tensor_list[buffer_idx].dtype}")
# Check static shape map
for param, (buffer_idx, shape_list) in self.static_shape_map.items():
if isinstance(tensor_list[buffer_idx], torch.Tensor):
for shape_idx, shape in shape_list:
if tensor_list[buffer_idx].shape[shape_idx] != shape:
raise ValueError(f"Static shape mismatch for parameter {param}: expected {shape} at index {shape_idx}, got {tensor_list[buffer_idx].shape}")
if not skip_tensor_validation:
self._check_buffer_device(tensor_list)
self._check_buffer_dtype(tensor_list)
self._check_static_shape(tensor_list)
# Add dynamic dimension values to kernel arguments
for _, (buffer_idx, shape_idx) in self.dynamic_symbolic_map.items():
......
......@@ -29,6 +29,8 @@ class LibraryGenerator(object):
def load_lib(self, lib_path: Optional[str] = None):
if lib_path is None:
lib_path = self.libpath
else:
self.libpath = lib_path
return ctypes.CDLL(lib_path)
def compile_lib(self, timeout: float = None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment