"git@developer.sourcefind.cn:yangql/composable_kernel.git" did not exist on "b2589957f3ceaa47e34c3fa8586d8b15021618da"
Commit ca730c0a authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Introduce a smarter warp partition strategy (#396)

* make it python 3.8- happy

* [Enhancement] Improve loop partitioning and vectorization logic in layout inference and loop vectorization

- Enhanced the VisitStmt_ method to support local buffer handling in parallel loops, allowing for register usage without explicit thread binding.
- Updated loop vectorization logic to simplify expressions and ensure accurate vector size calculations, improving performance and clarity in the vectorization process.

* lint fix

* [Refactor] Update warp size checks and enhance warp partitioning logic in GEMM

- Changed warp_n size check from 16 to 8 in gemm_layouts.cc to improve compatibility with specific configurations.
- Refactored warp partitioning logic in gemm.cc to prioritize N dimension for better performance based on aspect ratio.
- Introduced a new CompileArgs dataclass in autotuner to streamline compile argument management and improve code clarity.

* lint fix

* [Enhancement] Initialize jit_compile in AutoTuner class

- Added initialization for jit_compile attribute in the AutoTuner class to ensure it is set to None by default.
- Updated the assignment logic for jit_compile to prevent overwriting an existing compile function, enhancing the flexibility of the AutoTuner's compilation process.
parent 8c5b1341
...@@ -81,7 +81,7 @@ Fragment makeGemmFragmentC_F64(const int block_m, const int block_n, ...@@ -81,7 +81,7 @@ Fragment makeGemmFragmentC_F64(const int block_m, const int block_n,
ICHECK(block_m % warp_m == 0); ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0); ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 16 == 0); ICHECK(warp_m % 16 == 0);
ICHECK(warp_n % 16 == 0); ICHECK(warp_n % 8 == 0);
auto base_layout = makeGemmFragment8x8(); auto base_layout = makeGemmFragment8x8();
auto warp_layout = auto warp_layout =
base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false); base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
...@@ -98,7 +98,7 @@ Fragment makeGemmFragmentC(const int block_m, const int block_n, ...@@ -98,7 +98,7 @@ Fragment makeGemmFragmentC(const int block_m, const int block_n,
ICHECK(block_m % warp_m == 0); ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0); ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m; ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m;
ICHECK(warp_n % 16 == 0) << "warp_n=" << warp_n; ICHECK(warp_n % 8 == 0) << "warp_n=" << warp_n;
auto base_layout = makeGemmFragment8x8()->Repeat({2, 1}, false); auto base_layout = makeGemmFragment8x8()->Repeat({2, 1}, false);
auto warp_layout = auto warp_layout =
base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false); base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
......
...@@ -87,14 +87,17 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target, ...@@ -87,14 +87,17 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target,
bool M_divisible = (this->M % (factor * m_warp)) == 0; bool M_divisible = (this->M % (factor * m_warp)) == 0;
bool N_divisible = (this->N % (factor * n_warp)) == 0; bool N_divisible = (this->N % (factor * n_warp)) == 0;
if (M_divisible && N_divisible) { if (M_divisible && N_divisible) {
if (this->M / m_warp >= this->N / n_warp) // put N dimension first
m_warp *= factor; // because usually n in mma
else // is more smaller than m
if (this->N / n_warp >= this->M / m_warp)
n_warp *= factor; n_warp *= factor;
} else if (M_divisible) { else
m_warp *= factor; m_warp *= factor;
} else if (N_divisible) { } else if (N_divisible) {
n_warp *= factor; n_warp *= factor;
} else if (M_divisible) {
m_warp *= factor;
} else { } else {
ICHECK(0) << "Cannot compute warp partition for shape" << M << " " << N ICHECK(0) << "Cannot compute warp partition for shape" << M << " " << N
<< " with num_warps " << num_warps; << " with num_warps " << num_warps;
...@@ -103,7 +106,6 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target, ...@@ -103,7 +106,6 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target,
} else { } else {
ICHECK(0) << "Unknown GemmWarpPolicy"; ICHECK(0) << "Unknown GemmWarpPolicy";
} }
// TODO: perform more checks here
return {m_warp, n_warp}; return {m_warp, n_warp};
} }
......
...@@ -87,6 +87,39 @@ class AutotuneResult: ...@@ -87,6 +87,39 @@ class AutotuneResult:
kernel: Callable kernel: Callable
@dataclass(frozen=True)
class CompileArgs:
"""Compile arguments for the auto-tuner.
Attributes:
out_idx: List of output tensor indices.
supply_type: Type of tensor supply mechanism.
ref_prog: Reference program for correctness validation.
supply_prog: Supply program for input tensors.
out_idx: Union[List[int], int] = -1
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto
ref_prog: Callable = None
supply_prog: Callable = None
rtol: float = 1e-2
atol: float = 1e-2
max_mismatched_ratio: float = 0.01
skip_check: bool = False
cache_input_tensors: bool = True
target: Literal['auto', 'cuda', 'hip'] = 'auto'
"""
out_idx: Union[List[int], int] = -1
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto
ref_prog: Callable = None
supply_prog: Callable = None
rtol: float = 1e-2
atol: float = 1e-2
max_mismatched_ratio: float = 0.01
skip_check: bool = False
cache_input_tensors: bool = True
target: Literal['auto', 'cuda', 'hip'] = 'auto'
class AutoTuner: class AutoTuner:
"""Auto-tuner for tilelang programs. """Auto-tuner for tilelang programs.
...@@ -104,6 +137,8 @@ class AutoTuner: ...@@ -104,6 +137,8 @@ class AutoTuner:
self.ref_latency_cache = None self.ref_latency_cache = None
self.jit_input_tensors = None self.jit_input_tensors = None
self.ref_input_tensors = None self.ref_input_tensors = None
self.jit_compile = None
self.compile_args = CompileArgs()
@classmethod @classmethod
def from_kernel(cls, kernel: Callable, configs): def from_kernel(cls, kernel: Callable, configs):
...@@ -146,6 +181,17 @@ class AutoTuner: ...@@ -146,6 +181,17 @@ class AutoTuner:
Returns: Returns:
AutoTuner: Self for method chaining. AutoTuner: Self for method chaining.
""" """
self.compile_args = CompileArgs(
out_idx=out_idx,
supply_type=supply_type,
ref_prog=ref_prog,
supply_prog=supply_prog,
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio,
skip_check=skip_check,
cache_input_tensors=cache_input_tensors,
target=target)
# If a custom `supply_prog`` is provided, the profiler's `supply_type` setting # If a custom `supply_prog`` is provided, the profiler's `supply_type` setting
# becomes ineffective. The custom supply program will be used instead. # becomes ineffective. The custom supply program will be used instead.
...@@ -153,23 +199,6 @@ class AutoTuner: ...@@ -153,23 +199,6 @@ class AutoTuner:
logger.warning("Ignoring `supply_type` passed to `set_compile_args` because " logger.warning("Ignoring `supply_type` passed to `set_compile_args` because "
"`ref_prog` is not None.") "`ref_prog` is not None.")
def _compile(*config_arg):
kernel = tilelang.compile(self.fn(*config_arg), out_idx=out_idx, target=target)
jit_context = JITContext(
out_idx=out_idx,
ref_prog=ref_prog,
supply_prog=supply_prog,
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio,
skip_check=skip_check,
cache_input_tensors=cache_input_tensors,
kernel=kernel,
supply_type=supply_type,
target=target)
return jit_context
self.jit_compile = _compile
return self return self
def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30): def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30):
...@@ -191,6 +220,27 @@ class AutoTuner: ...@@ -191,6 +220,27 @@ class AutoTuner:
best_config = None best_config = None
best_jit_context = None best_jit_context = None
def _compile(*config_arg):
compile_args = self.compile_args
kernel = tilelang.compile(
self.fn(*config_arg), out_idx=compile_args.out_idx, target=compile_args.target)
jit_context = JITContext(
out_idx=compile_args.out_idx,
ref_prog=compile_args.ref_prog,
supply_prog=compile_args.supply_prog,
rtol=compile_args.rtol,
atol=compile_args.atol,
max_mismatched_ratio=compile_args.max_mismatched_ratio,
skip_check=compile_args.skip_check,
cache_input_tensors=compile_args.cache_input_tensors,
kernel=kernel,
supply_type=compile_args.supply_type,
target=compile_args.target)
return jit_context
if self.jit_compile is None:
self.jit_compile = _compile
def target_fn(jit_context: JITContext): def target_fn(jit_context: JITContext):
# Unpack the context # Unpack the context
kernel = jit_context.kernel kernel = jit_context.kernel
......
...@@ -104,10 +104,10 @@ class GemmWarpPolicy(IntEnum): ...@@ -104,10 +104,10 @@ class GemmWarpPolicy(IntEnum):
# Assign the factor to either m_warp or n_warp based on divisibility and aspect ratio. # Assign the factor to either m_warp or n_warp based on divisibility and aspect ratio.
if M_divisible and N_divisible: if M_divisible and N_divisible:
# Prefer to assign to rows if M is larger, otherwise to columns. # Prefer to assign to rows if M is larger, otherwise to columns.
if M / m_warp >= N / n_warp: if N / n_warp >= M / m_warp:
m_warp *= factor
else:
n_warp *= factor n_warp *= factor
else:
m_warp *= factor
elif M_divisible: elif M_divisible:
m_warp *= factor m_warp *= factor
elif N_divisible: elif N_divisible:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment