Unverified Commit 407117e1 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Layout] Introduce a new layout inference mechanism (#699)



* Implement new free stage layout inference.

* Fix bug

* Make replication upcasting and unnormalizable iterators safe.

* Better handling of updating with more replica

* Remove unnecessary check.

* Fix compilation.

* Fix setup.py.

* Simplify development mode.

* Allow ParallelOp layout when there's already a compatible layout specified

* lint fix

* Add ProveFragmentContains function to validate thread access between small and large fragments

This function checks if the threads accessing elements of a smaller fragment are a subset of those accessing a larger fragment, ensuring valid access during updates. The implementation includes deriving thread indices, computing logical indices, and verifying thread mappings.

* Update dependencies in requirements files

* Remove 'thefuzz' from requirements-dev.txt
* Specify exact versions for 'torch' and add 'flash_attn' in requirements-test.txt

* Update CI workflow to use SHA256 hash for requirements file

* Update requirements and CI workflow for flash attention

* Removed specific version for 'torch' in requirements-test.txt
* Added installation of 'flash_attn==2.5.8' in CI workflow to ensure compatibility

* Refactor flash attention import handling in examples

* Removed availability checks for 'flash_attn' in multiple example scripts.
* Simplified import statements for 'flash_attn' to ensure consistent usage across examples.

---------
Co-authored-by: default avatarHuanqi Cao <caohuanqi@deepseek.com>
parent 87aae294
...@@ -26,7 +26,7 @@ jobs: ...@@ -26,7 +26,7 @@ jobs:
- name: Ensure venv (local & persistent) - name: Ensure venv (local & persistent)
run: | run: |
set -e set -e
REQS_HASH=$(cat requirements-test.txt 2>/dev/null || true) REQS_HASH=$(sha256sum requirements-test.txt 2>/dev/null | awk '{print $1}' || echo "no_requirements")
MARKER="${{ runner.tool_cache }}/.venv_marker_${{ env.PYTHON_VERSION }}_${REQS_HASH:0:8}" MARKER="${{ runner.tool_cache }}/.venv_marker_${{ env.PYTHON_VERSION }}_${REQS_HASH:0:8}"
if [[ -f "$MARKER" ]] && [[ -f "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" ]]; then if [[ -f "$MARKER" ]] && [[ -f "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" ]]; then
...@@ -40,6 +40,7 @@ jobs: ...@@ -40,6 +40,7 @@ jobs:
python -m pip install --upgrade pip --no-user python -m pip install --upgrade pip --no-user
[[ -f requirements-test.txt ]] && \ [[ -f requirements-test.txt ]] && \
PIP_NO_BUILD_ISOLATION=1 pip install -r requirements-test.txt --no-user PIP_NO_BUILD_ISOLATION=1 pip install -r requirements-test.txt --no-user
pip install flash_attn==2.5.8 --no-user --no-build-isolation
touch "$MARKER" touch "$MARKER"
fi fi
...@@ -94,6 +95,8 @@ jobs: ...@@ -94,6 +95,8 @@ jobs:
python -m pip install --upgrade pip --no-user python -m pip install --upgrade pip --no-user
[[ -f requirements-test.txt ]] && \ [[ -f requirements-test.txt ]] && \
PIP_NO_BUILD_ISOLATION=1 pip install -r requirements-test.txt --no-user PIP_NO_BUILD_ISOLATION=1 pip install -r requirements-test.txt --no-user
# flash attention usually requires no isolation build
pip install flash_attn==2.5.8 --no-user --no-build-isolation
pip install . --no-user pip install . --no-user
touch "$MARKER" touch "$MARKER"
fi fi
......
...@@ -517,20 +517,11 @@ def main(args): ...@@ -517,20 +517,11 @@ def main(args):
output_sparse = sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens, output_sparse = sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens,
block_table) block_table)
is_flash_attn_2_available = False import flash_attn # noqa: F401
try:
import flash_attn # noqa: F401
is_flash_attn_2_available = True
except:
pass
output_ref_torch = ref_program_torch_paged(Q, K_cache, V_cache, block_indices, cache_seqlens, output_ref_torch = ref_program_torch_paged(Q, K_cache, V_cache, block_indices, cache_seqlens,
block_table, page_block_size, block_N) block_table, page_block_size, block_N)
if not is_flash_attn_2_available:
print("FlashAttn 2 is not available, skipping FA reference and performance measurement")
return
output_ref_fa = ref_program_fa(Q, K_cache, V_cache, cache_seqlens, block_table) output_ref_fa = ref_program_fa(Q, K_cache, V_cache, cache_seqlens, block_table)
# Check correctness # Check correctness
if sparse_ratio == 0.0: if sparse_ratio == 0.0:
......
...@@ -439,16 +439,7 @@ def main(batch=8, ...@@ -439,16 +439,7 @@ def main(batch=8,
out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) out = sparse_kernel(Q, K, V, block_indices, cache_seqlens)
debug("output", ref, out, atol=1e-3, rtol=1e-3) debug("output", ref, out, atol=1e-3, rtol=1e-3)
is_flash_attn_2_available = False import flash_attn # noqa: F401
try:
import flash_attn # noqa: F401
is_flash_attn_2_available = True
except ImportError:
pass
if not is_flash_attn_2_available:
print("FlashAttn 2 is not available, skipping FA reference and performance measurement")
return
## latency reference ## latency reference
for _ in range(10): for _ in range(10):
......
...@@ -419,16 +419,7 @@ def main(batch=8, ...@@ -419,16 +419,7 @@ def main(batch=8,
out = model(Q, K, V, block_mask, cache_seqlens) out = model(Q, K, V, block_mask, cache_seqlens)
debug("output", ref, out, atol=1e-3, rtol=1e-3) debug("output", ref, out, atol=1e-3, rtol=1e-3)
is_flash_attn_2_available = False import flash_attn # noqa: F401
try:
import flash_attn # noqa: F401
is_flash_attn_2_available = True
except ImportError:
pass
if not is_flash_attn_2_available:
print("FlashAttn 2 is not available, skipping FA reference and performance measurement")
return
## latency reference ## latency reference
for _ in range(10): for _ in range(10):
......
...@@ -449,16 +449,7 @@ def main(batch=64, ...@@ -449,16 +449,7 @@ def main(batch=64,
print(f"Average time: {avg_time:.6f} seconds") print(f"Average time: {avg_time:.6f} seconds")
# Measure performance of reference implementation # Measure performance of reference implementation
is_flash_attn_2_available = False import flash_attn # noqa: F401
try:
import flash_attn # noqa: F401
is_flash_attn_2_available = True
except ImportError:
pass
if not is_flash_attn_2_available:
print("FlashAttn 2 is not available, skipping FA reference and performance measurement")
return
start = time.time() start = time.time()
for _ in range(1000): for _ in range(1000):
......
...@@ -429,17 +429,7 @@ def main(batch=64, ...@@ -429,17 +429,7 @@ def main(batch=64,
print(f"Average time: {avg_time:.6f} seconds") print(f"Average time: {avg_time:.6f} seconds")
print(f"Average flops: {avg_flops:.2f} GFLOPS") print(f"Average flops: {avg_flops:.2f} GFLOPS")
is_flash_attn_2_available = False import flash_attn # noqa: F401
try:
import flash_attn # noqa: F401
is_flash_attn_2_available = True
except ImportError:
pass
# Measure performance of reference implementation
if not is_flash_attn_2_available:
print("FlashAttn 2 is not available, skipping FA reference and performance measurement")
return
start = time.time() start = time.time()
for _ in range(1000): for _ in range(1000):
......
...@@ -412,16 +412,7 @@ def main(batch: int = 2, heads: int = 16, seq_len: int = 256, dim: int = 32): ...@@ -412,16 +412,7 @@ def main(batch: int = 2, heads: int = 16, seq_len: int = 256, dim: int = 32):
) )
torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2)
is_flash_attn_2_available = False import flash_attn
try:
import flash_attn
is_flash_attn_2_available = True
except:
pass
if not is_flash_attn_2_available:
print("FlashAttn 2 is not available, skipping FA reference and performance measurement")
return
fla_out_unpad = flash_attn.flash_attn_varlen_func( fla_out_unpad = flash_attn.flash_attn_varlen_func(
q_unpad, q_unpad,
......
...@@ -21,7 +21,6 @@ ml_dtypes ...@@ -21,7 +21,6 @@ ml_dtypes
psutil psutil
scipy scipy
torch torch
thefuzz
tabulate tabulate
wheel wheel
setuptools setuptools
\ No newline at end of file
...@@ -21,7 +21,6 @@ cloudpickle ...@@ -21,7 +21,6 @@ cloudpickle
ml_dtypes ml_dtypes
psutil psutil
torch torch
thefuzz
tabulate tabulate
wheel wheel
setuptools setuptools
......
...@@ -4,8 +4,6 @@ import shutil ...@@ -4,8 +4,6 @@ import shutil
from setuptools import setup, find_packages, Extension from setuptools import setup, find_packages, Extension
from setuptools.command.build_py import build_py from setuptools.command.build_py import build_py
from setuptools.command.sdist import sdist from setuptools.command.sdist import sdist
from setuptools.command.develop import develop
import distutils.dir_util
from typing import List, Optional from typing import List, Optional
import re import re
import tarfile import tarfile
...@@ -18,7 +16,7 @@ import hashlib ...@@ -18,7 +16,7 @@ import hashlib
import sysconfig import sysconfig
import functools import functools
import urllib.request import urllib.request
from distutils.version import LooseVersion from packaging.version import Version
import platform import platform
import multiprocessing import multiprocessing
from setuptools.command.build_ext import build_ext from setuptools.command.build_ext import build_ext
...@@ -117,7 +115,7 @@ def get_nvcc_cuda_version(): ...@@ -117,7 +115,7 @@ def get_nvcc_cuda_version():
nvcc_output = subprocess.check_output(["nvcc", "-V"], universal_newlines=True) nvcc_output = subprocess.check_output(["nvcc", "-V"], universal_newlines=True)
output = nvcc_output.split() output = nvcc_output.split()
release_idx = output.index("release") + 1 release_idx = output.index("release") + 1
nvcc_cuda_version = LooseVersion(output[release_idx].split(",")[0]) nvcc_cuda_version = Version(output[release_idx].split(",")[0])
return nvcc_cuda_version return nvcc_cuda_version
...@@ -128,7 +126,7 @@ def get_rocm_version(): ...@@ -128,7 +126,7 @@ def get_rocm_version():
# Example output: ROCM version: x.y.z-... # Example output: ROCM version: x.y.z-...
match = re.search(r'ROCm Version: (\d+\.\d+\.\d+)', rocm_output) match = re.search(r'ROCm Version: (\d+\.\d+\.\d+)', rocm_output)
if match: if match:
return LooseVersion(match.group(1)) return Version(match.group(1))
else: else:
rocm_path = os.environ.get("ROCM_PATH", "/opt/rocm") rocm_path = os.environ.get("ROCM_PATH", "/opt/rocm")
rocm_version_file = os.path.join(rocm_path, "lib", "cmake", "rocm", rocm_version_file = os.path.join(rocm_path, "lib", "cmake", "rocm",
...@@ -138,9 +136,9 @@ def get_rocm_version(): ...@@ -138,9 +136,9 @@ def get_rocm_version():
content = f.read() content = f.read()
match = re.search(r'set\(PACKAGE_VERSION "(\d+\.\d+\.\d+)"', content) match = re.search(r'set\(PACKAGE_VERSION "(\d+\.\d+\.\d+)"', content)
if match: if match:
return LooseVersion(match.group(1)) return Version(match.group(1))
# return a default # return a default
return LooseVersion("5.0.0") return Version("5.0.0")
def get_tilelang_version(with_cuda=True, with_system_info=True, with_commit_id=False) -> str: def get_tilelang_version(with_cuda=True, with_system_info=True, with_commit_id=False) -> str:
...@@ -418,7 +416,7 @@ class TileLangBuilPydCommand(build_py): ...@@ -418,7 +416,7 @@ class TileLangBuilPydCommand(build_py):
target_dir = os.path.join(self.build_lib, item) target_dir = os.path.join(self.build_lib, item)
if os.path.isdir(source_dir): if os.path.isdir(source_dir):
self.mkpath(target_dir) self.mkpath(target_dir)
distutils.dir_util.copy_tree(source_dir, target_dir) self.copy_tree(source_dir, target_dir)
else: else:
target_dir = os.path.dirname(target_dir) target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir): if not os.path.exists(target_dir):
...@@ -434,7 +432,7 @@ class TileLangBuilPydCommand(build_py): ...@@ -434,7 +432,7 @@ class TileLangBuilPydCommand(build_py):
target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item) target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item)
if os.path.isdir(source_dir): if os.path.isdir(source_dir):
self.mkpath(target_dir) self.mkpath(target_dir)
distutils.dir_util.copy_tree(source_dir, target_dir) self.copy_tree(source_dir, target_dir)
else: else:
target_dir = os.path.dirname(target_dir) target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir): if not os.path.exists(target_dir):
...@@ -511,7 +509,7 @@ class TileLangBuilPydCommand(build_py): ...@@ -511,7 +509,7 @@ class TileLangBuilPydCommand(build_py):
target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item) target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item)
if os.path.isdir(source_dir): if os.path.isdir(source_dir):
self.mkpath(target_dir) self.mkpath(target_dir)
distutils.dir_util.copy_tree(source_dir, target_dir) self.copy_tree(source_dir, target_dir)
else: else:
target_dir = os.path.dirname(target_dir) target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir): if not os.path.exists(target_dir):
...@@ -528,7 +526,7 @@ class TileLangBuilPydCommand(build_py): ...@@ -528,7 +526,7 @@ class TileLangBuilPydCommand(build_py):
target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item) target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item)
if os.path.isdir(source_dir): if os.path.isdir(source_dir):
self.mkpath(target_dir) self.mkpath(target_dir)
distutils.dir_util.copy_tree(source_dir, target_dir) self.copy_tree(source_dir, target_dir)
else: else:
target_dir = os.path.dirname(target_dir) target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir): if not os.path.exists(target_dir):
...@@ -544,7 +542,7 @@ class TileLangBuilPydCommand(build_py): ...@@ -544,7 +542,7 @@ class TileLangBuilPydCommand(build_py):
target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item) target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item)
if os.path.isdir(source_dir): if os.path.isdir(source_dir):
self.mkpath(target_dir) self.mkpath(target_dir)
distutils.dir_util.copy_tree(source_dir, target_dir) self.copy_tree(source_dir, target_dir)
else: else:
target_dir = os.path.dirname(target_dir) target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir): if not os.path.exists(target_dir):
...@@ -570,7 +568,7 @@ class TileLangBuilPydCommand(build_py): ...@@ -570,7 +568,7 @@ class TileLangBuilPydCommand(build_py):
if os.path.isdir(source_dir): if os.path.isdir(source_dir):
self.mkpath(target_dir) self.mkpath(target_dir)
distutils.dir_util.copy_tree(source_dir, target_dir) self.copy_tree(source_dir, target_dir)
else: else:
target_dir = os.path.dirname(target_dir) target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir): if not os.path.exists(target_dir):
...@@ -588,54 +586,6 @@ class TileLangSdistCommand(sdist): ...@@ -588,54 +586,6 @@ class TileLangSdistCommand(sdist):
super().make_distribution() super().make_distribution()
# ------------------------------------------------------------------------
# NEW: Add a custom 'develop' command so that `pip install -e .` works.
# ------------------------------------------------------------------------
class TileLangDevelopCommand(develop):
"""
Customized setuptools 'develop' command for an editable install.
Ensures the extension is built and all necessary assets are copied.
"""
def run(self):
logger.info("Running TileLangDevelopCommand")
# 1. Build the C/C++ extension modules
self.run_command("build_ext")
build_ext_cmd = self.get_finalized_command("build_ext")
ext_modules = build_ext_cmd.extensions
for ext in ext_modules:
extdir = build_ext_cmd.get_ext_fullpath(ext.name)
logger.info(f"Extension {ext.name} output directory: {extdir}")
ext_output_dir = os.path.dirname(extdir)
logger.info(f"Extension output directory (parent): {ext_output_dir}")
# Copy the built TVM to the package directory
TVM_PREBUILD_ITEMS = [
f"{ext_output_dir}/libtvm_runtime.so",
f"{ext_output_dir}/libtvm.so",
f"{ext_output_dir}/libtilelang.so",
f"{ext_output_dir}/libtilelang_module.so",
]
for item in TVM_PREBUILD_ITEMS:
source_lib_file = os.path.join(ROOT_DIR, item)
# only copy the file
file_name = os.path.basename(item)
target_dir = os.path.join(PACKAGE_NAME, file_name)
target_dir = os.path.dirname(target_dir)
target_dir = os.path.join(target_dir, "lib")
if not os.path.exists(target_dir):
os.makedirs(target_dir)
if os.path.exists(source_lib_file):
patch_libs(source_lib_file)
shutil.copy2(source_lib_file, target_dir)
# remove the original file
os.remove(source_lib_file)
else:
logger.info(f"INFO: {source_lib_file} does not exist.")
class CMakeExtension(Extension): class CMakeExtension(Extension):
""" """
A specialized setuptools Extension class for building a CMake project. A specialized setuptools Extension class for building a CMake project.
...@@ -811,18 +761,31 @@ class TilelangExtensionBuild(build_ext): ...@@ -811,18 +761,31 @@ class TilelangExtensionBuild(build_ext):
# Determine the directory where the final .so or .pyd library should go. # Determine the directory where the final .so or .pyd library should go.
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
# To make it compatible with in-place build and avoid redundant link during incremental build,
# we need to change the build destination to tilelang/lib, where it's actually loaded
if self.inplace:
extdir = os.path.abspath('./tilelang/lib/')
logger.info(f"{extdir=}")
# Prepare arguments for the CMake configuration step. # Prepare arguments for the CMake configuration step.
# -DCMAKE_LIBRARY_OUTPUT_DIRECTORY sets where built libraries go # -DCMAKE_LIBRARY_OUTPUT_DIRECTORY sets where built libraries go
# -DPYTHON_EXECUTABLE ensures that the correct Python is used # -DPYTHON_EXECUTABLE ensures that the correct Python is used
cmake_args = [ cmake_args = [
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}", f"-DPython_EXECUTABLE={sys.executable}", f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}",
f"-DCMAKE_BUILD_TYPE={'Debug' if DEBUG_MODE else 'Release'}" f"-DPython_EXECUTABLE={sys.executable}",
f"-DCMAKE_BUILD_TYPE={'Debug' if DEBUG_MODE else 'Release'}",
"-G",
"Ninja",
] ]
if not USE_ROCM: if not USE_ROCM:
cmake_args.append(f"-DCMAKE_CUDA_COMPILER={os.path.join(CUDA_HOME, 'bin', 'nvcc')}") cmake_args.append(f"-DCMAKE_CUDA_COMPILER={os.path.join(CUDA_HOME, 'bin', 'nvcc')}")
# Create the temporary build directory (if it doesn't exist). # Create the temporary build directory (if it doesn't exist).
build_temp = os.path.abspath(self.build_temp) if self.inplace:
build_temp = os.path.abspath('./build')
else:
build_temp = os.path.abspath(self.build_temp)
os.makedirs(build_temp, exist_ok=True) os.makedirs(build_temp, exist_ok=True)
# Copy the default 'config.cmake' from the source tree into our build directory. # Copy the default 'config.cmake' from the source tree into our build directory.
...@@ -884,6 +847,5 @@ setup( ...@@ -884,6 +847,5 @@ setup(
"build_py": TileLangBuilPydCommand, "build_py": TileLangBuilPydCommand,
"sdist": TileLangSdistCommand, "sdist": TileLangSdistCommand,
"build_ext": TilelangExtensionBuild, "build_ext": TilelangExtensionBuild,
"develop": TileLangDevelopCommand,
}, },
) )
...@@ -124,7 +124,11 @@ Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr> &exprs, ...@@ -124,7 +124,11 @@ Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr> &exprs,
Array<IterSplitExpr> results; Array<IterSplitExpr> results;
for (const IterMark &mark : collector.visited_) { for (const IterMark &mark : collector.visited_) {
ICHECK(mark->source.as<Var>()) << "Not a normalized iterator: " << mark; if (!mark->source.as<Var>()) {
std::ostringstream oss;
oss << "Not a normalized iterator: " << mark;
throw NormalizeIterException(oss.str());
}
} }
for (const IterVar &iter : input_iters) { for (const IterVar &iter : input_iters) {
......
...@@ -14,6 +14,15 @@ namespace tl { ...@@ -14,6 +14,15 @@ namespace tl {
using namespace tir; using namespace tir;
class NormalizeIterException : public std::exception {
public:
const char *what() const noexcept override { return msg_.c_str(); }
NormalizeIterException(const std::string &msg) : msg_(msg) {}
private:
std::string msg_;
};
/*! /*!
* \brief Collect the IterSplit that is not used in expr. * \brief Collect the IterSplit that is not used in expr.
* *
......
...@@ -23,6 +23,19 @@ public: ...@@ -23,6 +23,19 @@ public:
static const Op &Get(); static const Op &Get();
AtomicAdd(const AtomicAdd &other)
: args_(other.args_), src(other.src), dst(other.dst),
src_range(other.src_range), dst_range(other.dst_range),
coalesced_width(other.coalesced_width) {
// No clone nullptr
if (other.par_op_)
par_op_ = std::unique_ptr<ParallelOp>(
static_cast<ParallelOp *>(other.par_op_->Clone().release()));
}
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<AtomicAdd>(*this);
}
protected: protected:
For MakeSIMTLoop(arith::Analyzer *analyzer) const; For MakeSIMTLoop(arith::Analyzer *analyzer) const;
Array<IterVar> MakeIterVars() const; Array<IterVar> MakeIterVars() const;
......
...@@ -51,6 +51,10 @@ public: ...@@ -51,6 +51,10 @@ public:
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final; Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
static const Op &Get(); static const Op &Get();
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<Conv2DIm2ColOp>(*this);
}
private: private:
Buffer src, dst; Buffer src, dst;
int stride, padding, dilation, kernel; int stride, padding, dilation, kernel;
......
...@@ -373,20 +373,6 @@ LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -373,20 +373,6 @@ LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) {
arith::Analyzer analyzer; arith::Analyzer analyzer;
par_op_ = std::make_unique<ParallelOp>(MakeSIMTLoop(&analyzer)); par_op_ = std::make_unique<ParallelOp>(MakeSIMTLoop(&analyzer));
} }
if (T.layout_map.count(src) && T.layout_map.count(dst)) {
// Only compare fragment layout
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") {
const auto &src_layout = T.layout_map[src].as<Fragment>();
const auto &dst_layout = T.layout_map[dst].as<Fragment>();
if (src_layout && dst_layout) {
ICHECK((*src_layout)->IsEqual(dst_layout->get(), true))
<< "Get different layout for " << src << " and " << dst
<< "\nLHS = " << (*src_layout)->DebugOutput()
<< "\nRHS = " << (*dst_layout)->DebugOutput()
<< "\nYou may need to use a shared memory to transform the layout";
}
}
}
return par_op_->InferLayout(T, level); return par_op_->InferLayout(T, level);
} }
......
...@@ -23,6 +23,19 @@ public: ...@@ -23,6 +23,19 @@ public:
static const Op &Get(); static const Op &Get();
Copy(const Copy &other)
: args_(other.args_), src(other.src), dst(other.dst),
src_range(other.src_range), dst_range(other.dst_range),
coalesced_width(other.coalesced_width), disable_tma(other.disable_tma) {
// No clone nullptr
if (other.par_op_)
par_op_ = std::unique_ptr<ParallelOp>(
static_cast<ParallelOp *>(other.par_op_->Clone().release()));
}
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<Copy>(*this);
}
protected: protected:
Stmt LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const; Stmt LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const;
Stmt LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const; Stmt LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const;
...@@ -53,6 +66,10 @@ public: ...@@ -53,6 +66,10 @@ public:
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final; Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
static const Op &Get(); static const Op &Get();
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<Fill>(*this);
}
private: private:
For MakeSIMTLoop(arith::Analyzer *analyzer) const; For MakeSIMTLoop(arith::Analyzer *analyzer) const;
tir::Buffer dst; tir::Buffer dst;
......
...@@ -26,6 +26,10 @@ public: ...@@ -26,6 +26,10 @@ public:
kFullCol = 2, kFullCol = 2,
} policy; } policy;
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<Gemm>(*this);
}
private: private:
// Target GEMM instruction // Target GEMM instruction
enum class GemmInst { kMMA, kWGMMA, kUTCMMA, kMFMA }; enum class GemmInst { kMMA, kWGMMA, kUTCMMA, kMFMA };
......
...@@ -26,6 +26,10 @@ public: ...@@ -26,6 +26,10 @@ public:
kFullCol = 2, kFullCol = 2,
} policy; } policy;
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<GemmSP>(*this);
}
private: private:
std::pair<int, int> std::pair<int, int>
ComputeWarpPartition(int num_warps, Target target, ComputeWarpPartition(int num_warps, Target target,
......
...@@ -64,6 +64,7 @@ public: ...@@ -64,6 +64,7 @@ public:
virtual Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const; virtual Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
virtual LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level); virtual LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level);
virtual ~Operator() = default; virtual ~Operator() = default;
virtual std::unique_ptr<Operator> Clone() const = 0;
}; };
class RegionOp : public Operator { class RegionOp : public Operator {
...@@ -71,6 +72,10 @@ public: ...@@ -71,6 +72,10 @@ public:
RegionOp(Array<PrimExpr> args, BufferMap vmap); RegionOp(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get(); static const Op &Get();
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<RegionOp>(*this);
}
const Buffer &GetBuffer() const { return buffer_; } const Buffer &GetBuffer() const { return buffer_; }
const Array<Range> &GetRanges() const { return ranges_; } const Array<Range> &GetRanges() const { return ranges_; }
int GetAccessMask() const { return access_mask_; } int GetAccessMask() const { return access_mask_; }
......
...@@ -22,6 +22,64 @@ namespace attr { ...@@ -22,6 +22,64 @@ namespace attr {
constexpr const char *coalesced_width = "coalesced_width"; constexpr const char *coalesced_width = "coalesced_width";
} // namespace attr } // namespace attr
// ProveFragmentContains checks whether the threads that access elements of a
// smaller fragment (small_frag) are a subset of the threads that access
// elements of a larger fragment (large_frag) for any given loop index. This
// function ensures that if the small fragment's layout corresponds to the loop
// itself, accessing the large fragment's elements is valid. Additionally, if
// small is updated to large, the originally valid access remains valid. The
// proof is performed by:
//
// 1. Defining a variable `rep_small` to represent the replicate index of the
// small fragment that is being checked.
// 2. Using the `small_frag_indices` and `rep_small` to derive the thread
// accessing
// the element in the small fragment.
// 3. Using `large_frag_indices` to derive the physical index of the large
// fragment
// along with the thread information, and then feeding these into the inverse
// of the large fragment to obtain the logical index and replicate index.
// 4. Verifying the mapping by checking whether the computed thread using the
// inverse
// layout corresponds to the original thread calculated for the small
// fragment. If they don't match, this indicates that the inverse layout's
// domain does not include the thread and thus the access is invalid.
bool ProveFragmentContains(Fragment small_frag, Fragment large_frag,
Array<PrimExpr> small_frag_indices,
Array<PrimExpr> large_frag_indices,
arith::Analyzer &analyzer_) {
Var rep_small("__checking_frag_contains_rep");
analyzer_.Bind(rep_small,
Range(IntImm(small_frag->ReplicateExtent()->dtype, 0),
small_frag->ReplicateExtent()),
true); // Bind the replicate extent of small_frag.
// Derive thread for small_frag.
auto thread = small_frag->ForwardThread(small_frag_indices, rep_small);
// Get physical index and thread for large_frag.
auto large_frag_physical_and_thread = large_frag->Forward(large_frag_indices);
// Add small_frag's thread to the large fragment's thread info.
large_frag_physical_and_thread.push_back(thread);
// Get the inverse of the large fragment.
auto inv_large_frag = large_frag->Inverse();
// Compute logical index and replicate index using inverse layout.
auto inv_large_frag_logical_and_rep =
inv_large_frag->Forward(large_frag_physical_and_thread);
// Extract replicate index from the result.
auto inv_large_frag_rep =
inv_large_frag_logical_and_rep[inv_large_frag_logical_and_rep.size() - 1];
// Calculate thread based on the logical index and replicate index.
auto check_thread =
large_frag->ForwardThread(large_frag_indices, inv_large_frag_rep);
// Simplify the difference between the threads.
auto diff = analyzer_.Simplify(thread - check_thread);
// If the difference is zero, the threads match and the access is valid.
return is_zero(diff);
}
class IfBufferRemapLoopGenerator : public StmtExprMutator { class IfBufferRemapLoopGenerator : public StmtExprMutator {
public: public:
static For run(Stmt stmt, Map<Buffer, Buffer> buffer_remap, static For run(Stmt stmt, Map<Buffer, Buffer> buffer_remap,
...@@ -267,7 +325,8 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -267,7 +325,8 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
} }
// Step 2: Check that the loop's partition can correctly align with all source // Step 2: Check that the loop's partition can correctly align with all source
// fragment // fragment, and infer layout only when it's not yet layout-ed
LayoutMap results;
for (const auto &[buffer, _] : indice_map_) { for (const auto &[buffer, _] : indice_map_) {
if (T.layout_map.count(buffer)) { if (T.layout_map.count(buffer)) {
auto fragment = T.layout_map[buffer].as<Fragment>().value(); auto fragment = T.layout_map[buffer].as<Fragment>().value();
...@@ -278,54 +337,19 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -278,54 +337,19 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
continue; continue;
auto vars = auto vars =
loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); }); loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); });
auto lhs = loop_layout_->ForwardThread(vars, std::nullopt); if (!ProveFragmentContains(loop_layout_, fragment, vars,
auto rhs = fragment->ForwardThread(indice_map_[buffer], std::nullopt); indice_map_[buffer], analyzer_)) {
auto diff = analyzer_.Simplify(lhs - rhs); std::ostringstream oss;
ICHECK(is_zero(diff)) oss << "Layout infer conflict between " << buffer << " and "
<< "Layout infer conflict for " << buffer << " " << source_buffer << source_buffer << " in T.Parallel loop:" << std::endl
<< "\nLHS = " << lhs << "\nRHS = " << rhs; << " loop " << loop_layout_->DebugOutput() << std::endl
} << " fragment " << fragment->DebugOutput() << std::endl;
} throw LayoutConflictException(oss.str());
// Step 3: Infer other fragment's layout from the loop's partition
LayoutMap results;
for (const auto &[buffer, _] : indice_map_) {
if (!T.layout_map.count(buffer)) {
results.Set(buffer, CompleteBufferFragment(buffer)->BindThreadRange(
T.thread_bounds));
}
// Layout infer conflict for local.fragment can not be handled here
// because the source_buffer is not always available
// (zhengju) do not modify strict layout even if it is conflict with the
// dst layout. This will not influence the result because the strict
// layout is usually with rep = 1 Since the real layout map is
// controlled by layout_inference.cc, we should add this check there
if (buffer.scope() == "local.fragment" && source_buffer.defined() &&
source_buffer.scope() == "local.fragment") {
if (T.layout_map.count(buffer)) {
const FragmentNode *src_layout =
T.layout_map[buffer].as<FragmentNode>();
Fragment dst_layout_fragment =
CompleteBufferFragment(buffer)->BindThreadRange(T.thread_bounds);
const FragmentNode *dst_layout = dst_layout_fragment.as<FragmentNode>();
if (as_const_int(dst_layout->ReplicateExtent()) &&
as_const_int(src_layout->ReplicateExtent()) &&
(*as_const_int(dst_layout->ReplicateExtent()) >
*as_const_int(src_layout->ReplicateExtent()))) {
results.Set(buffer, dst_layout_fragment);
continue;
}
if (src_layout && dst_layout) {
ICHECK(src_layout->IsEqual(dst_layout, true))
<< "Layout may conflict with ParallelOp for buffer " << buffer
<< " vs. " << source_buffer << "\nError body begin:\n"
<< GetRoot()->body << "\nError body end"
<< "\nLHS = " << src_layout->DebugOutput()
<< "\nRHS = " << dst_layout->DebugOutput()
<< "\nYou may need to use a shared memory to transform the "
"layout";
}
} }
} else {
auto dst_layout =
CompleteBufferFragment(buffer)->BindThreadRange(T.thread_bounds);
results.Set(buffer, dst_layout);
} }
} }
return results; return results;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment