Unverified Commit 407117e1 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Layout] Introduce a new layout inference mechanism (#699)



* Implement new free stage layout inference.

* Fix bug

* Make replication upcasting and unnormalizable iterators safe.

* Better handling of updating with more replica

* Remove unnecessary check.

* Fix compilation.

* Fix setup.py.

* Simplify development mode.

* Allow ParallelOp layout when there's already a compatible layout specified

* lint fix

* Add ProveFragmentContains function to validate thread access between small and large fragments

This function checks if the threads accessing elements of a smaller fragment are a subset of those accessing a larger fragment, ensuring valid access during updates. The implementation includes deriving thread indices, computing logical indices, and verifying thread mappings.

* Update dependencies in requirements files

* Remove 'thefuzz' from requirements-dev.txt
* Specify exact versions for 'torch' and add 'flash_attn' in requirements-test.txt

* Update CI workflow to use SHA256 hash for requirements file

* Update requirements and CI workflow for flash attention

* Removed specific version for 'torch' in requirements-test.txt
* Added installation of 'flash_attn==2.5.8' in CI workflow to ensure compatibility

* Refactor flash attention import handling in examples

* Removed availability checks for 'flash_attn' in multiple example scripts.
* Simplified import statements for 'flash_attn' to ensure consistent usage across examples.

---------
Co-authored-by: default avatarHuanqi Cao <caohuanqi@deepseek.com>
parent 87aae294
......@@ -26,7 +26,7 @@ jobs:
- name: Ensure venv (local & persistent)
run: |
set -e
REQS_HASH=$(cat requirements-test.txt 2>/dev/null || true)
REQS_HASH=$(sha256sum requirements-test.txt 2>/dev/null | awk '{print $1}' || echo "no_requirements")
MARKER="${{ runner.tool_cache }}/.venv_marker_${{ env.PYTHON_VERSION }}_${REQS_HASH:0:8}"
if [[ -f "$MARKER" ]] && [[ -f "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" ]]; then
......@@ -40,6 +40,7 @@ jobs:
python -m pip install --upgrade pip --no-user
[[ -f requirements-test.txt ]] && \
PIP_NO_BUILD_ISOLATION=1 pip install -r requirements-test.txt --no-user
pip install flash_attn==2.5.8 --no-user --no-build-isolation
touch "$MARKER"
fi
......@@ -94,6 +95,8 @@ jobs:
python -m pip install --upgrade pip --no-user
[[ -f requirements-test.txt ]] && \
PIP_NO_BUILD_ISOLATION=1 pip install -r requirements-test.txt --no-user
# flash attention usually requires no isolation build
pip install flash_attn==2.5.8 --no-user --no-build-isolation
pip install . --no-user
touch "$MARKER"
fi
......
......@@ -517,20 +517,11 @@ def main(args):
output_sparse = sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens,
block_table)
is_flash_attn_2_available = False
try:
import flash_attn # noqa: F401
is_flash_attn_2_available = True
except:
pass
import flash_attn # noqa: F401
output_ref_torch = ref_program_torch_paged(Q, K_cache, V_cache, block_indices, cache_seqlens,
block_table, page_block_size, block_N)
if not is_flash_attn_2_available:
print("FlashAttn 2 is not available, skipping FA reference and performance measurement")
return
output_ref_fa = ref_program_fa(Q, K_cache, V_cache, cache_seqlens, block_table)
# Check correctness
if sparse_ratio == 0.0:
......
......@@ -439,16 +439,7 @@ def main(batch=8,
out = sparse_kernel(Q, K, V, block_indices, cache_seqlens)
debug("output", ref, out, atol=1e-3, rtol=1e-3)
is_flash_attn_2_available = False
try:
import flash_attn # noqa: F401
is_flash_attn_2_available = True
except ImportError:
pass
if not is_flash_attn_2_available:
print("FlashAttn 2 is not available, skipping FA reference and performance measurement")
return
import flash_attn # noqa: F401
## latency reference
for _ in range(10):
......
......@@ -419,16 +419,7 @@ def main(batch=8,
out = model(Q, K, V, block_mask, cache_seqlens)
debug("output", ref, out, atol=1e-3, rtol=1e-3)
is_flash_attn_2_available = False
try:
import flash_attn # noqa: F401
is_flash_attn_2_available = True
except ImportError:
pass
if not is_flash_attn_2_available:
print("FlashAttn 2 is not available, skipping FA reference and performance measurement")
return
import flash_attn # noqa: F401
## latency reference
for _ in range(10):
......
......@@ -449,16 +449,7 @@ def main(batch=64,
print(f"Average time: {avg_time:.6f} seconds")
# Measure performance of reference implementation
is_flash_attn_2_available = False
try:
import flash_attn # noqa: F401
is_flash_attn_2_available = True
except ImportError:
pass
if not is_flash_attn_2_available:
print("FlashAttn 2 is not available, skipping FA reference and performance measurement")
return
import flash_attn # noqa: F401
start = time.time()
for _ in range(1000):
......
......@@ -429,17 +429,7 @@ def main(batch=64,
print(f"Average time: {avg_time:.6f} seconds")
print(f"Average flops: {avg_flops:.2f} GFLOPS")
is_flash_attn_2_available = False
try:
import flash_attn # noqa: F401
is_flash_attn_2_available = True
except ImportError:
pass
# Measure performance of reference implementation
if not is_flash_attn_2_available:
print("FlashAttn 2 is not available, skipping FA reference and performance measurement")
return
import flash_attn # noqa: F401
start = time.time()
for _ in range(1000):
......
......@@ -412,16 +412,7 @@ def main(batch: int = 2, heads: int = 16, seq_len: int = 256, dim: int = 32):
)
torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2)
is_flash_attn_2_available = False
try:
import flash_attn
is_flash_attn_2_available = True
except:
pass
if not is_flash_attn_2_available:
print("FlashAttn 2 is not available, skipping FA reference and performance measurement")
return
import flash_attn
fla_out_unpad = flash_attn.flash_attn_varlen_func(
q_unpad,
......
......@@ -21,7 +21,6 @@ ml_dtypes
psutil
scipy
torch
thefuzz
tabulate
wheel
setuptools
\ No newline at end of file
......@@ -21,7 +21,6 @@ cloudpickle
ml_dtypes
psutil
torch
thefuzz
tabulate
wheel
setuptools
......
......@@ -4,8 +4,6 @@ import shutil
from setuptools import setup, find_packages, Extension
from setuptools.command.build_py import build_py
from setuptools.command.sdist import sdist
from setuptools.command.develop import develop
import distutils.dir_util
from typing import List, Optional
import re
import tarfile
......@@ -18,7 +16,7 @@ import hashlib
import sysconfig
import functools
import urllib.request
from distutils.version import LooseVersion
from packaging.version import Version
import platform
import multiprocessing
from setuptools.command.build_ext import build_ext
......@@ -117,7 +115,7 @@ def get_nvcc_cuda_version():
nvcc_output = subprocess.check_output(["nvcc", "-V"], universal_newlines=True)
output = nvcc_output.split()
release_idx = output.index("release") + 1
nvcc_cuda_version = LooseVersion(output[release_idx].split(",")[0])
nvcc_cuda_version = Version(output[release_idx].split(",")[0])
return nvcc_cuda_version
......@@ -128,7 +126,7 @@ def get_rocm_version():
# Example output: ROCM version: x.y.z-...
match = re.search(r'ROCm Version: (\d+\.\d+\.\d+)', rocm_output)
if match:
return LooseVersion(match.group(1))
return Version(match.group(1))
else:
rocm_path = os.environ.get("ROCM_PATH", "/opt/rocm")
rocm_version_file = os.path.join(rocm_path, "lib", "cmake", "rocm",
......@@ -138,9 +136,9 @@ def get_rocm_version():
content = f.read()
match = re.search(r'set\(PACKAGE_VERSION "(\d+\.\d+\.\d+)"', content)
if match:
return LooseVersion(match.group(1))
return Version(match.group(1))
# return a default
return LooseVersion("5.0.0")
return Version("5.0.0")
def get_tilelang_version(with_cuda=True, with_system_info=True, with_commit_id=False) -> str:
......@@ -418,7 +416,7 @@ class TileLangBuilPydCommand(build_py):
target_dir = os.path.join(self.build_lib, item)
if os.path.isdir(source_dir):
self.mkpath(target_dir)
distutils.dir_util.copy_tree(source_dir, target_dir)
self.copy_tree(source_dir, target_dir)
else:
target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir):
......@@ -434,7 +432,7 @@ class TileLangBuilPydCommand(build_py):
target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item)
if os.path.isdir(source_dir):
self.mkpath(target_dir)
distutils.dir_util.copy_tree(source_dir, target_dir)
self.copy_tree(source_dir, target_dir)
else:
target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir):
......@@ -511,7 +509,7 @@ class TileLangBuilPydCommand(build_py):
target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item)
if os.path.isdir(source_dir):
self.mkpath(target_dir)
distutils.dir_util.copy_tree(source_dir, target_dir)
self.copy_tree(source_dir, target_dir)
else:
target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir):
......@@ -528,7 +526,7 @@ class TileLangBuilPydCommand(build_py):
target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item)
if os.path.isdir(source_dir):
self.mkpath(target_dir)
distutils.dir_util.copy_tree(source_dir, target_dir)
self.copy_tree(source_dir, target_dir)
else:
target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir):
......@@ -544,7 +542,7 @@ class TileLangBuilPydCommand(build_py):
target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item)
if os.path.isdir(source_dir):
self.mkpath(target_dir)
distutils.dir_util.copy_tree(source_dir, target_dir)
self.copy_tree(source_dir, target_dir)
else:
target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir):
......@@ -570,7 +568,7 @@ class TileLangBuilPydCommand(build_py):
if os.path.isdir(source_dir):
self.mkpath(target_dir)
distutils.dir_util.copy_tree(source_dir, target_dir)
self.copy_tree(source_dir, target_dir)
else:
target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir):
......@@ -588,54 +586,6 @@ class TileLangSdistCommand(sdist):
super().make_distribution()
# ------------------------------------------------------------------------
# NEW: Add a custom 'develop' command so that `pip install -e .` works.
# ------------------------------------------------------------------------
class TileLangDevelopCommand(develop):
"""
Customized setuptools 'develop' command for an editable install.
Ensures the extension is built and all necessary assets are copied.
"""
def run(self):
logger.info("Running TileLangDevelopCommand")
# 1. Build the C/C++ extension modules
self.run_command("build_ext")
build_ext_cmd = self.get_finalized_command("build_ext")
ext_modules = build_ext_cmd.extensions
for ext in ext_modules:
extdir = build_ext_cmd.get_ext_fullpath(ext.name)
logger.info(f"Extension {ext.name} output directory: {extdir}")
ext_output_dir = os.path.dirname(extdir)
logger.info(f"Extension output directory (parent): {ext_output_dir}")
# Copy the built TVM to the package directory
TVM_PREBUILD_ITEMS = [
f"{ext_output_dir}/libtvm_runtime.so",
f"{ext_output_dir}/libtvm.so",
f"{ext_output_dir}/libtilelang.so",
f"{ext_output_dir}/libtilelang_module.so",
]
for item in TVM_PREBUILD_ITEMS:
source_lib_file = os.path.join(ROOT_DIR, item)
# only copy the file
file_name = os.path.basename(item)
target_dir = os.path.join(PACKAGE_NAME, file_name)
target_dir = os.path.dirname(target_dir)
target_dir = os.path.join(target_dir, "lib")
if not os.path.exists(target_dir):
os.makedirs(target_dir)
if os.path.exists(source_lib_file):
patch_libs(source_lib_file)
shutil.copy2(source_lib_file, target_dir)
# remove the original file
os.remove(source_lib_file)
else:
logger.info(f"INFO: {source_lib_file} does not exist.")
class CMakeExtension(Extension):
"""
A specialized setuptools Extension class for building a CMake project.
......@@ -811,18 +761,31 @@ class TilelangExtensionBuild(build_ext):
# Determine the directory where the final .so or .pyd library should go.
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
# To make it compatible with in-place build and avoid redundant link during incremental build,
# we need to change the build destination to tilelang/lib, where it's actually loaded
if self.inplace:
extdir = os.path.abspath('./tilelang/lib/')
logger.info(f"{extdir=}")
# Prepare arguments for the CMake configuration step.
# -DCMAKE_LIBRARY_OUTPUT_DIRECTORY sets where built libraries go
# -DPYTHON_EXECUTABLE ensures that the correct Python is used
cmake_args = [
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}", f"-DPython_EXECUTABLE={sys.executable}",
f"-DCMAKE_BUILD_TYPE={'Debug' if DEBUG_MODE else 'Release'}"
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}",
f"-DPython_EXECUTABLE={sys.executable}",
f"-DCMAKE_BUILD_TYPE={'Debug' if DEBUG_MODE else 'Release'}",
"-G",
"Ninja",
]
if not USE_ROCM:
cmake_args.append(f"-DCMAKE_CUDA_COMPILER={os.path.join(CUDA_HOME, 'bin', 'nvcc')}")
# Create the temporary build directory (if it doesn't exist).
build_temp = os.path.abspath(self.build_temp)
if self.inplace:
build_temp = os.path.abspath('./build')
else:
build_temp = os.path.abspath(self.build_temp)
os.makedirs(build_temp, exist_ok=True)
# Copy the default 'config.cmake' from the source tree into our build directory.
......@@ -884,6 +847,5 @@ setup(
"build_py": TileLangBuilPydCommand,
"sdist": TileLangSdistCommand,
"build_ext": TilelangExtensionBuild,
"develop": TileLangDevelopCommand,
},
)
......@@ -124,7 +124,11 @@ Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr> &exprs,
Array<IterSplitExpr> results;
for (const IterMark &mark : collector.visited_) {
ICHECK(mark->source.as<Var>()) << "Not a normalized iterator: " << mark;
if (!mark->source.as<Var>()) {
std::ostringstream oss;
oss << "Not a normalized iterator: " << mark;
throw NormalizeIterException(oss.str());
}
}
for (const IterVar &iter : input_iters) {
......
......@@ -14,6 +14,15 @@ namespace tl {
using namespace tir;
class NormalizeIterException : public std::exception {
public:
const char *what() const noexcept override { return msg_.c_str(); }
NormalizeIterException(const std::string &msg) : msg_(msg) {}
private:
std::string msg_;
};
/*!
* \brief Collect the IterSplit that is not used in expr.
*
......
......@@ -23,6 +23,19 @@ public:
static const Op &Get();
AtomicAdd(const AtomicAdd &other)
: args_(other.args_), src(other.src), dst(other.dst),
src_range(other.src_range), dst_range(other.dst_range),
coalesced_width(other.coalesced_width) {
// No clone nullptr
if (other.par_op_)
par_op_ = std::unique_ptr<ParallelOp>(
static_cast<ParallelOp *>(other.par_op_->Clone().release()));
}
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<AtomicAdd>(*this);
}
protected:
For MakeSIMTLoop(arith::Analyzer *analyzer) const;
Array<IterVar> MakeIterVars() const;
......
......@@ -51,6 +51,10 @@ public:
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
static const Op &Get();
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<Conv2DIm2ColOp>(*this);
}
private:
Buffer src, dst;
int stride, padding, dilation, kernel;
......
......@@ -373,20 +373,6 @@ LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) {
arith::Analyzer analyzer;
par_op_ = std::make_unique<ParallelOp>(MakeSIMTLoop(&analyzer));
}
if (T.layout_map.count(src) && T.layout_map.count(dst)) {
// Only compare fragment layout
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") {
const auto &src_layout = T.layout_map[src].as<Fragment>();
const auto &dst_layout = T.layout_map[dst].as<Fragment>();
if (src_layout && dst_layout) {
ICHECK((*src_layout)->IsEqual(dst_layout->get(), true))
<< "Get different layout for " << src << " and " << dst
<< "\nLHS = " << (*src_layout)->DebugOutput()
<< "\nRHS = " << (*dst_layout)->DebugOutput()
<< "\nYou may need to use a shared memory to transform the layout";
}
}
}
return par_op_->InferLayout(T, level);
}
......
......@@ -23,6 +23,19 @@ public:
static const Op &Get();
Copy(const Copy &other)
: args_(other.args_), src(other.src), dst(other.dst),
src_range(other.src_range), dst_range(other.dst_range),
coalesced_width(other.coalesced_width), disable_tma(other.disable_tma) {
// No clone nullptr
if (other.par_op_)
par_op_ = std::unique_ptr<ParallelOp>(
static_cast<ParallelOp *>(other.par_op_->Clone().release()));
}
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<Copy>(*this);
}
protected:
Stmt LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const;
Stmt LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const;
......@@ -53,6 +66,10 @@ public:
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
static const Op &Get();
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<Fill>(*this);
}
private:
For MakeSIMTLoop(arith::Analyzer *analyzer) const;
tir::Buffer dst;
......
......@@ -26,6 +26,10 @@ public:
kFullCol = 2,
} policy;
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<Gemm>(*this);
}
private:
// Target GEMM instruction
enum class GemmInst { kMMA, kWGMMA, kUTCMMA, kMFMA };
......
......@@ -26,6 +26,10 @@ public:
kFullCol = 2,
} policy;
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<GemmSP>(*this);
}
private:
std::pair<int, int>
ComputeWarpPartition(int num_warps, Target target,
......
......@@ -64,6 +64,7 @@ public:
virtual Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
virtual LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level);
virtual ~Operator() = default;
virtual std::unique_ptr<Operator> Clone() const = 0;
};
class RegionOp : public Operator {
......@@ -71,6 +72,10 @@ public:
RegionOp(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<RegionOp>(*this);
}
const Buffer &GetBuffer() const { return buffer_; }
const Array<Range> &GetRanges() const { return ranges_; }
int GetAccessMask() const { return access_mask_; }
......
......@@ -22,6 +22,64 @@ namespace attr {
constexpr const char *coalesced_width = "coalesced_width";
} // namespace attr
// ProveFragmentContains checks whether the threads that access elements of a
// smaller fragment (small_frag) are a subset of the threads that access
// elements of a larger fragment (large_frag) for any given loop index. This
// function ensures that if the small fragment's layout corresponds to the loop
// itself, accessing the large fragment's elements is valid. Additionally, if
// small is updated to large, the originally valid access remains valid. The
// proof is performed by:
//
// 1. Defining a variable `rep_small` to represent the replicate index of the
// small fragment that is being checked.
// 2. Using the `small_frag_indices` and `rep_small` to derive the thread
// accessing
// the element in the small fragment.
// 3. Using `large_frag_indices` to derive the physical index of the large
// fragment
// along with the thread information, and then feeding these into the inverse
// of the large fragment to obtain the logical index and replicate index.
// 4. Verifying the mapping by checking whether the computed thread using the
// inverse
// layout corresponds to the original thread calculated for the small
// fragment. If they don't match, this indicates that the inverse layout's
// domain does not include the thread and thus the access is invalid.
bool ProveFragmentContains(Fragment small_frag, Fragment large_frag,
Array<PrimExpr> small_frag_indices,
Array<PrimExpr> large_frag_indices,
arith::Analyzer &analyzer_) {
Var rep_small("__checking_frag_contains_rep");
analyzer_.Bind(rep_small,
Range(IntImm(small_frag->ReplicateExtent()->dtype, 0),
small_frag->ReplicateExtent()),
true); // Bind the replicate extent of small_frag.
// Derive thread for small_frag.
auto thread = small_frag->ForwardThread(small_frag_indices, rep_small);
// Get physical index and thread for large_frag.
auto large_frag_physical_and_thread = large_frag->Forward(large_frag_indices);
// Add small_frag's thread to the large fragment's thread info.
large_frag_physical_and_thread.push_back(thread);
// Get the inverse of the large fragment.
auto inv_large_frag = large_frag->Inverse();
// Compute logical index and replicate index using inverse layout.
auto inv_large_frag_logical_and_rep =
inv_large_frag->Forward(large_frag_physical_and_thread);
// Extract replicate index from the result.
auto inv_large_frag_rep =
inv_large_frag_logical_and_rep[inv_large_frag_logical_and_rep.size() - 1];
// Calculate thread based on the logical index and replicate index.
auto check_thread =
large_frag->ForwardThread(large_frag_indices, inv_large_frag_rep);
// Simplify the difference between the threads.
auto diff = analyzer_.Simplify(thread - check_thread);
// If the difference is zero, the threads match and the access is valid.
return is_zero(diff);
}
class IfBufferRemapLoopGenerator : public StmtExprMutator {
public:
static For run(Stmt stmt, Map<Buffer, Buffer> buffer_remap,
......@@ -267,7 +325,8 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
}
// Step 2: Check that the loop's partition can correctly align with all source
// fragment
// fragment, and infer layout only when it's not yet layout-ed
LayoutMap results;
for (const auto &[buffer, _] : indice_map_) {
if (T.layout_map.count(buffer)) {
auto fragment = T.layout_map[buffer].as<Fragment>().value();
......@@ -278,54 +337,19 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
continue;
auto vars =
loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); });
auto lhs = loop_layout_->ForwardThread(vars, std::nullopt);
auto rhs = fragment->ForwardThread(indice_map_[buffer], std::nullopt);
auto diff = analyzer_.Simplify(lhs - rhs);
ICHECK(is_zero(diff))
<< "Layout infer conflict for " << buffer << " " << source_buffer
<< "\nLHS = " << lhs << "\nRHS = " << rhs;
}
}
// Step 3: Infer other fragment's layout from the loop's partition
LayoutMap results;
for (const auto &[buffer, _] : indice_map_) {
if (!T.layout_map.count(buffer)) {
results.Set(buffer, CompleteBufferFragment(buffer)->BindThreadRange(
T.thread_bounds));
}
// Layout infer conflict for local.fragment can not be handled here
// because the source_buffer is not always available
// (zhengju) do not modify strict layout even if it is conflict with the
// dst layout. This will not influence the result because the strict
// layout is usually with rep = 1 Since the real layout map is
// controlled by layout_inference.cc, we should add this check there
if (buffer.scope() == "local.fragment" && source_buffer.defined() &&
source_buffer.scope() == "local.fragment") {
if (T.layout_map.count(buffer)) {
const FragmentNode *src_layout =
T.layout_map[buffer].as<FragmentNode>();
Fragment dst_layout_fragment =
CompleteBufferFragment(buffer)->BindThreadRange(T.thread_bounds);
const FragmentNode *dst_layout = dst_layout_fragment.as<FragmentNode>();
if (as_const_int(dst_layout->ReplicateExtent()) &&
as_const_int(src_layout->ReplicateExtent()) &&
(*as_const_int(dst_layout->ReplicateExtent()) >
*as_const_int(src_layout->ReplicateExtent()))) {
results.Set(buffer, dst_layout_fragment);
continue;
}
if (src_layout && dst_layout) {
ICHECK(src_layout->IsEqual(dst_layout, true))
<< "Layout may conflict with ParallelOp for buffer " << buffer
<< " vs. " << source_buffer << "\nError body begin:\n"
<< GetRoot()->body << "\nError body end"
<< "\nLHS = " << src_layout->DebugOutput()
<< "\nRHS = " << dst_layout->DebugOutput()
<< "\nYou may need to use a shared memory to transform the "
"layout";
}
if (!ProveFragmentContains(loop_layout_, fragment, vars,
indice_map_[buffer], analyzer_)) {
std::ostringstream oss;
oss << "Layout infer conflict between " << buffer << " and "
<< source_buffer << " in T.Parallel loop:" << std::endl
<< " loop " << loop_layout_->DebugOutput() << std::endl
<< " fragment " << fragment->DebugOutput() << std::endl;
throw LayoutConflictException(oss.str());
}
} else {
auto dst_layout =
CompleteBufferFragment(buffer)->BindThreadRange(T.thread_bounds);
results.Set(buffer, dst_layout);
}
}
return results;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment