Commit 9ba96f19 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Set default log level from waning into info (#132)

* Change default log level from WARNING to INFO in TileLang initialization

* Refactor Flash Attention Variable-Length MHA Example with Cython Backend Support

- Update `example_mha_fwd_varlen.py` to use Cython backend for kernel compilation
- Remove unused imports and simplify function signature
- Modify `flashattn` function to handle max sequence length as a separate argument
- Update kernel call to include max sequence length parameter
- Improve code readability and remove commented-out code
- Add print statement to confirm successful assertion

* Refactor code formatting in TileLang lowering and example files

- Improve line breaks and code formatting in `lower.py`, `wrapper.py`, and `tensor.py`
- Simplify line breaks and reduce unnecessary whitespace
- Enhance code readability by adjusting indentation and line breaks
- Update example MHA forward pass script with cleaner tensor initialization
parent dd5d955c
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
# ruff: noqa # ruff: noqa
import torch import torch
import tilelang import tilelang
from tilelang.autotuner import *
import tilelang.language as T import tilelang.language as T
import tilelang.testing import tilelang.testing
import argparse import argparse
...@@ -220,7 +219,7 @@ def attention_ref( ...@@ -220,7 +219,7 @@ def attention_ref(
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
def flashattn(batch_size, UQ, UKV, heads, dim, is_causal, max_seqlen_q): def flashattn(batch_size, UQ, UKV, heads, dim, is_causal):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
q_shape = [UQ, heads, dim] q_shape = [UQ, heads, dim]
k_shape = [UKV, heads, dim] k_shape = [UKV, heads, dim]
...@@ -243,6 +242,7 @@ def flashattn(batch_size, UQ, UKV, heads, dim, is_causal, max_seqlen_q): ...@@ -243,6 +242,7 @@ def flashattn(batch_size, UQ, UKV, heads, dim, is_causal, max_seqlen_q):
V_unpad: T.Buffer(v_shape, dtype), V_unpad: T.Buffer(v_shape, dtype),
cu_seqlens_q: T.Buffer([batch_size + 1], "int32"), cu_seqlens_q: T.Buffer([batch_size + 1], "int32"),
cu_seqlens_k: T.Buffer([batch_size + 1], "int32"), cu_seqlens_k: T.Buffer([batch_size + 1], "int32"),
max_seqlen_q: T.int32,
Output_unpad: T.Buffer(o_shape, dtype), Output_unpad: T.Buffer(o_shape, dtype),
): ):
with T.Kernel( with T.Kernel(
...@@ -382,16 +382,10 @@ if __name__ == "__main__": ...@@ -382,16 +382,10 @@ if __name__ == "__main__":
device = torch.device("cuda") device = torch.device("cuda")
window_size = (-1, -1) window_size = (-1, -1)
# q = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) q = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device)
# k = torch.randn( k = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device)
# batch, seq_len, heads, dim, dtype=dtype, requires_grad=True
# ).to(device)
v = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) v = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device)
q = torch.ones(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device)
k = torch.ones(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device)
# v = torch.ones(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device)
query_padding_mask = generate_random_padding_mask(seq_len, batch, device, mode="random") query_padding_mask = generate_random_padding_mask(seq_len, batch, device, mode="random")
key_padding_mask = generate_random_padding_mask(seq_len, batch, device, mode="random") key_padding_mask = generate_random_padding_mask(seq_len, batch, device, mode="random")
( (
...@@ -415,20 +409,11 @@ if __name__ == "__main__": ...@@ -415,20 +409,11 @@ if __name__ == "__main__":
UK = k_unpad.shape[0] # unpadded key length UK = k_unpad.shape[0] # unpadded key length
UKV = k_unpad.shape[0] # unpadded query key length UKV = k_unpad.shape[0] # unpadded query key length
# TODO(lei): max_seqlen_q should be a dynamic argument. program = flashattn(batch, UQ, UKV, heads, dim, causal)
program = flashattn(batch, UQ, UKV, heads, dim, causal, max_seqlen_q) kernel = tilelang.compile(program, out_idx=-1, execution_backend="cython")
# print(program) print(kernel.get_kernel_source())
kernel = tilelang.compile(program, out_idx=-1)
# print(kernel.get_kernel_source())
profiler = kernel.get_profiler()
tilelang_latency = profiler.do_bench()
print(f"Tilelang latency: {tilelang_latency} ms")
# tflops
tflops = total_flops / tilelang_latency / 1e9
out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k) out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q)
out = output_pad_fn(out_unpad) out = output_pad_fn(out_unpad)
out_ref, _ = attention_ref( out_ref, _ = attention_ref(
...@@ -451,6 +436,6 @@ if __name__ == "__main__": ...@@ -451,6 +436,6 @@ if __name__ == "__main__":
0.0, 0.0,
causal=causal, causal=causal,
) )
# TODO: Benchmark flash_attn and tilelang
fla_out = output_pad_fn(fla_out_unpad) fla_out = output_pad_fn(fla_out_unpad)
torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2)
print("Assert Equal Passed")
...@@ -51,7 +51,7 @@ def _init_logger(): ...@@ -51,7 +51,7 @@ def _init_logger():
handler.setFormatter(formatter) handler.setFormatter(formatter)
logger.addHandler(handler) logger.addHandler(handler)
logger.propagate = False logger.propagate = False
set_log_level("WARNING") set_log_level("INFO")
_init_logger() _init_logger()
......
...@@ -118,8 +118,13 @@ def tilelang_callback_hip_compile(code, target): ...@@ -118,8 +118,13 @@ def tilelang_callback_hip_compile(code, target):
def extrac_params(func: tir.PrimFunc): def extrac_params(func: tir.PrimFunc):
buffers = [func.buffer_map[var] for var in func.params] tensor_types = []
tensor_types = [relay.TensorType(buffer.shape, buffer.dtype) for buffer in buffers] for var in func.params:
if var in func.buffer_map:
tensor_types.append(
relay.TensorType(func.buffer_map[var].shape, func.buffer_map[var].dtype))
else:
tensor_types.append(relay.scalar_type(var.dtype))
return tensor_types return tensor_types
......
...@@ -22,7 +22,7 @@ import os ...@@ -22,7 +22,7 @@ import os
from pathlib import Path from pathlib import Path
import logging import logging
logger = logging.getLogger("tilelang") logger = logging.getLogger(__name__)
def get_cython_compiler() -> Optional[str]: def get_cython_compiler() -> Optional[str]:
...@@ -198,6 +198,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -198,6 +198,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
buffer_map = func.buffer_map buffer_map = func.buffer_map
dynamic_symbolic_map = {} dynamic_symbolic_map = {}
for i, param in enumerate(params): for i, param in enumerate(params):
if param in buffer_map:
buffer = buffer_map[param] buffer = buffer_map[param]
for j, shape in enumerate(buffer.shape): for j, shape in enumerate(buffer.shape):
if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map): if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map):
......
...@@ -59,7 +59,15 @@ cdef class CythonKernelWrapper: ...@@ -59,7 +59,15 @@ cdef class CythonKernelWrapper:
tensor_list.append(tensor) tensor_list.append(tensor)
# Convert tensor pointers to C void pointers for kernel call # Convert tensor pointers to C void pointers for kernel call
call_args = [ctypes.c_void_p(tensor_list[i].data_ptr()) for i in range(len(tensor_list))] call_args = []
for i in range(len(tensor_list)):
if isinstance(tensor_list[i], torch.Tensor):
call_args.append(ctypes.c_void_p(tensor_list[i].data_ptr()))
elif isinstance(tensor_list[i], int):
# Dynamic symbolics which are passed as integer arguments
call_args.append(tensor_list[i])
else:
raise ValueError(f"Unsupported tensor type: {type(tensor_list[i])}")
# Add dynamic dimension values to kernel arguments # Add dynamic dimension values to kernel arguments
for _, (buffer_idx, shape_idx) in self.dynamic_symbolic_map.items(): for _, (buffer_idx, shape_idx) in self.dynamic_symbolic_map.items():
......
...@@ -102,11 +102,17 @@ class TLCUDASourceWrapper(object): ...@@ -102,11 +102,17 @@ class TLCUDASourceWrapper(object):
function_args = [] function_args = []
# Collect function arguments based on primary function's parameters and buffer mappings # Collect function arguments based on primary function's parameters and buffer mappings
for param in self.prim_func.params: for param in self.prim_func.params:
if param in self.prim_func.buffer_map:
buffer = self.prim_func.buffer_map[param] buffer = self.prim_func.buffer_map[param]
function_args.append({ function_args.append({
"name": buffer.name, "name": buffer.name,
"type": self._TYPE_MAP[buffer.dtype] + "* __restrict__", "type": self._TYPE_MAP[buffer.dtype] + "* __restrict__",
}) })
elif isinstance(param, tvm.tir.Var):
function_args.append({"name": param.name, "type": self._TYPE_MAP[param.dtype]})
else:
raise ValueError(
f"Parameter {param} is not in the buffer map of the primary function.")
# Add dynamic symbols as integer arguments # Add dynamic symbols as integer arguments
for dyn_sym in dynamic_symbolic_set: for dyn_sym in dynamic_symbolic_set:
function_args.append({"name": dyn_sym, "type": "int"}) function_args.append({"name": dyn_sym, "type": "int"})
...@@ -284,6 +290,7 @@ class TLCUDASourceWrapper(object): ...@@ -284,6 +290,7 @@ class TLCUDASourceWrapper(object):
# Determine the set of dynamic symbols used in the function # Determine the set of dynamic symbols used in the function
dynamic_symbolic_set: List[str] = [] dynamic_symbolic_set: List[str] = []
for param in prim_func.params: for param in prim_func.params:
if param in prim_func.buffer_map:
buffer = prim_func.buffer_map[param] buffer = prim_func.buffer_map[param]
for dim in buffer.shape: for dim in buffer.shape:
if isinstance(dim, tvm.tir.Var) and (dim.name not in dynamic_symbolic_set): if isinstance(dim, tvm.tir.Var) and (dim.name not in dynamic_symbolic_set):
......
...@@ -52,6 +52,11 @@ def get_tensor_supply(supply_type: TensorSupplyType): ...@@ -52,6 +52,11 @@ def get_tensor_supply(supply_type: TensorSupplyType):
dtype = map_torch_type(str(tensor.dtype)) dtype = map_torch_type(str(tensor.dtype))
device = torch.cuda.current_device() device = torch.cuda.current_device()
if hasattr(tensor, "shape") and not tensor.shape:
raise ValueError(
f"TensorType must have a shape, but got {type(tensor)}, "
"likely you are trying to generate a random tensor with a dynamic symbolic shape.")
shape = list(map(int, tensor.shape)) shape = list(map(int, tensor.shape))
if supply_type == TensorSupplyType.Auto: if supply_type == TensorSupplyType.Auto:
if dtype == torch.float16 or dtype == torch.float32: if dtype == torch.float16 or dtype == torch.float32:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment