Commit 57ab687c authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Initialization] Migration of Codebase from Dev Branch into Main (#10)



* Add format.sh script for code formatting and linting

* docs update

* center align the title

* lint fix

* add ignore

* Add .gitignore for 3rdparty directory

* Add requirements-dev.txt, requirements-test.txt, and requirements.txt

* 3rdparty

* Add gemm.h, CMakeLists.txt, _ffi_api.py, __init__.py, runtime.h, reduce.h, loop_partition.h, utils.h, and loop_vectorize.h

* Refactor CMakeLists.txt and include statements

- Update CMakeLists.txt to use a newer version of CMake and add project name
- Remove unnecessary include directories

Fix include paths in layout.cc, codegen.cc, codegen.h, rt_mod.cc, frontend_legalize.cc, inject_pipeline.cc, layout_inference.cc, loop_vectorize.cc, and lower_tile_op.cc

- Update include paths to use relative paths instead of absolute paths

* Update submodule for 3rdparty/tvm

* update

* load dll first

* Refactor CMakeLists.txt and include statements

* Refactor CMakeLists.txt and include statements

* git keep update

* Refactor CMakeLists.txt and include statements

* Refactor CMakeLists.txt and include statements

* refactor code structure

* Update Readme

* CMakeLists Customized

* update readme

* update README

* update readme

* update usage

* with TVM_IMPORT_PYTHON_PATH to handle own tvm build python import

* annotate lower transform global func with `transform` prefix

* Migrate Simplify Pass from tilelang tvm branch

* enhance system environment handling with __init__ and CMake

* Initial commit

* CODE_OF_CONDUCT.md committed

* LICENSE committed

* README.md committed

* SECURITY.md committed

* SUPPORT.md committed

* CODE_OF_CONDUCT Commit

* LICENSE Commit

* SECURITY Commit

* SUPPORT Commit

* Modify Support

* Update README.md

* security ci update

* remove examples

* Update and implement clang-format

* add composable kernel components

* Migrate from latest update

* submodule update

* Test update

* Update License

* Spell check

* lint fix

* add clang-tidy to apply static analysis for c source

* update tilelang examples

* Update Install Docs

* Refactor filetree

* Enhance Install

* conflict resloved

* annotate_version

* Initial Update

* test fix

* install

* Implement setup.py

* lint fix

* Separate Init

* Separate test

* docker file commit

* add logo

* Update Readme and Examples

* update readme

* update logo

* Implement AMD Installation

* Add License

* Update AMD MI300x Benchmark

* update README

* update mi300 benchmark scripts

* update ignore

* enhance build scirpt

* update image

* enhance setup.py to remove duplicated libraries

* remove debug files

* update readme

* update image

* update gemm examples

* update flashattention README

* readme update

* add cmake into requirements

* libinfo fix

* auto update submodule

* lint fix

* Fix AMD Build and Test

* Update check for transpose attribute for CDNA Arch

* typo fix for amd

* Implement Matmul Benchmark

* Refactor Code

* [TypoFix] Fix GEMM Example

* [Docs] Init Linear Attention README

* [TYPO] Typo fix

* [Lint] Lint Fix

* enhance example with intrinsics

* [Enhancement] Improve Buffer Collection during IR Parser

* [Dev] Introduce Current classmethod to get current frame

* submodule update

* fake test pass update

* support thread_extent_api

* code optimize

* Add GEMM function implementation for matrix multiplication

* Update logging format to reflect TileLang in logger messages

* Refactor CMakeLists.txt for improved readability and set default build type to Release

* Support Gemm SS Primitives Implementation

* [README] Upload Tile Language Logo (#5)

* update logo

* Update README.md to enhance formatting and center the title

---------
Co-authored-by: default avatarmicrosoft-github-operations[bot] <55726097+microsoft-github-operations[bot]@users.noreply.github.com>
Co-authored-by: default avatarMicrosoft Open Source <microsoftopensource@users.noreply.github.com>
Co-authored-by: default avatarYu Cheng <yu.cheng@pku.edu.cn>
parent 64f17c2f
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <rocwmma/rocwmma.hpp>
#include <ck_tile/core.hpp>
using ck_tile::half_t;
#define HIPRT_INF_F __int_as_float(0x7f800000)
#define HIPRT_NEGINF_F __int_as_float(0xff800000)
#define HIPRT_NAN_F __int_as_float(0x7fffffff)
#define HIPRT_MIN_DENORM_F __int_as_float(0x00000001)
#define HIPRT_MAX_NORMAL_F __int_as_float(0x7f7fffff)
#define HIPRT_NEG_ZERO_F __int_as_float(0x80000000)
#define HIPRT_ZERO_F 0.0f
#define HIPRT_ONE_F 1.0f
/* double precision constants */
#define HIPRT_INF __hiloint2double(0x7ff00000, 0x00000000)
#define HIPRT_NAN __hiloint2double(0xfff80000, 0x00000000)
#define uint unsigned int
#define uchar unsigned char
#define ushort unsigned short
#define TL_DEVICE __forceinline__ __device__
#define half _Float16
#define __float2half_rn(x) half(x)
#define hpow __ocml_pown_f16
#define hsqrt __ocml_sqrt_f16
using float16_t = _Float16;
using float16x2 = __attribute__((__vector_size__(2 * sizeof(float16_t)))) float16_t;
using float16x4 = __attribute__((__vector_size__(4 * sizeof(float16_t)))) float16_t;
using float16x8 = __attribute__((__vector_size__(8 * sizeof(float16_t)))) float16_t;
using float16x16 = __attribute__((__vector_size__(16 * sizeof(float16_t)))) float16_t;
using int32x4 = __attribute__((__vector_size__(4 * sizeof(int)))) int;
using float32x4 = __attribute__((__vector_size__(4 * sizeof(float)))) float;
using float32x16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
using int8x4 = __attribute__((__vector_size__(4 * sizeof(int8_t)))) int8_t;
// Pack two half_t values.
TL_DEVICE unsigned __pack_half2(const half_t x, const half_t y) {
unsigned v0 = *((unsigned short*)&x);
unsigned v1 = *((unsigned short*)&y);
return (v1 << 16) | v0;
}
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include "common.h"
using f32 = float;
// using f16 = _Float16;
using u8 = std::uint8_t;
using u16 = std::uint16_t;
using u32 = std::uint32_t;
using index_t = u32;
using ck_tile::int32x4_t;
struct __attribute__((packed)) buffer_resource {
const void* ptr;
uint32_t range;
uint32_t config;
};
CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t size = 0xffffffff) {
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
r.x = __builtin_amdgcn_readfirstlane(r.x);
r.y = __builtin_amdgcn_readfirstlane(r.y);
r.z = __builtin_amdgcn_readfirstlane(r.z);
r.w = __builtin_amdgcn_readfirstlane(r.w);
return r;
}
__device__ void init_m0(uint32_t m0_value) {
asm volatile("s_mov_b32 m0, %0" : : "s"(m0_value) : "memory");
}
__device__ void inc_m0(uint32_t m0_inc) {
asm volatile("s_add_u32 m0, %0, m0" : : "n"(m0_inc) : "memory");
}
namespace tl {
// AMDGPU automatically commit memory fence
TL_DEVICE void cp_async_commit() {}
// Global Memory only fence
__device__ void async_gld_fence(index_t cnt) {
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
// Global Memory and Shared Memory fence
__device__ void async_gld_sld_fence(index_t cnt) {
asm volatile("s_waitcnt lgkmcnt(%0)" : : "n"(cnt) : "memory");
}
__device__ void wave_barrier() { asm volatile("s_barrier" : : : "memory"); }
template <int N = 0>
TL_DEVICE void cp_async_wait() {
async_gld_fence(N);
// or
// async_gld_sld_fence(N);
}
template <bool pre_nop = false>
CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem, int32x4_t rsrc, index_t voffset) {
auto const lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>(smem)));
asm volatile(
"s_mov_b32 m0, %0; \n\t"
"buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr),
"v"(voffset), "s"(rsrc)
: "memory");
}
template <int N>
TL_DEVICE void cp_async_gs(void* lds_base_ptr, void* global_base_ptr) {
if constexpr(N == 16) {
*(uint4*)lds_base_ptr = *(uint4*)global_base_ptr;
} else if constexpr(N == 8) {
*(uint2*)lds_base_ptr = *(uint2*)global_base_ptr;
} else if constexpr(N == 4) {
async_buffer_load_dword_v(lds_base_ptr, make_wave_buffer_resource(((int32_t *)global_base_ptr) - threadIdx.x), threadIdx.x * N /*assume 4 bytes*/);
}
}
template <int N>
TL_DEVICE void cp_async_gs_conditional(void* lds_base_ptr, void* global_base_ptr, bool cond) {
if constexpr(N == 16){
*(uint4*)lds_base_ptr = cond? *(uint4*)global_base_ptr: make_uint4(0,0,0,0);
}else if constexpr(N == 8){
*(uint2*)lds_base_ptr = cond? *(uint2*)global_base_ptr: make_uint2(0,0);
}else{
if (cond) {
async_buffer_load_dword_v(lds_base_ptr, make_wave_buffer_resource(((int32_t *)global_base_ptr) - threadIdx.x), threadIdx.x * N /*assume 4 bytes*/);
}else{
*(uint4*)lds_base_ptr = make_uint4(0,0,0,0);
}
}
}
} // namespace tl
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include "common.h"
namespace tl {
// ref to bitblas/tl/mfma_macro_generator.py::kPack
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool TransposeA, bool TransposeB, int kPack,
typename A_type, typename B_type, typename C_type, typename AccDataType = float>
class GemmTensorOp {
public:
static constexpr int micro_size_x = 16;
static constexpr int micro_size_y = 16;
static constexpr int micro_size_k = 16;
// This part comes from the Codegen
static constexpr int M_Tile = M;
static constexpr int N_Tile = N;
static constexpr int K_Tile = K;
static constexpr int block_row_warps = num_warp_m;
static constexpr int block_col_warps = num_warp_n;
static constexpr int inner_k = K_Tile / (micro_size_k * kPack);
static constexpr int warp_rows = M_Tile / (block_row_warps * micro_size_x);
static constexpr int warp_cols = N_Tile / (block_col_warps * micro_size_y);
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
static constexpr bool kPadA = true;
static constexpr bool kPadB = true;
static constexpr bool kPadC = true;
static constexpr int BANK_SIZE_BYTES = 128;
static constexpr int warp_size = 64;
TL_DEVICE static constexpr auto reverse_index_map(int thread_id, int local_id) {
return std::make_pair(thread_id % 16, (thread_id / 16) * (4 * kPack) + local_id);
}
TL_DEVICE static constexpr auto reverse_index_map_transposed(int thread_id, int local_id) {
return std::make_pair((thread_id / 16) * (4 * kPack) + local_id, thread_id % 16);
}
/*
* Detailed Implementation please
* checkout bitblas/tl/utils.py:get_swizzle_layout
*/
template <int continuous = 32, int element_size = 2>
TL_DEVICE static auto make_mfma_swizzle_layout(const int row, const int col) {
const auto dtype_bits = element_size * 8;
const int numBanks = 32;
const int bankBitWidth = 32;
const int SIMDWidth = 16;
const int vecSize = 4 * kPack;
const int innerDimLength = continuous;
const int typeWidthInBit = dtype_bits;
const int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit;
const int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
const int maxPhase = std::min(SIMDWidth / perPhase, innerDimLength / vecSize);
const int phase = (row / perPhase) % maxPhase;
const int colOffSwizzled = (((col / vecSize) ^ phase) * vecSize);
const int colOffOrdered = col % vecSize;
const int colOff = colOffSwizzled + colOffOrdered;
return std::make_pair(row, colOff);
}
template <int continuous = 32, int element_size = 2>
TL_DEVICE static constexpr auto make_layout_padded(const int row, const int col) {
return std::make_pair(row, col);
}
template <int continuous = 32, int element_size = 2>
TL_DEVICE static constexpr auto make_swizzle_layout(const int row, const int col) {
constexpr auto vector_size = BANK_SIZE_BYTES / (element_size * 8);
if (continuous % (vector_size * 4) == 0) {
auto [n_row, n_col] = make_mfma_swizzle_layout<continuous, element_size>(row, col);
return n_row * continuous + n_col;
} else {
auto [n_row, n_col] = make_layout_padded(row, col);
int padded = continuous;
if ((element_size * 8 * continuous) % 256 == 0)
padded += BANK_SIZE_BYTES / (element_size * 8);
return n_row * padded + n_col;
}
}
static TL_DEVICE void body(A_type* A_shared, B_type* B_shared, C_type* C_local) {
auto tid = threadIdx.x;
auto warp_id = tid / warp_size;
auto warp_n = warp_id / block_row_warps;
auto warp_m = warp_id % block_row_warps;
auto warp_row_tiles = warp_rows * micro_size_x;
auto warp_col_tiles = warp_cols * micro_size_y;
auto lane_id = tid % warp_size;
auto tx = lane_id;
constexpr auto local_size_a = (micro_size_x * micro_size_k) / warp_size;
constexpr auto local_size_b = (micro_size_y * micro_size_k) / warp_size;
constexpr auto local_size_c = (micro_size_x * micro_size_y) / warp_size;
constexpr auto last_dim_a = TransposeA ? M_Tile : K_Tile;
constexpr auto last_dim_b = TransposeB ? K_Tile : N_Tile;
A_type A_local[warp_rows * kPack * local_size_a];
B_type B_local[warp_cols * kPack * local_size_b];
for (int ki = 0; ki < inner_k; ki++) {
// Fetch A into register
for (int i = 0; i < warp_rows; i++) {
const auto l = warp_m * warp_row_tiles + i * micro_size_x;
const auto r = ki * (kPack * micro_size_k);
for (int local_id = 0; local_id < (kPack * local_size_a); local_id++) {
auto [row, col] = reverse_index_map(lane_id, local_id);
A_local[i * kPack * local_size_a + local_id] =
A_shared[make_swizzle_layout<last_dim_a, sizeof(A_type)>(l + row, r + col)];
}
}
// Fetch B into register
for (int j = 0; j < warp_cols; j++) {
const auto l = warp_n * warp_col_tiles + j * micro_size_y;
const auto r = ki * (kPack * micro_size_k);
for (int local_id = 0; local_id < (kPack * local_size_b); local_id++) {
auto [row, col] = reverse_index_map(lane_id, local_id);
B_local[j * kPack * local_size_b + local_id] =
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(l + row, r + col)];
}
}
// Compute
for (int kp = 0; kp < kPack; kp++) {
for (int i = 0; i < warp_rows; ++i) {
for (int j = 0; j < warp_cols; ++j) {
*(((float32x4*)C_local) + ((i * warp_cols) + j)) = __builtin_amdgcn_mfma_f32_16x16x16f16(
*(((float16x4*)B_local) + j * kPack + kp),
*(((float16x4*)A_local) + i * kPack + kp),
*(((float32x4*)C_local) + ((i * warp_cols) + j)), 0, 0, 0);
}
}
}
}
}
static TL_DEVICE void body_rs(A_type* A_local, B_type* B_shared, C_type* C_local) {
auto tid = threadIdx.x;
auto warp_id = tid / warp_size;
auto warp_n = warp_id / block_row_warps;
auto warp_m = warp_id % block_row_warps;
auto warp_row_tiles = warp_rows * micro_size_x;
auto warp_col_tiles = warp_cols * micro_size_y;
auto lane_id = tid % warp_size;
auto tx = lane_id;
constexpr auto local_size_a = (micro_size_x * micro_size_k) / warp_size;
constexpr auto local_size_b = (micro_size_y * micro_size_k) / warp_size;
constexpr auto local_size_c = (micro_size_x * micro_size_y) / warp_size;
constexpr auto last_dim_a = TransposeA ? M_Tile : K_Tile;
constexpr auto last_dim_b = TransposeB ? K_Tile : N_Tile;
B_type B_local[warp_cols * kPack * local_size_b];
for (int ki = 0; ki < inner_k; ki++) {
// Fetch B into register
for (int j = 0; j < warp_cols; j++) {
const auto l = warp_n * warp_col_tiles + j * micro_size_y;
const auto r = ki * kPack * micro_size_k;
for (int local_id = 0; local_id < kPack * local_size_b; local_id++) {
auto [row, col] = reverse_index_map(lane_id, local_id);
B_local[j * local_size_b + local_id] =
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(l + row, r + col)];
}
}
// Compute
for (int kp = 0; kp < kPack; kp++) {
for (int i = 0; i < warp_rows; ++i) {
for (int j = 0; j < warp_cols; ++j) {
*(((float32x4*)C_local) + ((i * warp_cols) + j)) = __builtin_amdgcn_mfma_f32_16x16x16f16(
*(((float16x4*)B_local) + j * kPack + kp), *(((float16x4*)A_local) + ki * warp_rows * kPack + i * kPack + kp),
*(((float32x4*)C_local) + ((i * warp_cols) + j)), 0, 0, 0);
}
}
}
}
}
};
} // namespace tl
namespace tl {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B, int kPack,
typename A_type, typename B_type, typename C_type>
TL_DEVICE void gemm_ss(A_type* pA, B_type* pB, C_type* accum) {
using Compute =
GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, kPack, A_type, B_type, C_type>;
Compute::body(pA, pB, accum);
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B, int kPack,
typename A_type, typename B_type, typename C_type>
TL_DEVICE void gemm_rs(A_type* pA, B_type* pB, C_type* accum) {
using Compute =
GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, kPack, A_type, B_type, C_type>;
Compute::body_rs(pA, pB, accum);
}
} // namespace tl
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include "common.h"
\ No newline at end of file
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include "common.h"
namespace tl {
struct SumOp {
template <typename T>
TL_DEVICE T operator()(T const& x, T const& y) {
return x + y;
}
};
struct MaxOp {
template <typename T>
TL_DEVICE T operator()(T const& x, T const& y) {
return ck_tile::max(x, y);
}
};
struct MinOp {
template <typename T>
TL_DEVICE T operator()(T const& x, T const& y) {
return ck_tile::min(x, y);
}
};
template <class Reducer, int threads, int scale>
struct AllReduce {
static_assert(threads == 1024 || threads == 512 || threads == 256 || threads == 128 ||
threads == 64 || threads == 32 || threads == 16 || threads == 8 || threads == 4 ||
threads == 2);
static_assert(threads % scale == 0);
template <typename T>
static __device__ T run(T x, T* red_buf = nullptr) {
constexpr int offset = threads / 2;
constexpr int warpSize = 64;
if constexpr (offset >= warpSize) {
__syncthreads();
red_buf[threadIdx.x] = x;
__syncthreads();
x = Reducer()(x, red_buf[threadIdx.x ^ offset]);
} else {
x = Reducer()(x, __shfl_xor(x, offset));
}
if constexpr (offset == scale) {
return x;
} else {
return AllReduce<Reducer, offset, scale>::run(x, red_buf);
}
}
};
} // namespace tl
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include "common.h"
namespace tl {
template <int panel_width>
TL_DEVICE dim3 rasterization2DRow() {
auto ceil_div = [](int a, int b) { return (a + b - 1) / b; };
const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x;
const unsigned int grid_size = gridDim.x * gridDim.y;
const unsigned int panel_size = panel_width * gridDim.x;
const unsigned int panel_offset = block_idx % panel_size;
const unsigned int panel_idx = block_idx / panel_size;
const unsigned int total_panel = ceil_div(grid_size, panel_size);
const unsigned int stride =
panel_idx + 1 < total_panel ? panel_width : (grid_size - panel_idx * panel_size) / gridDim.x;
const unsigned int col_idx =
(panel_idx & 1) ? gridDim.x - 1 - panel_offset / stride : panel_offset / stride;
const unsigned int row_idx = panel_offset % stride + panel_idx * panel_width;
return {col_idx, row_idx, blockIdx.z};
}
template <int panel_width>
TL_DEVICE dim3 rasterization2DColumn() {
auto ceil_div = [](int a, int b) { return (a + b - 1) / b; };
const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x;
const unsigned int grid_size = gridDim.x * gridDim.y;
const unsigned int panel_size = panel_width * gridDim.y;
const unsigned int panel_offset = block_idx % panel_size;
const unsigned int panel_idx = block_idx / panel_size;
const unsigned int total_panel = ceil_div(grid_size, panel_size);
const unsigned int stride =
panel_idx + 1 < total_panel ? panel_width : (grid_size - panel_idx * panel_size) / gridDim.y;
const unsigned int row_idx =
(panel_idx & 1) ? gridDim.y - 1 - panel_offset / stride : panel_offset / stride;
const unsigned int col_idx = panel_offset % stride + panel_idx * panel_width;
return {col_idx, row_idx, blockIdx.z};
}
} // namespace tl
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file clasuter_planning.cc
* \brief Plan the cluster for GPU(sm90+) blocks
*/
#include <tvm/arith/analyzer.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
namespace tvm {
namespace tir {
class ClusterPlanner {
public:
static PrimFunc Substitute(PrimFunc& f) {
// Step 1: Collect the read region of the function
Map<Var, Buffer> buffer_data_to_buffer_;
for (const auto& [_, buffer] : f->buffer_map) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ f->body);
Array<Array<BufferRegion>> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
auto reads = access[0];
BlockIdxVisitor blockIdx_visitor;
blockIdx_visitor(f->body);
auto dom_map = blockIdx_visitor.dom_map_;
// Step 2: Collect mem reuse count for clustering on each dimension.
std::unordered_map<const IterVarNode*, size_t> mem_reuse_count;
for (auto iv : dom_map) mem_reuse_count[iv] = 0;
for (const auto& buffer_region : reads) {
PrimExpr size = buffer_region->buffer->dtype.bits();
RegionVisitor visitor;
for (const auto& range : buffer_region->region) {
size = size * range->extent;
visitor(range->min);
}
size = arith::Analyzer().Simplify(size);
if (auto imm = size.as<IntImmNode>()) {
for (auto iv : dom_map) {
if (visitor.seen_.count(iv->var.get()) == 0) mem_reuse_count[iv] += imm->value;
}
}
}
// Step 3: Pick the cluster dimension with the largest mem_reuse.
size_t mem_reuse_max = 0;
String cluster_tag;
for (auto iv : dom_map) {
if (auto extent = iv->dom->extent.as<IntImmNode>()) {
if (extent->value % cluster_size_ == 0 && mem_reuse_count[iv] > mem_reuse_max) {
cluster_tag = iv->thread_tag;
mem_reuse_max = mem_reuse_count[iv];
}
}
}
if (mem_reuse_max > 0) {
cluster_tag = "clusterIdx" + String(cluster_tag.c_str() + strlen("blockIdx"));
return WithAttr(f, cluster_tag, Integer(cluster_size_));
} else {
return f;
}
}
private:
ClusterPlanner() = default;
class RegionVisitor : public ExprVisitor {
public:
RegionVisitor(){};
void VisitExpr_(const VarNode* var) { seen_.insert(var); }
std::unordered_set<const VarNode*> seen_;
};
class BlockIdxVisitor : public StmtVisitor {
public:
BlockIdxVisitor(){};
void VisitStmt_(const AttrStmtNode* attr) final {
if (attr->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(attr->node);
String tag = iv->thread_tag;
if (tag == "blockIdx.x" || tag == "blockIdx.y" || tag == "blockIdx.z")
dom_map_.insert(iv.get());
}
StmtVisitor::VisitStmt_(attr);
}
/*! \brief The map from vars to blockidx extents. */
std::unordered_set<const IterVarNode*> dom_map_;
};
/*! \brief Currently set the plossible cluster size as 2 */
const static int cluster_size_ = 2;
};
PrimFunc ClusterPlanning(PrimFunc f) { return ClusterPlanner::Substitute(f); }
namespace transform {
tvm::transform::Pass ClusterPlanning() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return ClusterPlanning(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tl.ClusterPlanning", {});
}
TVM_REGISTER_GLOBAL("tl.transform.ClusterPlanning").set_body_typed(ClusterPlanning);
} // namespace transform
} // namespace tir
} // namespace tvm
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file common.h
* \brief Common utilities for TL transforms
*/
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>
#include <queue>
#include "arith/ir_mutator_with_analyzer.h"
#include "../../op/parallel.h"
#include "../loop_partition.h"
#include "../loop_vectorize.h"
namespace tvm {
namespace tl {
using namespace tir;
using arith::IRMutatorWithAnalyzer;
class FragmentAccessDetector : public StmtExprVisitor {
public:
FragmentAccessDetector() = default;
void Collect(Stmt stmt) { VisitStmt(stmt); }
bool HasFragmentAccess() { return has_fragment_access_; }
private:
void VisitExpr_(const BufferLoadNode* op) final {
// Check if the buffer is in global scope
if (IsFragmentBuffer(op->buffer)) {
has_fragment_access_ = true;
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode* op) final {
// Check if the buffer is in global scope
if (IsFragmentBuffer(op->buffer)) {
has_fragment_access_ = true;
}
StmtExprVisitor::VisitStmt_(op);
}
// Helper function to determine if a buffer is local.fragment
bool IsFragmentBuffer(const Buffer& buffer) {
// The storage scope is often encoded in the buffer->data var name or associated attributes.
String scope = buffer.scope();
return scope == "local.fragment";
}
bool has_fragment_access_{false};
};
/*!
* \brief ParallelLoopFuser
* This class is used to fuse a chain of parallel loops into one loop.
* The loops must:
* - All be parallel (ForKind::kParallel)
* - Have bounds from 0 to their extent
* Once fused, a single loop variable will replace the chain, and the
* original loop variables will be derived by division and modulo operations.
*
* This can be helpful for inferring layout for the fragment in a subsequent pass.
*/
class ParallelLoopFuser : public IRMutatorWithAnalyzer {
public:
static Stmt Fuse(Stmt stmt) {
arith::Analyzer analyzer;
ParallelLoopFuser substituter(&analyzer);
return substituter.VisitStmt(stmt);
}
private:
ParallelLoopFuser(arith::Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer) {};
Stmt VisitStmt_(const ForNode* op) final {
// Gather consecutive parallel loops
std::vector<const ForNode*> loop_chain;
const ForNode* current = op;
// check if has fragment access
FragmentAccessDetector detector;
detector.Collect(op->body);
// Do not fuse if there is a fragment access
if (detector.HasFragmentAccess()) {
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
while (true) {
if (current->kind != ForKind::kParallel) break;
if (!is_zero(current->min)) break;
loop_chain.push_back(current);
const ForNode* inner_for = current->body.as<ForNode>();
if (!inner_for) {
break;
}
current = inner_for;
}
// If only one loop found or loop chain size is 1, no fusion needed.
if (loop_chain.size() <= 1) {
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
// At this point we have multiple nested parallel loops starting at zero
// We will fuse them all.
PrimExpr fused_extent = make_const(DataType::Int(32), 1);
for (auto it = loop_chain.rbegin(); it != loop_chain.rend(); ++it) {
fused_extent = fused_extent * (*it)->extent;
}
std::string fused_name;
for (auto it = loop_chain.begin(); it != loop_chain.end(); ++it) {
fused_name += (*it)->loop_var->name_hint + "_";
}
fused_name += "fused";
// Create a new fused loop var
Var fused_var(fused_name, DataType::Int(32));
// The body of the last loop in the chain:
const ForNode* innermost_loop = loop_chain.back();
Stmt body = innermost_loop->body;
// We need to substitute all loop variables in the chain.
// The scheme:
// Suppose we have loops (i in [0,M], j in [0,N], k in [0,O])
// fused loop var f in [0, M*N*O]
// i = f / (N*O)
// j = (f % (N*O)) / O
// k = f % O
//
// Generalizing for a chain of lengths L:
// extents: E_0, E_1, ... E_{L-1}
// index_i = (f / (E_{i+1}*...*E_{L-1})) % E_i
// For the last one, it's just f % E_{L-1} if i == L-1.
// Compute the "stride" products for each loop variable
// stride[i] = product of extents of loops after i
// for L loops: stride[L-1] = 1
// stride[L-2] = E_{L-1}
// stride[L-3] = E_{L-1} * E_{L-2}
// ...
std::vector<PrimExpr> extents;
extents.reserve(loop_chain.size());
for (auto l : loop_chain) {
extents.push_back(l->extent);
}
std::vector<PrimExpr> strides(loop_chain.size(), make_const(DataType::Int(32), 1));
for (int i = static_cast<int>(loop_chain.size()) - 2; i >= 0; i--) {
strides[i] = strides[i + 1] * extents[i + 1];
}
// We'll create a substitution map for all loop variables
// index_i = (f / strides[i]) % extents[i]
// We'll define a helper lambda:
auto create_index_expr = [&](int i) {
return FloorMod(FloorDiv(fused_var, strides[i]), extents[i]);
};
Map<Var, PrimExpr> var_map;
for (size_t i = 0; i < loop_chain.size(); i++) {
const ForNode* loop = loop_chain[i];
var_map.Set(loop->loop_var, analyzer_->Simplify(create_index_expr(static_cast<int>(i))));
}
// Perform the substitution
body = Substitute(body, var_map);
// Create the fused loop
For fused_for = For(fused_var, 0, fused_extent, ForKind::kParallel, body);
return fused_for;
}
};
} // namespace tl
} // namespace tvm
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file common.h
* \brief Common utilities for TL transforms
*/
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>
#include <queue>
#include "arith/ir_mutator_with_analyzer.h"
#include "../../op/parallel.h"
#include "../loop_partition.h"
#include "../loop_vectorize.h"
namespace tvm {
namespace tl {
using namespace tir;
// Vectorize Part
// Use the same code as tir.transform.vectorize_loop
inline PrimExpr CreateNewLanes(bool is_scalable, int lanes_or_vscale_factor) {
if (is_scalable) {
return Mul(Call(DataType::Int(32), builtin::vscale(), {}), lanes_or_vscale_factor);
} else {
return lanes_or_vscale_factor;
}
}
inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) {
// Check if e is already in the expected form
if (e.dtype().get_lanes_or_vscale_factor() == lanes &&
e.dtype().is_scalable_vector() == is_scalable)
return e;
if (const BroadcastNode* op = e.as<BroadcastNode>()) {
ICHECK(op->dtype.is_scalable_vector() == is_scalable)
<< "Can't broadcast between scalable and fixed length vectors.";
int e_lanes = op->dtype.get_lanes_or_vscale_factor();
if (lanes % e_lanes == 0) {
return Broadcast(op->value, CreateNewLanes(is_scalable, lanes));
}
}
ICHECK(e.dtype().is_scalar()) << "Cannot broadcast lanes="
<< e.dtype().get_lanes_or_vscale_factor()
<< " is_scalable=" << e.dtype().is_scalable_vector() << " to "
<< lanes;
return Broadcast(e, CreateNewLanes(is_scalable, lanes));
}
// Rewrite vectorized allocation access
// This is necessary for making each vector component containing its own workspace.
// Originates from Halide's loop vectorizer
//
// s[i] = s[i * lanes + var]
//
// The same principle applies when using one thread to simulate multiple context.
//
class VecAllocAccess : public StmtExprMutator {
public:
VecAllocAccess(const VarNode* buf, Var var, PrimExpr var_lanes)
: buf_(buf), var_(var), var_lanes_(var_lanes) {}
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
return UpdateBufferAccess(load);
}
Stmt VisitStmt_(const BufferStoreNode* op) final {
auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
return UpdateBufferAccess(store);
}
private:
template <typename Node>
Node UpdateBufferAccess(Node node) {
// Only update the buffer that's being replaced.
if (node->buffer->data.get() != buf_) {
return node;
}
// Find/make a Buffer object with the correct updated shape.
Buffer buf;
auto it = buffer_map_.find(node->buffer.get());
if (it != buffer_map_.end()) {
buf = it->second;
} else {
// Extend the least significant dimension by a factor of
// var_lanes_. Typically, this will be a 1-d index into a flat
// memory space.
Array<PrimExpr> shape = node->buffer->shape;
shape.Set(shape.size() - 1, analyzer_.Simplify(shape[shape.size() - 1] * var_lanes_));
// TODO(Lunderberg): Move this pass to be prior to
// StorageFlatten/FlattenBuffer, implement by appending a
// dimension to the buffer. Since it is currently after the
// flattening, the strides are not technically necessary, but
// are updated for consistency.
// Update strides if defined.
Array<PrimExpr> strides;
for (size_t i = 0; i < strides.size(); i++) {
PrimExpr stride = strides[i];
if (i != strides.size() - 1) {
stride *= var_lanes_;
}
strides.push_back(analyzer_.Simplify(stride));
}
// Copy everything into the new buffer.
buf = node->buffer;
auto buf_writer = buf.CopyOnWrite();
buf_writer->shape = shape;
buf_writer->strides = strides;
buffer_map_[buf.get()] = buf;
}
// Extend the last index by the number of lanes in the vectorized
// variable.
Array<PrimExpr> indices = node->indices;
indices.Set(indices.size() - 1,
analyzer_.Simplify(indices[indices.size() - 1] * var_lanes_ + var_));
auto writer = node.CopyOnWrite();
writer->buffer = buf;
writer->indices = indices;
return node;
}
// buffer var
const VarNode* buf_;
// Updated buffer objects.
std::unordered_map<const BufferNode*, Buffer> buffer_map_;
// variable to be replaced
Var var_;
// the lanes.
PrimExpr var_lanes_;
// Analyzer for simplifications
arith::Analyzer analyzer_;
};
// We use ExprFunctor directly instead of StmtExprMutator
// This is because the transformation can change the dtype of the Expr
// The existing ExprMutator transformation rules may not be well defined.
class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExpr&)> {
public:
using ExprFunctor::VisitExpr;
using StmtMutator::operator();
Vectorizer(Var var, PrimExpr var_lanes) : var_(var), var_lanes_(var_lanes) {
ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes);
}
Stmt VisitStmt(const Stmt& stmt) final {
ICHECK(!need_scalarize_);
Stmt ret = StmtMutator::VisitStmt(stmt);
if (need_scalarize_) {
need_scalarize_ = false;
return Scalarize(stmt);
} else {
return ret;
}
}
PrimExpr VisitExpr(const PrimExpr& e) final { return ExprFunctor::VisitExpr(e); }
PrimExpr VisitExpr_(const AddNode* op) final {
return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a + b; });
}
PrimExpr VisitExpr_(const SubNode* op) final {
return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a - b; });
}
PrimExpr VisitExpr_(const MulNode* op) final {
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
} else {
bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector();
bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector();
if (is_vec_a && is_vec_b) {
// Let's not multiply scalable and fixed length vectors
ICHECK(a.dtype().is_scalable_vector() == b.dtype().is_scalable_vector())
<< "Fixed length and scalable vectors can't be mixed in multiplication.";
}
if (is_vec_a || is_vec_b) {
const RampNode* b_ramp = b.as<RampNode>();
const RampNode* a_ramp = a.as<RampNode>();
if (a_ramp && b.dtype().is_scalar() && analyzer_.CanProve(b > 0)) {
PrimExpr lanes = a_ramp->lanes;
return Ramp(a_ramp->base * b, a_ramp->stride * b, lanes);
}
if (b_ramp && a.dtype().is_scalar() && analyzer_.CanProve(a > 0)) {
PrimExpr lanes = b_ramp->lanes;
return Ramp(b_ramp->base * a, b_ramp->stride * a, lanes);
}
int a_lanes = a.dtype().get_lanes_or_vscale_factor();
int b_lanes = b.dtype().get_lanes_or_vscale_factor();
int max_lanes = std::max(a_lanes, b_lanes);
bool is_scalable = a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector();
return Mul(BroadcastTo(a, max_lanes, is_scalable), BroadcastTo(b, max_lanes, is_scalable));
}
}
return BinaryVec<Mul>(op);
}
PrimExpr VisitExpr_(const DivNode* op) final { return BinaryVec<Div>(op); }
PrimExpr VisitExpr_(const ModNode* op) final { return BinaryVec<Mod>(op); }
PrimExpr VisitExpr_(const FloorDivNode* op) final { return BinaryVec<FloorDiv>(op); }
PrimExpr VisitExpr_(const FloorModNode* op) final { return BinaryVec<FloorMod>(op); }
PrimExpr VisitExpr_(const MinNode* op) final { return BinaryVec<Min>(op); }
PrimExpr VisitExpr_(const MaxNode* op) final { return BinaryVec<Max>(op); }
PrimExpr VisitExpr_(const EQNode* op) final { return BinaryVec<EQ>(op); }
PrimExpr VisitExpr_(const NENode* op) final { return BinaryVec<NE>(op); }
PrimExpr VisitExpr_(const LTNode* op) final { return BinaryVec<LT>(op); }
PrimExpr VisitExpr_(const LENode* op) final { return BinaryVec<LE>(op); }
PrimExpr VisitExpr_(const GTNode* op) final { return BinaryVec<GT>(op); }
PrimExpr VisitExpr_(const GENode* op) final { return BinaryVec<GE>(op); }
PrimExpr VisitExpr_(const AndNode* op) final { return BinaryVec<And>(op); }
PrimExpr VisitExpr_(const OrNode* op) final { return BinaryVec<Or>(op); }
PrimExpr VisitExpr_(const NotNode* op) final {
PrimExpr a = this->VisitExpr(op->a);
if (a.same_as(op->a)) {
return GetRef<PrimExpr>(op);
} else {
return !(a);
}
}
PrimExpr VisitExpr_(const RampNode* op) final {
PrimExpr base = this->VisitExpr(op->base);
PrimExpr stride = this->VisitExpr(op->stride);
ICHECK(!base.dtype().is_scalable_vector())
<< "Creating scalable vectors from existing vectors is not supported.";
ICHECK(!stride.dtype().is_scalable_vector())
<< "Ramp stride with scalable dtype is not supported";
if (base.dtype().is_fixed_length_vector() && stride.dtype().is_scalar()) {
ICHECK(op->lanes->IsInstance<IntImmNode>())
<< "Vectorizing over existing scalable vectors is not supported.";
const RampNode* base_ramp = base.as<RampNode>();
int op_lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
int base_ramp_lanes = static_cast<int>(Downcast<IntImm>(base_ramp->lanes)->value);
if (analyzer_.CanProve(base_ramp->stride ==
stride * make_const(stride.dtype(), base_ramp_lanes))) {
return Ramp(base_ramp->base, stride, op_lanes * base_ramp_lanes);
}
}
int lanes = std::max(base.dtype().lanes(), stride.dtype().lanes());
base = BroadcastTo(base, lanes, false);
stride = BroadcastTo(stride, lanes, false);
Array<PrimExpr> elems;
for (int i = 0; i < lanes; ++i) {
elems.push_back(
Ramp(Shuffle::ExtractElement(base, i), Shuffle::ExtractElement(stride, i), op->lanes));
}
return Shuffle::Concat(elems);
}
PrimExpr VisitExpr_(const BroadcastNode* op) final {
PrimExpr value = this->VisitExpr(op->value);
if (value.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true;
return GetRef<PrimExpr>(op);
}
if (value.same_as(op->value)) {
return GetRef<PrimExpr>(op);
} else {
return Broadcast(op->value, op->lanes);
}
}
PrimExpr VisitExpr_(const SelectNode* op) final {
PrimExpr cond = this->VisitExpr(op->condition);
PrimExpr t = this->VisitExpr(op->true_value);
PrimExpr f = this->VisitExpr(op->false_value);
if (cond.same_as(op->condition) && t.same_as(op->true_value) && f.same_as(op->false_value)) {
return GetRef<PrimExpr>(op);
} else {
int cond_lanes = cond.dtype().get_lanes_or_vscale_factor();
int t_lanes = t.dtype().get_lanes_or_vscale_factor();
int f_lanes = f.dtype().get_lanes_or_vscale_factor();
int lanes = std::max(std::max(cond_lanes, t_lanes), f_lanes);
bool is_scalable = cond.dtype().is_scalable_vector() || t.dtype().is_scalable_vector() ||
f.dtype().is_scalable_vector();
return Select(BroadcastTo(cond, lanes, is_scalable), BroadcastTo(t, lanes, is_scalable),
BroadcastTo(f, lanes, is_scalable));
}
}
PrimExpr VisitExpr_(const CastNode* op) final {
PrimExpr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) {
return GetRef<PrimExpr>(op);
} else {
if (value.dtype().is_scalable_vector()) {
return Cast(op->dtype.with_scalable_vscale_factor(value.dtype().vscale_factor()), value);
} else {
return Cast(op->dtype.with_lanes(value.dtype().lanes()), value);
}
}
}
PrimExpr VisitExpr_(const FloatImmNode* op) final { return GetRef<PrimExpr>(op); }
PrimExpr VisitExpr_(const IntImmNode* op) final { return GetRef<PrimExpr>(op); }
PrimExpr VisitExpr_(const StringImmNode* op) final { return GetRef<PrimExpr>(op); }
// Variable
PrimExpr VisitExpr_(const VarNode* op) final {
Var var = GetRef<Var>(op);
if (var.same_as(var_)) {
return ramp_;
}
auto it = let_binding_.find(var);
if (it != let_binding_.end()) {
return it->second;
} else {
return std::move(var);
}
}
// IfThenElse expr
PrimExpr MutateIfThenElseExpr_(const CallNode* op) {
PrimExpr cond = this->VisitExpr(op->args[0]);
if (cond.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true;
return GetRef<PrimExpr>(op);
}
PrimExpr t = this->VisitExpr(op->args[1]);
PrimExpr f = this->VisitExpr(op->args[2]);
if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && f.same_as(op->args[2])) {
return GetRef<PrimExpr>(op);
} else {
int t_lanes = t.dtype().get_lanes_or_vscale_factor();
int f_lanes = f.dtype().get_lanes_or_vscale_factor();
int lanes = std::max(t_lanes, f_lanes);
bool is_scalable = t.dtype().is_scalable_vector() || f.dtype().is_scalable_vector();
t = BroadcastTo(t, lanes, is_scalable);
f = BroadcastTo(f, lanes, is_scalable);
if (is_scalable) {
return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, {cond, t, f});
} else {
return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f});
}
}
}
// Reinterpret expr
PrimExpr MutateReinterpretExpr_(const CallNode* op) {
ICHECK(op->op.same_as(builtin::reinterpret()));
PrimExpr value = this->VisitExpr(op->args[0]);
if (value.same_as(op->args[0])) {
return GetRef<PrimExpr>(op);
} else {
int lanes = value.dtype().get_lanes_or_vscale_factor();
if (value.dtype().is_scalable_vector()) {
return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, {value});
} else {
return Call(op->dtype.with_lanes(lanes), op->op, {value});
}
}
}
// Call
PrimExpr VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::if_then_else())) {
return MutateIfThenElseExpr_(op);
} else if (op->op.same_as(builtin::texture2d_load())) {
int lane = 0;
Array<PrimExpr> fcd = MutateArray({op->args.back()}, &lane);
auto new_args = op->args;
new_args.pop_back();
new_args.push_back(fcd[0]);
return Call(op->dtype.with_lanes(4), op->op, new_args);
} else if (op->op.same_as(builtin::texture2d_store())) {
int lane = 0;
// Vectorize the value to store
Array<PrimExpr> value{op->args.back()};
Array<PrimExpr> mutated_value = MutateArray(value, &lane);
Array<PrimExpr> new_args{op->args[0], op->args[1], op->args[2], mutated_value[0]};
return Call(op->dtype.with_lanes(lane), op->op, new_args);
} else if (op->op.same_as(builtin::reinterpret())) {
return MutateReinterpretExpr_(op);
}
auto optional_op = op->op.as<Op>();
bool vectorizable = optional_op && op_vectorizable_.get(optional_op.value(), false) &&
!op->dtype.is_scalable_vector();
if (!vectorizable) {
// Cannot vectorize this op
Array<PrimExpr> new_args;
for (auto arg : op->args) {
auto new_arg = this->VisitExpr(arg);
if (new_arg.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true;
return GetRef<PrimExpr>(op);
}
new_args.push_back(new_arg);
}
if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op);
} else {
return Call(op->dtype, op->op, new_args);
}
} else {
int lane = 0;
Array<PrimExpr> new_args = MutateArray(op->args, &lane);
// normal code path.
if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op);
} else {
return Call(op->dtype.with_lanes(lane), op->op, new_args);
}
}
}
// BufferLoad
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
auto load = GetRef<BufferLoad>(op);
auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); };
Array<PrimExpr> indices = op->indices.Map(fmutate);
if (!indices.same_as(op->indices)) {
auto writer = load.CopyOnWrite();
writer->indices = indices;
}
return std::move(load);
}
// Let
PrimExpr VisitExpr_(const LetNode* op) final {
PrimExpr value = this->VisitExpr(op->value);
// Weaker SSA condition
// A single var can be binded in multiple lets
// but they have to bind to the same value.
// This is used to allow cases when we reuse a single let
// expression to construct a nested expr.
// (let x = 1 in x + 1) * (let x = 1 in x + 1)
auto it = let_binding_.find(op->var);
if (it != let_binding_.end()) {
ICHECK(deep_equal_(it->second, value))
<< "Let cannot bind the same var to two different values";
}
if (value.dtype().get_lanes_or_vscale_factor() !=
op->value.dtype().get_lanes_or_vscale_factor()) {
Var new_var(op->var->name_hint, value.dtype());
let_binding_[op->var] = new_var;
return Let(new_var, value, this->VisitExpr(op->body));
} else {
let_binding_[op->var] = op->var;
PrimExpr body = this->VisitExpr(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<PrimExpr>(op);
} else {
return Let(op->var, value, body);
}
}
}
// BufferStore
Stmt VisitStmt_(const BufferStoreNode* op) final {
auto store = GetRef<BufferStore>(op);
auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); };
Array<PrimExpr> indices = op->indices.Map(fmutate);
PrimExpr value = this->VisitExpr(op->value);
if (!indices.same_as(op->indices) || !value.same_as(op->value)) {
ICHECK(!op->buffer->dtype.is_scalable_vector())
<< "Vectorizing over scalable buffer elements is not supported in vectorizer.";
// How many lanes of indexing are present in the index and
// buffer element type, excluding the last index.
int other_index_lanes = op->buffer->dtype.lanes();
for (size_t i = 0; i < indices.size() - 1; i++) {
other_index_lanes *= indices[i].dtype().lanes();
// Only allow the last index to be scalable
ICHECK(!indices[i].dtype().is_scalable_vector()) << "Only the last index can be scalable.";
}
// The total number of lanes of indexing, including the last index.
auto last_index_dtype = indices[indices.size() - 1].dtype();
int lanes_in_last_index = last_index_dtype.get_lanes_or_vscale_factor();
int index_lanes = other_index_lanes * lanes_in_last_index;
// The total number of lanes in this store operation. Either
// the index or the value will be broadcast out to this number
// of lanes, depending on which has more lanes.
int value_dtype_lanes = value.dtype().get_lanes_or_vscale_factor();
bool is_last_index_scalable = last_index_dtype.is_scalable_vector();
int total_lanes = std::max(index_lanes, value_dtype_lanes);
ICHECK_EQ(total_lanes % other_index_lanes, 0)
<< "When storing to buffer " << op->buffer->name << ", cannot produce " << total_lanes
<< " lanes of storage location by changing the last index.";
int last_index_lanes = total_lanes / other_index_lanes;
// Broadcast the last index such that the total number of index
// lanes matches the desired number.
indices.Set(indices.size() - 1, BroadcastTo(indices[indices.size() - 1], last_index_lanes,
is_last_index_scalable));
auto writer = store.CopyOnWrite();
writer->indices = indices;
writer->value = BroadcastTo(value, total_lanes, is_last_index_scalable);
}
return std::move(store);
}
// For
Stmt VisitStmt_(const ForNode* op) final {
if (op->kind == ForKind::kVectorized) {
LOG(WARNING) << "Detect vectorize inside vectorized loop, ignoring...";
}
ICHECK(is_zero(op->min));
ICHECK(!op->extent.dtype().is_scalable_or_fixed_length_vector());
PrimExpr extent = this->VisitExpr(op->extent);
if (extent.dtype().is_scalable_or_fixed_length_vector()) {
return Scalarize(GetRef<Stmt>(op));
}
Stmt body = this->VisitStmt(op->body);
if (extent.same_as(op->extent) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
return For(op->loop_var, op->min, extent, op->kind, body, op->thread_binding,
op->annotations);
}
}
// IfThenElse
Stmt VisitStmt_(const IfThenElseNode* op) final {
ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector());
PrimExpr condition = this->VisitExpr(op->condition);
if (condition.dtype().is_scalable_or_fixed_length_vector()) {
return Scalarize(GetRef<Stmt>(op));
}
Stmt then_case = this->VisitStmt(op->then_case);
Optional<Stmt> else_case = NullOpt;
if (op->else_case) {
else_case = this->VisitStmt(op->else_case.value());
}
if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op);
} else {
return IfThenElse(condition, then_case, else_case);
}
}
// While
Stmt VisitStmt_(const WhileNode* op) final {
LOG(FATAL) << "A while loop inside a vectorized loop not supported.";
}
// LetStmt
Stmt VisitStmt_(const LetStmtNode* op) final {
PrimExpr value = this->VisitExpr(op->value);
ICHECK(!let_binding_.count(op->var)) << "SSA violation, a single var is binded twice";
let_binding_[op->var] = value;
if (value.dtype().get_lanes_or_vscale_factor() !=
op->value.dtype().get_lanes_or_vscale_factor()) {
Var new_var(op->var->name_hint, value.dtype());
let_binding_[op->var] = new_var;
return LetStmt(new_var, value, this->VisitStmt(op->body));
} else {
let_binding_[op->var] = op->var;
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
return LetStmt(op->var, value, body);
}
}
}
// Allocate
Stmt VisitStmt_(const AllocateNode* op) final {
// Mutate the condition
PrimExpr condition = this->VisitExpr(op->condition);
if (condition.dtype().is_scalable_or_fixed_length_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc of " << op->buffer_var->name_hint;
return Scalarize(GetRef<Stmt>(op));
}
// Mutate the extents
Array<PrimExpr> extents;
for (const auto& extent : op->extents) {
PrimExpr new_ext = this->VisitExpr(extent);
if (new_ext.dtype().is_scalable_or_fixed_length_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc of " << op->buffer_var->name_hint;
return Scalarize(GetRef<Stmt>(op));
}
extents.push_back(new_ext);
}
// TODO(Lunderberg): Move this pass to be prior to
// StorageFlatten/FlattenBuffer. That will allow this pass to be
// implemented as adding a new buffer dimension, which is later
// flattened.
// Extend the least significant dimension by a factor of
// var_lanes_. Typically, this will be a 1-d index into a flat
// memory space.
extents.Set(extents.size() - 1, extents[extents.size() - 1] * var_lanes_);
// Rewrite access to the buffer in the body.
Stmt body = VecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body);
body = this->VisitStmt(body);
return Allocate(op->buffer_var, op->dtype, extents, condition, body);
}
// scalarize the statement
Stmt Scalarize(Stmt stmt) {
Var idx(var_->name_hint + ".s", var_->dtype);
stmt = Substitute(stmt, {{var_, idx}});
return For(idx, IntImm(var_->dtype, 0), var_lanes_, ForKind::kSerial, stmt);
}
// ProducerStore
Stmt VisitStmt_(const ProducerStoreNode* op) final {
LOG(FATAL) << "ProducerProvide cannot appear in a TIR PrimFunc";
}
private:
// analyzer
arith::Analyzer analyzer_;
// deep equal
ExprDeepEqual deep_equal_;
// variable to be replaced
Var var_;
// the lanes.
PrimExpr var_lanes_;
// ramp representing the var.
PrimExpr ramp_;
// flag to mark requirement of scalarization.
bool need_scalarize_{false};
// Let binding
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> let_binding_;
// vectorizable property
OpAttrMap<TVectorizable> op_vectorizable_ = Op::GetAttrMap<TVectorizable>("TVectorizable");
// mutate array, with given lane requirement
// when finished, p_lane updates the lane requirement.
Array<PrimExpr> MutateArray(Array<PrimExpr> arr, int* p_lanes) {
if (arr.size() == 0) return arr;
int& lanes = *p_lanes;
bool changed = false;
std::vector<PrimExpr> new_arr(arr.size());
for (size_t i = 0; i < arr.size(); i++) {
PrimExpr old_elem = arr[i];
PrimExpr new_elem = this->VisitExpr(old_elem);
if (!new_elem.same_as(old_elem)) changed = true;
new_arr[i] = new_elem;
lanes = std::max(lanes, new_elem.dtype().lanes());
}
for (size_t i = 0; i < arr.size(); ++i) {
if (new_arr[i].dtype().lanes() != lanes) {
new_arr[i] = BroadcastTo(new_arr[i], lanes, false);
changed = true;
}
}
if (!changed) return arr;
return Array<PrimExpr>(new_arr);
}
template <typename TOp, typename T>
PrimExpr BinaryVec(const T* op) {
static_assert(std::is_same<typename TOp::ContainerType, T>::value, "constraint");
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
} else {
int a_lanes = a.dtype().get_lanes_or_vscale_factor();
int b_lanes = b.dtype().get_lanes_or_vscale_factor();
int lanes = std::max(a_lanes, b_lanes);
bool is_scalable = a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector();
return TOp(BroadcastTo(a, lanes, is_scalable), BroadcastTo(b, lanes, is_scalable));
}
}
template <typename T, typename FCompute>
PrimExpr AddSubVec(const T* op, FCompute fcompute) {
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
} else {
int a_lanes = a.dtype().get_lanes_or_vscale_factor();
int b_lanes = b.dtype().get_lanes_or_vscale_factor();
int lanes = std::max(a_lanes, b_lanes);
if (lanes != 1) {
const RampNode* b_ramp = b.as<RampNode>();
const RampNode* a_ramp = a.as<RampNode>();
if (a.dtype().is_scalar() && b_ramp) {
return Ramp(fcompute(a, b_ramp->base),
fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride), b_ramp->lanes);
}
if (b.dtype().is_scalar() && a_ramp) {
return Ramp(fcompute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
}
}
bool is_scalable = a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector();
return fcompute(BroadcastTo(a, lanes, is_scalable), BroadcastTo(b, lanes, is_scalable));
}
}
};
} // namespace tl
} // namespace tvm
\ No newline at end of file
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file frontend_legalize.cc
* \brief Legalize the program from frontend
*/
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "arith/ir_mutator_with_analyzer.h"
namespace tvm {
namespace tl {
using namespace tir;
class FrontendLegalizer : public arith::IRMutatorWithAnalyzer {
public:
static PrimFunc Substitute(PrimFunc f) {
arith::Analyzer analyzer;
FrontendLegalizer substituter(&analyzer);
PrimFuncNode* fptr = f.CopyOnWrite();
fptr->body = substituter.VisitStmt(f->body);
return f;
}
private:
using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer;
Stmt VisitStmt_(const ForNode* node) final {
if (node->kind == ForKind::kParallel) {
parallel_for_scope_++;
}
auto n = StmtExprMutator::VisitStmt_(node);
if (node->kind == ForKind::kParallel) {
parallel_for_scope_--;
}
return n;
}
PrimExpr VisitExpr_(const VarNode* node) final {
if (let_bindings_.count(node)) {
return arith::IRMutatorWithAnalyzer::VisitExpr(let_bindings_[node]);
} else {
return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
}
}
Stmt VisitStmt_(const LetStmtNode* node) final {
let_bindings_[node->var.get()] = node->value;
return arith::IRMutatorWithAnalyzer::VisitStmt(node->body);
}
PrimExpr VisitExpr_(const LetNode* node) final {
let_bindings_[node->var.get()] = node->value;
return arith::IRMutatorWithAnalyzer::VisitExpr(node->body);
}
int parallel_for_scope_ = 0;
std::unordered_map<const VarNode*, PrimExpr> let_bindings_;
};
using namespace tir::transform;
Pass FrontendLegalize() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return FrontendLegalizer::Substitute(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tl.FrontendLegalize", {});
}
TVM_REGISTER_GLOBAL("tl.transform.FrontendLegalize")
.set_body_typed(FrontendLegalize);
} // namespace tl
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file inject_fence_proxy.cc
* \brief Inject fence between generic and async proxies (sm90+)
*/
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../op/builtin.h"
namespace tvm {
namespace tl {
using namespace tir;
enum class Proxy { kGeneric, kAsync, kBoth };
class ProxyMarker : public StmtVisitor {
public:
ProxyMarker() = default;
Proxy GetProxy(const StmtNode* stmt) const {
auto it = map_.find(stmt);
// ICHECK(it != map_.end());
// TODO: This is a hack implementation to avoid the ICHECK failure.
if (it == map_.end()) {
return Proxy::kGeneric;
}
return it->second;
}
Proxy GetProxy(const Stmt& stmt) const { return GetProxy(stmt.get()); }
void VisitStmt_(const EvaluateNode* op) final {
Proxy proxy = Proxy::kAsync;
if (auto call = op->value.as<CallNode>()) {
if (call->op.same_as(LDMatrixOp()) || call->op.same_as(STMatrixOp())) {
proxy = Proxy::kGeneric;
}
}
SetProxy(op, proxy);
}
void VisitStmt_(const BufferStoreNode* op) final {
Proxy proxy = Proxy::kGeneric;
SetProxy(op, proxy);
}
void VisitStmt_(const SeqStmtNode* op) final {
StmtVisitor::VisitStmt_(op);
auto role = GetProxy(op->seq[0]);
for (auto stmt : op->seq) {
if (role != GetProxy(stmt)) {
role = Proxy::kBoth;
break;
}
}
SetProxy(op, role);
}
void VisitStmt_(const IfThenElseNode* op) final {
StmtVisitor::VisitStmt_(op);
auto role = GetProxy(op->then_case);
if (op->else_case.defined()) {
auto role_else = GetProxy(op->else_case.value());
if (role != role_else) role = Proxy::kBoth;
}
SetProxy(op, role);
}
void VisitStmt_(const BlockRealizeNode* op) final {
StmtVisitor::VisitStmt_(op);
SetProxy(op, GetProxy(op->block));
}
template <class NodeType>
void HandleBodyStmt(const NodeType* op) {
StmtVisitor::VisitStmt_(op);
SetProxy(op, GetProxy(op->body));
}
void VisitStmt_(const ForNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const LetStmtNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const AttrStmtNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const AssertStmtNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const BlockNode* op) final { HandleBodyStmt(op); }
private:
void SetProxy(const StmtNode* stmt, Proxy proxy) { map_[stmt] = proxy; }
std::unordered_map<const StmtNode*, Proxy> map_;
};
class InjectFenceProxy : public StmtExprMutator {
public:
static PrimFunc Substitute(PrimFunc f) {
auto T = InjectFenceProxy();
f.CopyOnWrite()->body = T(f->body);
return f;
}
private:
Proxy get_generic_proxy(const Stmt& stmt) {
auto marker = ProxyMarker();
marker(stmt);
return marker.GetProxy(stmt);
}
Stmt VisitStmt_(const SeqStmtNode* op) final {
ICHECK(op->seq.size() > 0);
Array<Stmt> new_body;
Proxy cur_proxy, prev_proxy;
auto fence_stmt = Evaluate(Call(DataType::Handle(), FenceProxyAsyncOp(), {}));
prev_proxy = get_generic_proxy(op->seq[0]);
new_body.push_back(VisitStmt(op->seq[0]));
if (op->seq.size() > 1) {
for (int i = 1; i < static_cast<int>(op->seq.size()); i++) {
cur_proxy = get_generic_proxy(op->seq[i]);
if (cur_proxy == Proxy::kAsync && prev_proxy == Proxy::kGeneric) {
new_body.push_back(fence_stmt);
}
new_body.push_back(VisitStmt(op->seq[i]));
prev_proxy = cur_proxy;
}
}
ICHECK(new_body.size() > 0);
return new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body));
}
// Stmt VisitStmt_(const ForNode* op) final {
// std::cout << "ForNode:" << op->body->GetTypeKey() << std::endl;
// return StmtExprMutator::VisitStmt_(op);
// }
InjectFenceProxy() = default;
};
using namespace tir::transform;
tvm::transform::Pass InjectFenceProxy() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return InjectFenceProxy::Substitute(f);
};
return CreatePrimFuncPass(pass_func, 0, "tl.InjectFenceProxy", {});
}
TVM_REGISTER_GLOBAL("tl.transform.InjectFenceProxy")
.set_body_typed(InjectFenceProxy);
} // namespace tl
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file inject_software_pipeline.cc
* \brief Transform annotated loops into pipelined one that parallelize producers and consumers
*/
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/transform.h>
#include <unordered_set>
#include "support/utils.h"
#include "tir/schedule/utils.h"
#include "tir/transforms/ir_utils.h"
namespace tvm {
namespace tl {
using namespace tir;
/*!
* \brief Create a block and infer the access region with the given body.
*
* The result is a opaque block that doesn't contain any block iter vars. In case the body is a
* block realize without predicate, it is unnecessary to create a new block, the block of the block
* realize will be returned.
*
* \param body The body of the block.
* \param buffer_data_to_buffer The map from buffer data to buffer.
* \return The result block.
*/
Block MakeBlock(const Stmt& body, const Map<Var, Buffer>& buffer_data_to_buffer) {
if (const BlockRealizeNode* block_realize = body.as<BlockRealizeNode>()) {
if (is_one(block_realize->predicate)) {
// no need to create a new block
return block_realize->block;
}
}
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ body);
Array<Array<BufferRegion>> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer);
BlockNode* n = block.CopyOnWrite();
n->reads = access[0];
n->writes = access[1];
return block;
}
/*! Structure that represents the provided annotation per block or loop. */
struct PipelineAnnotation {
int stage;
int order;
bool async;
};
using PipelineInfo = std::unordered_map<Block, PipelineAnnotation, ObjectPtrHash, ObjectPtrEqual>;
struct BufferAccessInfo {
int def = -1; // the defining stage of the buffer
int use = -1; // the last using stage of the buffer
};
/*!
* \brief Rewriter for the body of the software pipeline. This pass inserts `floormod` to indices
* of the remapped buffer to select the version corresponding to the pipeline stage.
*/
class PipelineBodyRewriter : public StmtExprMutator {
public:
/*!
* \brief Constructor of PipelineBodyRewriter.
* \param buffer_data_to_buffer The map from buffer data to buffer.
* \param buffer_remap The map from original buffer to the buffer with updated shape for
* multi-versioning in the software pipeline.
* \param pipeline_loop The original loop to be software pipelined.
* \param access_all_versions Whether all versions the buffers in the software pipeline are
* accessed. This will be used to update block access region. In the prologue and epilogue
* of a two-stage software pipeline, only one version of these buffers are accessed.
*/
PipelineBodyRewriter(const Map<Var, Buffer>& buffer_data_to_buffer,
const Map<Buffer, Buffer>& buffer_remap, For pipeline_loop,
bool access_all_versions)
: buffer_data_to_buffer_(buffer_data_to_buffer),
buffer_remap_(buffer_remap),
pipeline_loop_(pipeline_loop),
access_all_versions_(access_all_versions) {}
private:
BufferRegion RewritePipelineBufferRegion(const BufferRegion& buffer_region) const {
auto it = buffer_remap_.find(buffer_region->buffer);
if (it != buffer_remap_.end()) {
Region new_region = buffer_region->region;
const Buffer& new_buffer = (*it).second;
// For pipeline buffers, relax the access region of the first dimension to full extent
// if access_all_versions == true
Range accessed_version =
access_all_versions_
? Range::FromMinExtent(0, new_buffer->shape[0])
: Range::FromMinExtent(floormod((pipeline_loop_->loop_var - pipeline_loop_->min),
new_buffer->shape[0]),
Integer(1));
new_region.insert(new_region.begin(), accessed_version);
return BufferRegion(new_buffer, new_region);
}
return buffer_region;
}
PrimExpr RewriteBufferAccess(const Call& call, const std::vector<int> arg_indices) {
auto product = [](const Array<PrimExpr>& input) {
return foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
make_const(DataType::Int(32), 1), input);
};
Array<PrimExpr> new_args = call->args;
for (int i : arg_indices) {
const Buffer& buffer = buffer_data_to_buffer_.at(Downcast<Var>(call->args[i]));
auto it = buffer_remap_.find(buffer);
if (it != buffer_remap_.end()) {
const Buffer& new_buffer = (*it).second;
const PrimExpr& old_index = call->args[i + 1];
PrimExpr offset;
if (new_buffer->strides.empty()) {
offset = product(buffer->shape);
} else {
offset = new_buffer->strides[0];
}
PrimExpr new_index =
old_index + floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset;
new_args.Set(i + 1, new_index);
}
}
return Call(call->dtype, call->op, new_args, call->span);
}
Stmt VisitStmt_(const BlockNode* op) final {
for (const Buffer& alloc_buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer);
}
Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
BlockNode* n = block.CopyOnWrite();
n->reads.MutateByApply([this](const BufferRegion& buffer_region) {
return RewritePipelineBufferRegion(buffer_region);
});
n->writes.MutateByApply([this](const BufferRegion& buffer_region) {
return RewritePipelineBufferRegion(buffer_region);
});
for (const Buffer& alloc_buffer : op->alloc_buffers) {
buffer_data_to_buffer_.erase(alloc_buffer->data);
}
return std::move(block);
}
Stmt VisitStmt_(const BufferStoreNode* op) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
auto it = buffer_remap_.find(store->buffer);
if (it == buffer_remap_.end()) {
return std::move(store);
}
const Buffer& new_buffer = (*it).second;
auto* n = store.CopyOnWrite();
n->buffer = new_buffer;
PrimExpr version =
floormod((pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
n->indices.insert(n->indices.begin(), version);
return std::move(store);
}
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
auto it = buffer_remap_.find(load->buffer);
if (it == buffer_remap_.end()) {
return std::move(load);
}
const Buffer& new_buffer = (*it).second;
auto* n = load.CopyOnWrite();
n->buffer = new_buffer;
PrimExpr version =
floormod((pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
n->indices.insert(n->indices.begin(), version);
return std::move(load);
}
PrimExpr VisitExpr_(const CallNode* op) final {
Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
if (call->op.same_as(builtin::tvm_access_ptr())) {
return RewriteBufferAccess(call, {1});
}
return call;
}
Map<Var, Buffer> buffer_data_to_buffer_;
Map<Buffer, Buffer> buffer_remap_;
For pipeline_loop_;
bool access_all_versions_;
};
/*!
* \brief Rewriter for the software pipeline that rewrite a loop into a pipelined one.
*/
class PipelineRewriter : public StmtExprMutator {
public:
PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer, const Array<Buffer>& pipeline_allocs,
const For& pipeline_loop, const PipelineInfo& pipeline_info)
: buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
pipeline_allocs_(pipeline_allocs),
pipeline_loop_(pipeline_loop),
pipeline_info_(pipeline_info) {}
Stmt BuildPipeline() {
// Step 1: Analyze accesses to the buffers in the pipeline and compute the number of versions
// need to maintain for each buffer.
std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual> infos =
GetBufferAccessInfo();
for (const Buffer& buffer : pipeline_allocs_) {
int num_versions = ComputeBufferVersions(buffer, infos.at(buffer));
if (num_versions > 1) {
buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions));
}
}
ordered_stmts_.resize(pipeline_info_.size());
for (const auto& [block, anno] : pipeline_info_) {
ordered_stmts_.Set(anno.order, block);
}
for (const Block& block : ordered_stmts_) {
int stage = pipeline_info_[block].stage;
if (pipeline_info_[block].async) {
auto& state = async_states[stage];
state.producer_head = pipeline_loop_->min - 1;
for (auto write_region : block->writes) {
auto buffer = write_region->buffer;
state.dst_buffers.insert(buffer.get());
if (buffer_remap_.count(buffer)) state.dst_buffers.insert(buffer_remap_[buffer].get());
}
}
}
std::unordered_set<int> consumed;
for (const Block& block : ordered_stmts_) {
int stage = pipeline_info_[block].stage;
if (pipeline_info_[block].async) {
auto& state = async_states[stage];
if (state.commit_groups.empty() || consumed.count(stage)) {
state.commit_groups.push_back({});
}
state.commit_groups.back().push_back(pipeline_info_[block].order);
consumed.erase(stage);
for (auto write_region : block->writes) {
auto buffer = buffer_remap_.count(write_region->buffer)
? buffer_remap_[write_region->buffer]
: write_region->buffer;
state.buffer_to_commit_group_[buffer.get()] = state.commit_groups.size() - 1;
}
}
for (auto read_region : block->reads) {
for (const auto& [producer_stage_id, producer_state] : async_states) {
if (producer_stage_id <= stage && producer_state.writes(read_region->buffer)) {
consumed.insert(producer_stage_id);
}
}
}
}
// Step 2: Emit the pipeline prologue, body and epilogue.
Stmt prologue = EmitImpl(pipeline_loop_->min, pipeline_loop_->min + max_stage_, true, true);
Stmt body = EmitImpl(pipeline_loop_->min + max_stage_,
pipeline_loop_->min + pipeline_loop_->extent, false, false);
Stmt epilogue = EmitImpl(pipeline_loop_->min + pipeline_loop_->extent,
pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true, true);
SeqStmt stmt = SeqStmt({prologue, body, epilogue});
// Step 3: Make a new block that contains new buffer allocations after pipeline rewriting.
Array<Buffer> alloc_buffers;
for (const auto& alloc : pipeline_allocs_) {
alloc_buffers.push_back(buffer_remap_.Get(alloc).value_or(alloc));
buffer_data_to_buffer_.erase(alloc->data);
}
Block block = MakeBlock(stmt, buffer_data_to_buffer_);
block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers);
return BlockRealize({}, Bool(true), block);
}
private:
/*!
* \brief Analyze accesses to the buffers in the software pipeline.
*
* This method check the 'define' and 'use' stage of the buffers in the software pipeline, which
* can be used to compute the number of versions needed to maintain after rewriting.
*/
std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
GetBufferAccessInfo() {
std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual> infos;
for (const auto& pair : pipeline_info_) {
const Block& block = pair.first;
int stage = pair.second.stage;
max_stage_ = std::max(max_stage_, stage);
for (const BufferRegion& write : block->writes) {
if (!infos.count(write->buffer)) {
infos.emplace(write->buffer, BufferAccessInfo{});
}
auto& info = infos.at(write->buffer);
if (info.def == -1) {
info.def = stage;
} else {
info.def = std::min(info.def, stage);
}
}
for (const BufferRegion& read : block->reads) {
if (!infos.count(read->buffer)) {
infos.emplace(read->buffer, BufferAccessInfo{});
}
auto& info = infos.at(read->buffer);
info.use = std::max(info.use, stage);
}
}
return infos;
}
/*!
* \brief Check whether two regions have intersections.
* \param region1 The first region.
* \param region2 The second region.
* \return Whether region1 and region2 have intersections.
*/
bool MayConflict(Region region1, Region region2) {
ICHECK(region1.size() == region2.size());
for (size_t i = 0; i < region1.size(); i++) {
Range dim1 = region1[i];
Range dim2 = region2[i];
auto int_set1 = arith::IntSet::FromRange(dim1);
auto int_set2 = arith::IntSet::FromRange(dim2);
if (arith::Intersect({int_set1, int_set2}).IsNothing()) {
return false;
}
}
return true;
}
/*!
* \brief Compute the number of versions need to maintain for buffer accessed in the software
* pipeline.
*
* This method applies liveness analysis to the target buffer to compute the number of versions
* need to maintain during the software pipeline.
* Annotation `attr::double_buffer_scope` is handled here which provides a way to override the
* result of the analysis. Additional double buffering in the software pipeline can be useful
* to eliminate synchronizations in GPU devices.
*
* \param buffer The target buffer
* \param buffer_info The access information of the target buffer.
* \return The number of versions required for the target buffer.
*/
int ComputeBufferVersions(const Buffer& buffer, const BufferAccessInfo& buffer_info) {
if (buffer_info.def == -1) {
// Keep the original number of versions as buffers defined outside the software pipeline
// should not be mutated.
return 1;
}
// `use - def + 1` is a upper bound of the needed versions
// We optimize a few case where the number of versions can be smaller than the upper bound
int num_versions = buffer_info.use - buffer_info.def + 1;
if (num_versions >= 2) {
// A special case when `use - def + 1 == 2`. Double buffering is only needed in this case when
// these exists a reader block_i and a writer block_j such that
// order(block_i) < order(block_j) and stage(block_i) < stage(block_j) and the access regions
// of block_i and block_j overlap.
bool need_multi_version = false;
for (const auto& pair1 : pipeline_info_) {
const Block& writer_block = pair1.first;
const auto& writer_info = pair1.second;
auto it1 = std::find_if(writer_block->writes.begin(), writer_block->writes.end(),
[&](const BufferRegion& buffer_region) {
return buffer_region->buffer.same_as(buffer);
});
if (it1 == writer_block->writes.end()) {
continue;
}
for (const auto& pair2 : pipeline_info_) {
const Block& reader_block = pair2.first;
const auto& reader_info = pair2.second;
auto it2 = std::find_if(reader_block->reads.begin(), reader_block->reads.end(),
[&](const BufferRegion& buffer_region) {
return buffer_region->buffer.same_as(buffer);
});
if (it2 == reader_block->reads.end()) {
continue;
}
if (writer_info.order < reader_info.order && writer_info.stage < reader_info.stage &&
MayConflict((*it1)->region, (*it2)->region)) {
need_multi_version = true;
break;
}
}
}
if (!need_multi_version) {
num_versions--;
}
}
return num_versions;
}
/*!
* \brief Rewrite buffer allocation to keep multiple versions of original buffer for pipelined
* accesses.
* \param buffer The buffer to be resized.
* \param num_versions The number of versions to keep.
* \return The resized buffer.
*/
Buffer RewriteAllocBuffer(const Buffer& buffer, int num_versions) {
ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get()));
new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
if (new_buffer->strides.size()) {
ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size());
PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1];
new_buffer->strides.insert(new_buffer->strides.begin(), stride_0);
}
return Buffer(new_buffer);
}
// Per-stage states that need to be tracked across pipeline prologue, body, and epilogue.
struct AsyncStateGlobal {
// Buffers that this stage asynchronously writes.
std::unordered_set<const BufferNode*> dst_buffers;
// An imaginary index that the latest async operation associated with this stage has written
// into. Only valid if all associated predicates are true, so that we can count the number of
// async invocations exactly. When it is valid, it is the "sum of extents of loops that have
// been executed" - 1, e.g. for epilogue it is prologue extent + body extent - 1. This
// is only needed to compute wait count for epilogue without async producers.
PrimExpr producer_head;
std::vector<std::vector<int>> commit_groups;
std::unordered_map<const BufferNode*, int> buffer_to_commit_group_;
bool writes(Buffer buf) const { return dst_buffers.count(buf.get()) > 0; }
};
// Per-stage states that are local to each of pipeline prologue, body, and epilogue.
struct AsyncStateLocal {
struct PendingWait {
// The index into a list of blocks, where async_wait_queue should be attached at the
// beginning.
int insert_before;
// in_flight_count would be a more precise name, but the implementation uses wait_count for
// brevity.
PrimExpr wait_count{nullptr};
bool valid() const { return wait_count.defined(); }
};
std::vector<PendingWait> pending_waits;
// A symbolic expression representing the index the latest async operation associated with this
// stage has written into, at the "current" iteration.
Optional<PrimExpr> producer_head;
};
/*! Structure holding intermediate information for pipeline loop rewriting. */
struct RewrittenBlockInfo {
int stage;
int order;
PrimExpr predicate;
Block block;
PrimExpr access_index;
bool is_async;
};
void PopulateWaitCounts(const std::vector<RewrittenBlockInfo>& new_blocks,
std::map<int, AsyncStateLocal>* async_states_local) {
for (size_t i = 0; i < new_blocks.size(); ++i) {
int producer_stage_idx = -1;
for (auto read_region : new_blocks[i].block->reads) {
for (const auto& [stage, state] : async_states) {
if (stage <= new_blocks[i].stage && state.writes(read_region->buffer)) {
// Found an earlier stage where read_region->buffer was asynchronously written
ICHECK(producer_stage_idx == -1 || producer_stage_idx == stage)
<< "A dependency on multiple async stages is not supported";
producer_stage_idx = stage;
}
}
}
if (producer_stage_idx == -1) continue;
const auto& state = async_states[producer_stage_idx];
auto& dep_local_state = (*async_states_local)[producer_stage_idx];
PrimExpr in_flight_cnt = 0;
for (const auto& group : state.commit_groups) {
PrimExpr consumer_head = new_blocks[i].access_index;
PrimExpr producer_head;
if (dep_local_state.producer_head.defined()) {
producer_head = dep_local_state.producer_head.value();
// if the group is after the wait point, minus by 1
if (group.front() > new_blocks[i].order) producer_head -= 1;
} else {
producer_head = state.producer_head;
}
in_flight_cnt += producer_head - consumer_head;
}
// We can relax the in-flight-count by the number of independent commit.
std::unordered_set<int> dependent_groups;
for (const auto& read_region : new_blocks[i].block->reads) {
if (state.buffer_to_commit_group_.count(read_region->buffer.get()))
dependent_groups.insert(state.buffer_to_commit_group_.at(read_region->buffer.get()));
}
for (int i = int(state.commit_groups.size()) - 1; i >= 0; i--) {
if (dependent_groups.count(i) == 0)
in_flight_cnt += 1;
else
break; // stop relaxing
}
in_flight_cnt = analyzer_.Simplify(in_flight_cnt);
dep_local_state.pending_waits.push_back({static_cast<int>(i), in_flight_cnt});
}
}
// Given pipelined blocks and async-related information, generate final loop statements with async
// scopes (if any).
Array<Stmt> CompletePipelineLoopStatements(
const std::vector<RewrittenBlockInfo>& blocks,
const std::map<int, AsyncStateLocal>& async_states_local) const {
std::vector<RewrittenBlockInfo> new_blocks = blocks;
for (const auto& [stage_id, state] : async_states_local) {
for (const auto& pw : state.pending_waits) {
auto& block = new_blocks[pw.insert_before].block;
BlockNode* n = block.CopyOnWrite();
auto zero = make_zero(DataType::Int(32));
n->body =
AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id,
AttrStmt(zero, tir::attr::async_wait_inflight_count, pw.wait_count, n->body));
}
}
// mark the last async stmt as commit
std::unordered_set<int> commit_group_indices;
for (const auto& [stage_id, state] : async_states) {
for (size_t i = 0; i < state.commit_groups.size(); ++i) {
commit_group_indices.insert(state.commit_groups[i].back());
}
}
Array<Stmt> stmts;
for (size_t i = 0; i < new_blocks.size(); i++) {
Block block = new_blocks[i].block;
if (commit_group_indices.count(new_blocks[i].order)) {
auto commit_queue_scope =
AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_commit_queue_scope,
new_blocks[i].stage, block->body);
block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_);
}
stmts.push_back(BlockRealize({}, new_blocks[i].predicate, block));
}
return stmts;
}
/*!
* \brief Emit the pipeline loop in the given range.
* \param start The start of the range
* \param end The end of the range
* \param unroll_loop Whether the loop should be unrolled.
* \return The result loop.
*/
Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop, bool need_bound_check) {
PrimExpr new_loop_var;
PrimExpr extent = end - start;
auto make_nop = []() { return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {})); };
bool is_unit_loop = analyzer_.CanProveEqual(extent, 1);
if (is_unit_loop) {
new_loop_var = start; // use constants as the loop var for unit loops
} else {
new_loop_var = pipeline_loop_->loop_var.copy_with_suffix("");
analyzer_.Bind(Downcast<Var>(new_loop_var), Range(start, end));
}
std::vector<RewrittenBlockInfo> new_blocks;
// Async related
std::map<int, AsyncStateLocal> async_states_local;
for (const Block& block : ordered_stmts_) {
int stage = pipeline_info_.at(block).stage;
int order = pipeline_info_.at(block).order;
PrimExpr inbound = Bool(true);
PrimExpr skewed_loop_var = new_loop_var - stage;
if (need_bound_check)
inbound = analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) &&
(skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent);
if (analyzer_.CanProve(!inbound)) {
continue;
}
Block new_block = Downcast<Block>(PipelineBodyRewriter(
buffer_data_to_buffer_, buffer_remap_, pipeline_loop_, max_stage_ != 1)(block));
PrimExpr delta = start - pipeline_loop_->min;
// This variable corresponds to
// - "producer_head" if this stage is an async producer
// - "consumer_head" if this stage reads from asynchronously written buffers.
PrimExpr normalized_access_index = is_unit_loop ? skewed_loop_var : skewed_loop_var + delta;
// Adjust the block predicate and the body according to the final loop bound
// [pipeline_loop_->min, extent).
if (!is_unit_loop) {
Var loop_iter = Downcast<Var>(new_loop_var);
inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}});
}
new_block = Downcast<Block>(
Substitute(new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
if (pipeline_info_[block].async) {
auto& local_state = async_states_local[stage];
local_state.producer_head = normalized_access_index;
BlockNode* n = new_block.CopyOnWrite();
n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope, 1, n->body);
}
new_blocks.push_back(
{stage, order, inbound, new_block, normalized_access_index, pipeline_info_[block].async});
}
PopulateWaitCounts(new_blocks, &async_states_local);
auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local);
Stmt new_loop{nullptr};
if (stmts.empty()) {
return make_nop();
}
if (stmts.size() == 1) {
new_loop = stmts[0];
} else {
new_loop = SeqStmt(stmts);
}
if (!is_unit_loop) {
Map<String, ObjectRef> preserved_annotations;
for (const auto& kv : pipeline_loop_->annotations) {
const String& key = kv.first;
if (kv.first != tir::attr::software_pipeline_stage &&
kv.first != tir::attr::software_pipeline_order &&
kv.first != tir::attr::software_pipeline_async_stages) {
preserved_annotations.Set(key, kv.second);
}
}
new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent,
unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, std::move(new_loop),
NullOpt, preserved_annotations);
}
// Update producer heads in the global async states.
for (const auto& [stage_id, state] : async_states_local) {
async_states[stage_id].producer_head += extent;
}
return BlockRealize({}, Bool(true), MakeBlock(std::move(new_loop), buffer_data_to_buffer_));
}
arith::Analyzer analyzer_;
Map<Var, Buffer> buffer_data_to_buffer_;
Array<Buffer> pipeline_allocs_;
For pipeline_loop_;
PipelineInfo pipeline_info_;
int max_stage_ = -1;
Map<Buffer, Buffer> buffer_remap_;
Array<Block> ordered_stmts_;
std::map<int, AsyncStateGlobal> async_states;
};
/*!
* \brief Build the dependency graph among a array of blocks.
* \param[in] blocks The array of blocks.
* \param[out] dep_src2dst Optional, a map to store dependency edges from the source to the
* destination.
* \param[out] dep_dst2src Optional, a map to store dependency edges from the
* destination to the source.
*/
void BuildDependencyGraph(
const Array<Block>& blocks,
std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>* dep_src2dst,
std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>* dep_dst2src) {
std::unordered_map<Var, Array<Block>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;
for (const Block& block : blocks) {
for (const BufferRegion& read : block->reads) {
auto it = buffer_writers.find(read->buffer->data);
if (it != buffer_writers.end()) {
for (const Block& writer : it->second) {
if (dep_src2dst != nullptr) {
(*dep_src2dst)[writer].push_back(block);
}
if (dep_dst2src != nullptr) {
(*dep_dst2src)[block].push_back(writer);
}
}
}
}
for (const BufferRegion& write : block->writes) {
buffer_writers[write->buffer->data].push_back(block);
}
}
}
class PipelineInjector : private StmtExprMutator {
public:
static Stmt Inject(const PrimFunc& func) {
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
PipelineInjector injector(global_symbol);
for (const auto& kv : func->buffer_map) {
const Buffer& buffer = kv.second;
injector.buffer_data_to_buffer_.Set(buffer->data, buffer);
}
return injector(func->body);
}
private:
explicit PipelineInjector(Optional<String> global_symbol) : global_symbol_(global_symbol) {}
/*!
* \brief Check the pipeline satisfies the following conditions:
* 1. No conflicting order: The order of each statement should be unique.
* 2. Reordering of statements doesn't break buffer access dependencies. Specifically, for
* dependency (e.g. read-after-write) from statement A to statement B, it requires:
* case 1: stage(A) < stage(B)
* case 2: stage(A) == stage(B) and order(A) < order(B)
*/
void ValidatePipelineBody(const PipelineInfo& pipeline_info, const Array<Block>& original_order) {
std::unordered_set<int> used_orders;
std::unordered_map<int, int> stage_max_order;
std::unordered_map<int, const Block*> order_to_block;
std::unordered_map<const Block*, int> block_to_stage;
for (const Block& block : original_order) {
const auto& stmt_info = pipeline_info.at(block);
int order = stmt_info.order;
CHECK(!used_orders.count(order))
<< "ValueError: Two statements in the software pipeline cannot have the same order";
used_orders.insert(order);
}
std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual> dep_src2dst;
BuildDependencyGraph(original_order, &dep_src2dst, nullptr);
for (const auto& pair : dep_src2dst) {
const Block& src = pair.first;
const auto& src_info = pipeline_info.at(src);
const Array<Block>& dsts = pair.second;
for (const Block& dst : dsts) {
const auto& dst_info = pipeline_info.at(dst);
CHECK_LE(src_info.stage, dst_info.stage)
<< "ValueError: statement " << dst << " in stage " << dst_info.stage
<< " cannot depends on statement " << src << " in a later stage " << src_info.stage;
if (src_info.stage == dst_info.stage) {
CHECK_LT(src_info.order, dst_info.order) << "ValueError: two statements with buffer "
"access dependency in the same stage of the "
"software pipeline cannot be reordered";
}
}
}
}
Stmt VisitStmt_(const ForNode* op) final {
// Step 1: Recursively rewrite the children first.
For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op));
if (!HasPipelineAnnotation(op)) {
return std::move(for_node);
}
// Step 2: Find the body and buffer allocations of the pipeline. The body can be direct child of
// the for-loop. If the for-loop has BlockRealize as its child, the pipeline body will be the
// child of the block.
Stmt pipeline_body{nullptr};
Array<Buffer> pipeline_allocs;
if (const auto* realize = for_node->body.as<BlockRealizeNode>()) {
const auto& block = realize->block;
for (const auto& buffer : block->alloc_buffers) {
ICHECK(buffer->IsInstance<BufferNode>());
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
pipeline_body = block->body;
pipeline_allocs = block->alloc_buffers;
} else {
pipeline_body = for_node->body;
}
const SeqStmtNode* pipeline_body_seq = pipeline_body.as<SeqStmtNode>();
CHECK(pipeline_body_seq)
<< "ValueError: The body of the software pipeline should be SeqStmt, got "
<< pipeline_body->GetTypeKey();
// Step 3: Blockize the components of the pipeline. Each child of the pipelined loop will be
// converted into a block.
PipelineInfo pipeline_info;
Array<Block> original_order; // pipeline body blocks in the original order
auto f_add_child = [&](const Stmt& child) {
original_order.push_back(MakeBlock(child, buffer_data_to_buffer_));
};
for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) {
const auto* nested_block_realize = pipeline_body_seq->seq[i].as<BlockRealizeNode>();
if (nested_block_realize && is_one(nested_block_realize->predicate) &&
nested_block_realize->block->body->IsInstance<SeqStmtNode>()) {
const Block& nested_pipeline_block = nested_block_realize->block;
ICHECK(
nested_pipeline_block->match_buffers.empty()); // match_buffer should have been lowered
for (const auto& buffer : nested_pipeline_block->alloc_buffers) {
pipeline_allocs.push_back(buffer);
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
const auto* nested_seq = nested_pipeline_block->body.as<SeqStmtNode>();
for (size_t j = 0; j < nested_seq->seq.size(); j++) {
f_add_child(nested_seq->seq[j]);
}
} else {
f_add_child(pipeline_body_seq->seq[i]);
}
}
auto pipeline_stages =
Downcast<Array<Integer>>(op->annotations.at(tir::attr::software_pipeline_stage));
auto pipeline_orders =
Downcast<Array<Integer>>(op->annotations.at(tir::attr::software_pipeline_order));
CHECK_EQ(pipeline_stages.size(), original_order.size())
<< "PrimFunc " << global_symbol_ << " has original order "
<< original_order.Map([](const auto& block) { return block->name_hint; })
<< ", but pipeline annotation is " << pipeline_stages << " with different size";
CHECK_EQ(pipeline_orders.size(), original_order.size())
<< "PrimFunc " << global_symbol_ << " has original order "
<< original_order.Map([](const auto& block) { return block->name_hint; })
<< ", but pipeline annotation is " << pipeline_orders << " with different size";
std::unordered_set<int> pipeline_async_stages;
if (auto annot = op->annotations.Get(tir::attr::software_pipeline_async_stages)) {
for (auto s : Downcast<Array<Integer>>(annot)) {
pipeline_async_stages.insert(s->value);
}
}
for (size_t i = 0; i < pipeline_stages.size(); i++) {
int stage = static_cast<int>(pipeline_stages[i]->value);
bool is_async = pipeline_async_stages.find(stage) != pipeline_async_stages.end();
PipelineAnnotation stage_order{stage,
/*order=*/static_cast<int>(pipeline_orders[i]->value),
is_async};
pipeline_info.emplace(original_order[i], stage_order);
}
ValidatePipelineBody(pipeline_info, original_order);
// Step 4: Rewrite the pipeline body.
Stmt pipeline =
PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs, GetRef<For>(op), pipeline_info)
.BuildPipeline();
if (const auto* realize = op->body.as<BlockRealizeNode>()) {
const auto& block = realize->block;
for (const auto& buffer : block->alloc_buffers) {
buffer_data_to_buffer_.erase(buffer->data);
}
}
return pipeline;
}
Stmt VisitStmt_(const BlockNode* op) final {
for (const auto& buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
for (const auto& buffer : op->alloc_buffers) {
buffer_data_to_buffer_.erase(buffer->data);
}
return std::move(block);
}
bool HasPipelineAnnotation(const ForNode* op) const {
auto it1 = op->annotations.find(tir::attr::software_pipeline_stage);
auto it2 = op->annotations.find(tir::attr::software_pipeline_order);
bool has_stage = it1 != op->annotations.end();
bool has_order = it2 != op->annotations.end();
if (has_stage && has_order) {
return true;
}
if (has_stage) {
LOG(FATAL) << "ValueError: Order of the software pipeline is not defined.";
}
if (has_order) {
LOG(FATAL) << "ValueError: Stage of the software pipeline is not defined.";
}
return false;
}
Map<Var, Buffer> buffer_data_to_buffer_;
Optional<String> global_symbol_;
};
/*!
* \brief Transform annotated loops into pipelined one that parallelize producers and consumers.
* \return The IR transform pass.
*/
tir::transform::Pass InjectSoftwarePipeline() {
using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* fptr = f.CopyOnWrite();
fptr->body = PipelineInjector::Inject(f);
fptr->body = ConvertSSA(std::move(fptr->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.InjectSoftwarePipeline", {});
}
TVM_REGISTER_GLOBAL("tl.transform.InjectSoftwarePipeline")
.set_body_typed(InjectSoftwarePipeline);
} // namespace tl
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file layout_inference.cc
* \brief infer the fragment/shared memory layout
*/
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>
#include <queue>
#include "arith/ir_mutator_with_analyzer.h"
#include "../op/parallel.h"
#include "loop_partition.h"
#include "loop_vectorize.h"
#include "common/loop_fusion_utils.h"
namespace tvm {
namespace tl {
using namespace tir;
using arith::IRMutatorWithAnalyzer;
struct LayoutInferenceResult {
Map<Buffer, Layout> layout_map;
Map<For, Fragment> for_map;
Map<For, PrimExpr> predicate_map;
};
class BufferUseDefCollector : public StmtExprVisitor {
public:
BufferUseDefCollector() = default;
LayoutInferenceResult Run() {
Map<Buffer, Layout> layout_map = annotated_layout_map_;
int num_infer = infer_list_.size();
// maintain a bfs queue and infer common layout
std::queue<int> q;
std::vector<bool> in_queue(num_infer, true);
for (int i = 0; i < num_infer; i++) q.push(i);
auto run_infer_step = [&](int cur_infer_id, InferLevel level, bool update_queue) {
auto& next = infer_list_[cur_infer_id];
auto iter_var = thread_var_vec_[cur_infer_id];
auto updates = next->InferLayout(
LayoutInferArgs{target_, static_cast<size_t>(*as_const_int(iter_var->dom->extent)),
layout_map},
level);
for (const auto& [buffer, layout] : updates) {
if (layout_map.count(buffer)) {
ICHECK(StructuralEqual()(layout, layout_map[buffer]))
<< "Get different layout for " << buffer;
} else {
layout_map.Set(buffer, layout);
if (!update_queue) continue;
for (int idx : use_list_[buffer]) {
if (!in_queue[idx] && idx != cur_infer_id) {
in_queue[idx] = true;
q.push(idx);
}
}
}
}
};
auto finish_infer_queue = [&]() {
while (!q.empty()) {
int cur_infer_id = q.front();
q.pop();
in_queue[cur_infer_id] = false;
run_infer_step(cur_infer_id, InferLevel::kCommon, true);
}
};
// step 1, infer strict layout
for (int i = 0; i < num_infer; i++) {
run_infer_step(i, InferLevel::kStrict, false);
}
// step2, infer common layout with bfs
finish_infer_queue();
// step 3, relax the infer constraint to free and rerun.
for (int i = 0; i < num_infer; i++) {
run_infer_step(i, InferLevel::kFree, true);
finish_infer_queue();
}
// Check that all fragments have been inferred
for (const auto& [buffer, _] : use_list_) {
if (buffer.scope() == "local.fragment" && layout_map.count(buffer) == 0)
LOG_ERROR << "The layout for fragment " << buffer << " can not be inferred correctly.";
}
// Collect the layout for for nodes
Map<For, Fragment> for_map;
Map<For, PrimExpr> predicate_map;
for (auto& base_infer : infer_list_) {
if (auto for_infer = dynamic_cast<ParallelOp*>(base_infer.get())) {
ICHECK(for_infer->GetLoopLayout().defined())
<< "The Layout for Parallel for can not be inferred correctly : \n"
<< for_infer->GetRoot();
for_map.Set(for_infer->GetRoot(), for_infer->GetLoopLayout());
if (auto predicate = for_infer->GetPredicate(thread_var_->var))
predicate_map.Set(for_infer->GetRoot(), predicate.value());
}
}
return {layout_map, for_map, predicate_map};
}
void Collect(const PrimFunc& f) {
for (const auto& [_, buffer] : f->buffer_map) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined()) << "Layout_Inference: Require the target attribute";
target_ = target.value();
this->operator()(f->body);
}
private:
void VisitExpr_(const CallNode* op) final {
StmtExprVisitor::VisitExpr_(op);
// Do not analysis the call node to the global function.
if (op->op.as<GlobalVarNode>()) return;
auto p = ParseOperator(GetRef<Call>(op), buffer_data_to_buffer_);
if (p != nullptr) {
for (const auto& arg : op->args) {
if (auto buffer = getBufferFromAccessPtr(arg)) {
addToUseList(buffer.value());
}
}
infer_list_.push_back(std::move(p));
thread_var_vec_.push_back(thread_var_);
}
}
Optional<Buffer> getBufferFromAccessPtr(const PrimExpr& expr) {
auto call = expr.as<CallNode>();
if (call && call->op.same_as(builtin::tvm_access_ptr())) {
auto var = call->args[1].as<Var>().value();
return buffer_data_to_buffer_[var];
}
return NullOpt;
}
void addToUseList(const Buffer& buffer) {
int infer_idx = infer_list_.size();
if (use_list_.find(buffer) == use_list_.end()) {
use_list_[buffer] = {};
}
use_list_[buffer].push_back(infer_idx);
}
void VisitStmt_(const ForNode* op) final {
if (op->kind == ForKind::kParallel) {
auto infer = std::make_unique<ParallelOp>(GetRef<For>(op));
for (const auto& [buffer, _] : infer->GetIndiceMap()) {
addToUseList(buffer);
}
infer_list_.push_back(std::move(infer));
thread_var_vec_.push_back(thread_var_);
} else {
StmtExprVisitor::VisitStmt(op->body);
}
}
void VisitStmt_(const BlockNode* op) final {
for (auto buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
if (op->annotations.count(attr::kLayoutMap)) {
auto map = op->annotations.Get(attr::kLayoutMap).as<Map<Var, Layout>>().value();
for (const auto& [var, layout] : map) {
auto buffer = buffer_data_to_buffer_[var];
ICHECK(StructuralEqual()(layout->InputShape(), buffer->shape));
annotated_layout_map_.Set(buffer, layout);
}
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
ICHECK(iv->dom->extent.as<IntImmNode>());
thread_var_ = iv;
}
}
StmtExprVisitor::VisitStmt_(op);
}
Map<Var, Buffer> buffer_data_to_buffer_;
std::vector<std::unique_ptr<Operator>> infer_list_;
std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual> use_list_;
IterVar thread_var_;
std::vector<IterVar> thread_var_vec_;
Target target_;
LayoutMap annotated_layout_map_;
};
class LayoutInferencer : public IRMutatorWithAnalyzer {
public:
static PrimFunc Substitute(PrimFunc f) {
arith::Analyzer analyzer;
PrimFuncNode* fptr = f.CopyOnWrite();
fptr->body = ParallelLoopFuser::Fuse(f->body);
BufferUseDefCollector collector;
collector.Collect(f);
auto result = collector.Run();
LayoutInferencer substituter(result, &analyzer);
fptr->body = substituter.VisitStmt(f->body);
return f;
}
private:
LayoutInferencer(const LayoutInferenceResult result, arith::Analyzer* analyzer)
: arith::IRMutatorWithAnalyzer(analyzer), result_(result) {};
Stmt VisitStmt_(const BlockNode* op) final {
Block block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op));
for (auto buffer : block->alloc_buffers) {
if (buffer.scope() == "local.framgent") {
ICHECK(result_.layout_map.count(buffer))
<< "Cannot inference fragment layout for " << buffer;
}
}
auto block_ptr = block.CopyOnWrite();
block_ptr->annotations.Set(attr::kLayoutMap, result_.layout_map);
return block;
}
Stmt VisitStmt_(const ForNode* op) final {
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
if (result_.for_map.count(GetRef<For>(op))) {
auto loop_layout = result_.for_map[GetRef<For>(op)];
for_node = PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
for_node = VectorizeLoop(for_node);
if (result_.predicate_map.count(GetRef<For>(op))) {
return IfThenElse(result_.predicate_map[GetRef<For>(op)], for_node);
} else {
return for_node;
}
}
return for_node;
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
ICHECK_NE(iv->thread_tag.length(), 0U);
if (iv->thread_tag == "threadIdx.x") {
thread_var_ = iv;
}
}
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
private:
const LayoutInferenceResult result_;
IterVar thread_var_;
};
tvm::transform::Pass LayoutInference() {
using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return LayoutInferencer::Substitute(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {});
}
TVM_REGISTER_GLOBAL("tl.transform.LayoutInference")
.set_body_typed(LayoutInference);
} // namespace tl
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file layout_inference.cc
* \brief infer the fragment/shared memory layout
*/
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>
#include <queue>
#include "arith/ir_mutator_with_analyzer.h"
#include "../op/parallel.h"
#include "loop_partition.h"
#include "loop_vectorize.h"
namespace tvm {
namespace tl {
using namespace tir;
using arith::IRMutatorWithAnalyzer;
// Helper class to find leaf For nodes in a given IR
class LeafForFinder : public StmtVisitor {
public:
std::vector<For> leaf_for_nodes;
private:
void VisitStmt_(const ForNode* op) final {
has_child_for_ = false;
bool parent_has_child_for = parent_has_child_for_;
parent_has_child_for_ = false;
StmtVisitor::VisitStmt(op->body);
if (!has_child_for_) {
leaf_for_nodes.push_back(GetRef<For>(op));
}
parent_has_child_for_ = parent_has_child_for;
parent_has_child_for_ = true;
}
private:
bool has_child_for_ = false;
bool parent_has_child_for_ = false;
};
// We will create a visitor to check BufferLoad and BufferStore nodes
// within this loop body. This visitor will:
// 1. Identify BufferLoad and BufferStore nodes.
// 2. Check if the buffer is in global scope.
// 3. For each index, compare against the buffer's shape.
// If the index might exceed the shape (upper bound too large),
// log a warning or handle accordingly.
struct GlobalMemChecker : public StmtExprVisitor {
arith::Analyzer* analyzer;
explicit GlobalMemChecker(arith::Analyzer* analyzer) : analyzer(analyzer) {}
void VisitExpr_(const BufferLoadNode* op) final {
// Check if the buffer is in global scope
if (IsGlobalBuffer(op->buffer)) {
CheckBufferIndices(op->buffer, op->indices, /*is_load=*/true);
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode* op) final {
// Check if the buffer is in global scope
if (IsGlobalBuffer(op->buffer)) {
CheckBufferIndices(op->buffer, op->indices, /*is_load=*/false);
}
StmtExprVisitor::VisitStmt_(op);
}
// Helper function to determine if a buffer is global
bool IsGlobalBuffer(const Buffer& buffer) {
// The storage scope is often encoded in the buffer->data var name or associated attributes.
// In typical TVM IR, global buffers have scope "global".
// Here we assume a helper function GetPtrStorageScope is available.
// If not, you might need to parse buffer->data->name_hint or associated attributes.
String scope = buffer.scope();
return scope == "global";
}
// Check each index against the buffer shape dimensions
void CheckBufferIndices(const Buffer& buffer, const Array<PrimExpr>& indices, bool is_load) {
// Ensure indices count matches buffer dimension
if (indices.size() != buffer->shape.size()) {
LOG(WARNING) << "Buffer access dimension mismatch: indices size (" << indices.size()
<< ") vs. shape size (" << buffer->shape.size() << ")";
return;
}
for (size_t i = 0; i < indices.size(); i++) {
PrimExpr index = indices[i];
PrimExpr shape_dim = buffer->shape[i];
// We want to check if index < shape_dim can be proven.
// If analyzer->CanProve(index < shape_dim) returns false,
// it means we cannot prove the access is within bounds.
PrimExpr cond = index < shape_dim;
if (!analyzer->CanProve(cond)) {
_conditions.push_back(cond);
}
}
}
Array<PrimExpr> GetConditions() { return _conditions; }
private:
Array<PrimExpr> _conditions;
};
class SafeMemorysRewriter : public StmtExprMutator {
arith::Analyzer* analyzer_;
public:
explicit SafeMemorysRewriter(arith::Analyzer* analyzer) : analyzer_(analyzer) {}
private:
Stmt VisitStmt_(const BufferStoreNode* op) final {
// Check if the buffer is in global scope
auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
GlobalMemChecker checker(analyzer_);
checker(store);
Array<PrimExpr> conditions = checker.GetConditions();
if (conditions.size() == 0) {
return store;
}
auto value = store->value;
if (IsGlobalBuffer(store->buffer)) {
Stmt store_with_conditions = store;
for (auto cond : conditions) {
store_with_conditions = IfThenElse(cond, store_with_conditions);
}
return store_with_conditions;
} else if (isSharedBuffer(store->buffer)) {
PrimExpr value = store->value;
for (auto cond : conditions) {
value = if_then_else(cond, value, make_zero(value->dtype));
}
store.CopyOnWrite()->value = value;
return store;
}
return store;
}
// Handle Call Nodes
// For example
// T.call_extern("handle", "atomicAddx2", T.address_of(C), T.address_of(C_shared))
Stmt VisitStmt_(const EvaluateNode* op) final {
auto evaluate = Downcast<Evaluate>(StmtExprMutator::VisitStmt_(op));
auto call = Downcast<Call>(evaluate->value);
if (call.defined() && call->op == builtin::call_extern()) {
GlobalMemChecker checker(analyzer_);
checker(call);
Array<PrimExpr> conditions = checker.GetConditions();
if (conditions.size() == 0) {
return evaluate;
}
Stmt evaluate_with_conditions = evaluate;
for (auto cond : conditions) {
evaluate_with_conditions = IfThenElse(cond, evaluate_with_conditions);
}
return evaluate_with_conditions;
}
return evaluate;
}
bool isSharedBuffer(const Buffer& buffer) {
String scope = buffer.scope();
return scope == "shared" || scope == "shared.dyn";
}
bool IsGlobalBuffer(const Buffer& buffer) {
String scope = buffer.scope();
return scope == "global";
}
};
// Class to legalize safe memory access by transforming them appropriately
class SafeMemoryLegalizer : IRMutatorWithAnalyzer {
public:
// Static method to substitute and transform the given PrimFunc
static PrimFunc Substitute(PrimFunc f) {
arith::Analyzer analyzer;
// Create an instance of the legalizer with the analyzer
SafeMemoryLegalizer substituter(&analyzer);
// Get a mutable copy of the function node
PrimFuncNode* fptr = f.CopyOnWrite();
// Apply the legalizer to the function body
fptr->body = substituter.VisitStmt(f->body);
return f;
}
private:
// Constructor initializing the base class with the analyzer
SafeMemoryLegalizer(arith::Analyzer* analyzer) : arith::IRMutatorWithAnalyzer(analyzer) {}
// Override the VisitStmt_ method to handle ForNode (loop statements)
Stmt VisitStmt_(const ForNode* op) final {
// Visit and potentially modify the loop node
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
auto has_inner_loop = HasInnerLoop(for_node->body);
if (!has_inner_loop) {
SafeMemorysRewriter rewriter(analyzer_);
for_node.CopyOnWrite()->body = rewriter(for_node->body);
// // Detect Buffer Load Node in the loop body, collect the indices and buffer size
// // Run the checker on the loop body
// GlobalMemChecker checker(analyzer_);
// checker(for_node->body);
// Array<PrimExpr> conditions = checker.GetConditions();
// auto body = for_node->body;
// // Note that we might have duplicate conditions
// // Which will be optimized by simplify pass
// // Replace the loop body with the new body
// for (auto cond : conditions) {
// body = IfThenElse(cond, body);
// }
// for_node.CopyOnWrite()->body = body;
return std::move(for_node);
}
// Visit a For Node
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
static bool HasInnerLoop(const Stmt& stmt) {
LeafForFinder finder;
finder(stmt);
return finder.leaf_for_nodes.size() > 0;
}
};
// Create a pass that legalizes vectorized loops in the IRModule
tvm::transform::Pass LegalizeSafeMemoryAccess() {
using namespace tir::transform;
// Define the transformation function to be applied
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return SafeMemoryLegalizer::Substitute(std::move(f));
};
// Create and return a PrimFunc pass with the transformation function
return CreatePrimFuncPass(pass_func, 0, "tl.LegalizeSafeMemoryAccess", {});
}
// Register the pass globally so it can be used in the compilation pipeline
TVM_REGISTER_GLOBAL("tl.transform.LegalizeSafeMemoryAccess")
.set_body_typed(LegalizeSafeMemoryAccess);
} // namespace tl
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file layout_inference.cc
* \brief infer the fragment/shared memory layout
*/
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>
#include <queue>
#include "arith/ir_mutator_with_analyzer.h"
#include "../op/parallel.h"
#include "loop_partition.h"
#include "loop_vectorize.h"
namespace tvm {
namespace tl {
using namespace tir;
using arith::IRMutatorWithAnalyzer;
// Class to legalize vectorized loops by transforming them appropriately
class LoopVectorizedLegalizer : IRMutatorWithAnalyzer {
public:
// Static method to substitute and transform the given PrimFunc
static PrimFunc Substitute(PrimFunc f) {
arith::Analyzer analyzer;
// Create an instance of the legalizer with the analyzer
LoopVectorizedLegalizer substituter(&analyzer);
// Get a mutable copy of the function node
PrimFuncNode* fptr = f.CopyOnWrite();
// Apply the legalizer to the function body
fptr->body = substituter.VisitStmt(f->body);
return f;
}
private:
// Constructor initializing the base class with the analyzer
LoopVectorizedLegalizer(arith::Analyzer* analyzer) : arith::IRMutatorWithAnalyzer(analyzer) {}
// Override the VisitStmt_ method to handle ForNode (loop statements)
Stmt VisitStmt_(const ForNode* op) final {
// Visit and potentially modify the loop node
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
// If the loop is not vectorized, proceed with the default behavior
if (for_node->kind != ForKind::kVectorized) {
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
// Change the loop kind from vectorized to serial
for_node.CopyOnWrite()->kind = ForKind::kSerial;
// Apply vectorization transformation to the loop
return VectorizeLoop(std::move(for_node));
}
};
// Create a pass that legalizes vectorized loops in the IRModule
tvm::transform::Pass LegalizeVectorizedLoop() {
using namespace tir::transform;
// Define the transformation function to be applied
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return LoopVectorizedLegalizer::Substitute(std::move(f));
};
// Create and return a PrimFunc pass with the transformation function
return CreatePrimFuncPass(pass_func, 0, "tl.LegalizeVectorizedLoop", {});
}
// Register the pass globally so it can be used in the compilation pipeline
TVM_REGISTER_GLOBAL("tl.transform.LegalizeVectorizedLoop")
.set_body_typed(LegalizeVectorizedLoop);
} // namespace tl
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file loop_partition.cc
* \brief Partition parallel loops onto threads
*/
#include "loop_partition.h"
#include <tvm/tir/stmt_functor.h>
namespace tvm {
namespace tl {
using namespace tir;
class BufferIndiceSimplify : public StmtExprMutator {
public:
BufferIndiceSimplify(arith::Analyzer* analyzer) : analyzer_(analyzer) {}
private:
PrimExpr VisitExpr_(const BufferLoadNode* node) final {
auto visited = StmtExprMutator::VisitExpr_(node);
auto n = visited.as<BufferLoad>().value();
auto nptr = n.CopyOnWrite();
nptr->indices = nptr->indices.Map([&](const auto& e) { return analyzer_->Simplify(e); });
return n;
}
Stmt VisitStmt_(const BufferStoreNode* node) final {
auto visited = StmtExprMutator::VisitStmt_(node);
auto n = visited.as<BufferStore>().value();
auto nptr = n.CopyOnWrite();
nptr->indices = nptr->indices.Map([&](const auto& e) { return analyzer_->Simplify(e); });
return n;
}
arith::Analyzer* analyzer_;
};
// Rewrite the parallel loop into a common loop, which is mapped to threads
For PartitionLoop(For op, Var thread_var, arith::Analyzer* analyzer, Fragment loop_layout) {
ICHECK(loop_layout.defined());
ICHECK(thread_var.defined());
int old_loop_depth = loop_layout->InputDim();
int new_loop_depth = loop_layout->OutputDim();
// Create the new loop iter var
Array<Var> vars;
for (int i = 0; i < new_loop_depth; i++) {
Var var = Var(std::string{char('i' + i)});
vars.push_back(var);
}
vars.push_back(thread_var);
// create the substitute map, and the loop body
Map<Var, PrimExpr> vmap;
Stmt body = op;
auto inv_loop = loop_layout->Inverse();
auto indices = inv_loop->Forward(vars.Map([](const Var& v) { return PrimExpr(v); }));
for (int i = 0; i < old_loop_depth; i++) {
ICHECK(body.as<For>().defined());
For loop = body.as<For>().value();
vmap.Set(loop->loop_var, indices[i]);
body = loop->body;
}
// substitute and re-construct the serial loop
body = Substitute(body, vmap);
for (int i = new_loop_depth - 1; i >= 0; i--) {
body =
For(vars[i], make_zero(vars[i]->dtype), inv_loop->InputShape()[i], ForKind::kSerial, body);
analyzer->Bind(vars[i], Range(0, inv_loop->InputShape()[i]));
}
body = BufferIndiceSimplify(analyzer)(body);
auto for_node = LoopPragmaUnroll(Downcast<For>(body));
return for_node;
}
class LoopPramaUnroller : public StmtExprMutator {
public:
LoopPramaUnroller() = default;
private:
Stmt VisitStmt_(const ForNode* node) final {
if (node->kind == ForKind::kSerial) {
For new_for = GetRef<For>(node);
auto for_ptr = new_for.CopyOnWrite();
for_ptr->annotations.Set(tir::attr::pragma_unroll_explicit, Bool(false));
for_ptr->kind = ForKind::kUnrolled;
return new_for;
}
return StmtExprMutator::VisitStmt_(node);
}
};
class LoopPartitioner : public StmtExprVisitor {
public:
LoopPartitioner() = default;
Fragment Partition(For op, int num_thread, int vectorize_size) {
this->VisitStmt(op);
int loop_size_full = 1;
PrimExpr flattened = 0;
for (size_t i = 0; i < loop_vars_.size(); i++) {
auto ext_ptr = as_const_int(loop_vars_[i]->dom->extent);
ICHECK(ext_ptr);
int extent = *ext_ptr;
loop_size_full *= extent;
flattened = flattened * extent + loop_vars_[i]->var;
}
ICHECK(loop_size_full % vectorize_size == 0);
PrimExpr access_idx = FloorDiv(flattened, vectorize_size);
PrimExpr thd = FloorMod(access_idx, num_thread);
PrimExpr idx =
FloorDiv(access_idx, num_thread) * vectorize_size + FloorMod(flattened, vectorize_size);
return Fragment(loop_vars_, {idx}, {thd}, {});
}
private:
void VisitStmt_(const ForNode* node) final {
if (node->kind == ForKind::kParallel) {
body_ = node->body;
loop_vars_.push_back(IterVar(Range::FromMinExtent(node->min, node->extent), node->loop_var,
IterVarType::kDataPar));
}
StmtExprVisitor::VisitStmt_(node);
}
Stmt body_;
PrimExpr flattened = 0;
Array<IterVar> loop_vars_;
};
Fragment PlanLoopPartition(For op, size_t num_thread, int vectorize_size) {
LoopPartitioner partitioner;
return partitioner.Partition(op, num_thread, vectorize_size);
}
For LoopPragmaUnroll(For stmt) {
LoopPramaUnroller unroller;
For unrolled = Downcast<For>(unroller(stmt));
return unrolled;
}
} // namespace tl
} // namespace tvm
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file loop_partition.h
* \brief Partition parallel loops onto threads
*/
#ifndef TVM_TL_LOOP_PARTITION_H_
#define TVM_TL_LOOP_PARTITION_H_
#include <tvm/tir/op.h>
#include "../layout/layout.h"
namespace tvm {
namespace tl {
using namespace tir;
For PartitionLoop(For op, Var thread_var, arith::Analyzer* analyzer, Fragment loop_layout);
Fragment PlanLoopPartition(For op, size_t num_thread, int vectorize_size);
For LoopPragmaUnroll(For stmt);
} // namespace tl
} // namespace tvm
#endif // TVM_TL_LOOP_PARTITION_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file loop_vectorize.cc
* \brief A tool to automatically vectorize a for loop
*/
#include "loop_vectorize.h"
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <numeric>
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "../layout/layout.h"
#include "../layout/utils.h"
#include "common/loop_vectorization_utils.h"
namespace tvm {
namespace tl {
using namespace tir;
struct VectorizePlanResult {
int vector_size;
bool dynamic;
PrimExpr condition;
};
class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
public:
VectorizePlanner() = default;
int Plan(const For& node) {
this->operator()(node);
// Always Enable vectorization
// if (!has_nonlocal_memory_access_) return 1;
return vector_size_;
}
bool GetDynamic() { return dynamic_; }
PrimExpr GetCondition() { return condition_; }
private:
void VisitStmt_(const ForNode* node) final {
inner_for_ = node;
iter_map_.Set(node->loop_var, Range(node->min, node->extent));
arith::IRVisitorWithAnalyzer::VisitStmt_(node);
}
void VisitExpr_(const BufferLoadNode* node) final {
if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
node->buffer.scope() == "shared.dyn")
has_nonlocal_memory_access_ = true;
if (node->buffer->shape.size() == 1 && node->buffer->shape[0].as<IntImmNode>()->value == 1) {
// TODO(lei): This should be improved as
// constant buffer that tl hack to use as local register.
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
}
UpdateVectorSize(node->indices, node->buffer);
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
}
void VisitStmt_(const BufferStoreNode* node) final {
if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
node->buffer.scope() == "shared.dyn")
has_nonlocal_memory_access_ = true;
UpdateVectorSize(node->indices, node->buffer);
return arith::IRVisitorWithAnalyzer::VisitStmt_(node);
}
void VisitStmt_(const IfThenElseNode* node) final {
CheckConditionVectorized(node->condition);
return arith::IRVisitorWithAnalyzer::VisitStmt_(node);
}
void VisitExpr_(const CallNode* node) final {
if (node->op == builtin::if_then_else()) {
CheckConditionVectorized(node->args[0]);
} else if (node->op == builtin::call_extern()) {
// do not vectorize extern calls
vector_size_ = 1;
}
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
}
void CheckConditionVectorized(const PrimExpr& cond) {
// TODO: perform some checks here
}
void UpdateVectorSize(const Array<PrimExpr> indices, const Buffer& buffer) {
if (!inner_for_) return;
auto extent_ptr = inner_for_->extent.as<IntImmNode>();
if (!extent_ptr) return;
const DataType& access_type = buffer->dtype;
// i // 2, i % 8 can also be vectorized as factor 16
int max_vector_size = 128 / access_type.bits();
// so we should disable this GCD optimization
max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);
auto last_dim = buffer->shape.back();
auto mod_set = analyzer_.modular_set(last_dim);
// when dynamic shape like [m, k]: coeff=1, base=0, GCD will block conditionally tail vectorize
if (buffer->shape.back().as<IntImmNode>()) {
max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff);
auto gcd_base = arith::ZeroAwareGCD(max_vector_size, mod_set->base);
// If gcd_base is equal to the last dimension,
// we should analyze the second-to-last dimension
// in relation to the last dimension.
if (gcd_base < Downcast<IntImm>(last_dim)->value) {
max_vector_size = gcd_base;
}
vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);
PrimExpr elem_offset = 0;
PrimExpr stride = 1;
for (int i = indices.size() - 1; i >= 0; --i) {
elem_offset = elem_offset + indices[i] * stride;
stride = stride * buffer->shape[i];
}
while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var, inner_for_->extent,
vector_size_, &analyzer_)) {
vector_size_ /= 2;
}
} else if (vector_size_ <= vector_load_bits_max_ / buffer->dtype.bits()) {
// dynamic shape load: get the vectorization condition
dynamic_ = true;
PrimExpr offset = buffer.OffsetOf(indices).back();
condition_ = (FloorMod(offset, vector_size_) == 0);
}
}
static const int vector_load_bits_max_ = 128;
const ForNode* inner_for_;
Map<Var, Range> iter_map_;
bool has_nonlocal_memory_access_ = false;
int vector_size_ = 128;
// conditionally vectorize
bool dynamic_ = false;
PrimExpr condition_;
};
class VectorizeDynamicCallRemover : public StmtExprMutator {
public:
VectorizeDynamicCallRemover(Var inner_var, int vector_size)
: inner_var_(inner_var), vector_size_(vector_size) {}
private:
PrimExpr VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::if_then_else())) {
PrimExpr cond = this->VisitExpr(op->args[0]);
Map<Var, PrimExpr> vmap;
// Currently remove upper bound check
vmap.Set(inner_var_, 0);
cond = Substitute(cond, vmap);
Array<PrimExpr> new_args{cond, op->args[1], op->args[2]};
return Call(op->dtype, op->op, new_args, op->span);
} else {
// TODO: For other calls
return GetRef<PrimExpr>(op);
}
}
Var inner_var_;
int vector_size_;
};
class VectorizeRewriter : public StmtExprMutator {
public:
VectorizeRewriter(VectorizePlanResult plan)
: vector_size_(plan.vector_size), condition_(plan.condition), dynamic_(plan.dynamic) {}
private:
Stmt VisitStmt_(const ForNode* node) final {
inner_for_ = node;
auto ret = StmtExprMutator::VisitStmt_(node);
if (inner_for_ == node) { // rewrite the innermost loop
For fnode = ret.as<For>().value();
auto old_var = fnode->loop_var;
auto extent_ptr = as_const_int(fnode->extent);
ICHECK(extent_ptr) << fnode->extent;
int extent = *extent_ptr;
ICHECK(extent % vector_size_ == 0)
<< "extent: " << extent << " vector_size_: " << vector_size_;
ICHECK(is_zero(fnode->min));
if (!dynamic_) { // check dynamic shape
if (extent == vector_size_) {
fnode.CopyOnWrite()->kind = ForKind::kVectorized;
return fnode;
} else {
Var inner_var = Var("vec");
Var outer_var = Var(old_var->name_hint);
Map<Var, PrimExpr> vmap;
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
Stmt body = Substitute(fnode->body, vmap);
body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body, fnode->thread_binding,
fnode->annotations, fnode->span);
return body;
}
} else {
Var inner_var = Var("vec");
Var outer_var = Var(old_var->name_hint);
Map<Var, PrimExpr> vmap;
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
Stmt body = Substitute(fnode->body, vmap);
// add condition ifthenelse here
Map<Var, PrimExpr> vmap_condition;
vmap_condition.Set(fnode->loop_var, outer_var * vector_size_);
PrimExpr condition = Substitute(condition_, vmap_condition);
VectorizeDynamicCallRemover remover(inner_var, vector_size_);
body = remover(body);
For vectorize_for = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
For serial_for = For(inner_var, 0, vector_size_, ForKind::kSerial, body);
body = IfThenElse(condition, vectorize_for, serial_for);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body, fnode->thread_binding,
fnode->annotations, fnode->span);
return body;
}
} else {
return ret;
}
}
const ForNode* inner_for_;
const int vector_size_;
const PrimExpr condition_;
const bool dynamic_;
};
int GetVectorizeSize(const For& loop) { return VectorizePlanner().Plan(loop); }
VectorizePlanResult GetVectorizePlanResult(const For& loop) {
VectorizePlanner planner;
int vector_size = planner.Plan(loop);
bool dynamic = planner.GetDynamic();
PrimExpr condition = planner.GetCondition();
return {vector_size, dynamic, condition};
}
bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size, int target_vectorized_size,
arith::Analyzer* analyzer) {
ICHECK(target_vectorized_size >= 1);
if (target_vectorized_size == 1) return true;
if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_vectorized_size), 0)) return false;
Var v0("v0"), v1("v1");
analyzer->Bind(v0, Range(0, target_vectorized_size));
analyzer->Bind(v1, Range(0, FloorDiv(iter_var_size, target_vectorized_size)));
PrimExpr expr_transformed =
analyzer->Simplify(Substitute(expr, {{var, v0 + v1 * target_vectorized_size}}));
Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size));
PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed);
auto ramp_node = expr_vectorized.as<RampNode>();
if (!ramp_node) {
// Broadcast value
if (expr_vectorized.dtype().lanes() == 1)
return true;
else
return false;
} else {
return is_one(ramp_node->stride);
}
}
For VectorizeLoop(const For& loop, int vectorize_hint) {
VectorizePlanResult res{128, false, 0};
if (vectorize_hint <= 0) {
res = GetVectorizePlanResult(loop);
vectorize_hint = res.vector_size;
}
if (vectorize_hint == 1) return loop;
auto rewriter = VectorizeRewriter(res);
return Downcast<For>(rewriter(loop));
}
} // namespace tl
} // namespace tvm
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file loop_vectorize.h
* \brief A tool to automatically vectorize a for loop
*/
#ifndef TVM_TL_LOOP_VECTORIZE_H_
#define TVM_TL_LOOP_VECTORIZE_H_
#include <tvm/arith/analyzer.h>
#include <tvm/tir/op.h>
namespace tvm {
namespace tl {
using namespace tir;
int GetVectorizeSize(const For& loop);
For VectorizeLoop(const For& loop, int vectorize_hint = -1);
bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size, int target_vectorized_size,
arith::Analyzer* analyzer);
} // namespace tl
} // namespace tvm
#endif // TVM_TL_LOOP_VECTORIZE_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file lower hopper intrin.cc
* \brief Lower Hopper intrinsics cuda GPU(sm90+)
*/
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../op/builtin.h"
#include "../op/bulk_copy.h"
#include "../runtime/runtime.h"
namespace tvm {
namespace tl {
using namespace tir;
class LowerHopperIntrin : public StmtExprMutator {
public:
static PrimFunc Substitute(PrimFunc& f) {
PrimFuncNode* fptr = f.CopyOnWrite();
LowerHopperIntrin substituter;
fptr->body = substituter.VisitStmt(f->body);
for (auto [call, var] : substituter.desc_map_) {
// Should allocate 128 bytes for TensorMap on stack
Call alloc_desc =
Call(DataType::Handle(), builtin::tvm_stack_alloca(), {StringImm("arg_value"), 16});
Array<PrimExpr> init_desc_args;
if (call->op.same_as(CreateTMADescriptorOp())) {
init_desc_args.push_back(StringImm(tvm_tensormap_create_tiled));
} else if (call->op.same_as(CreateTMAIm2ColDescriptorOp())) {
init_desc_args.push_back(StringImm(tvm_tensormap_create_im2col));
} else {
CHECK(0) << call->op;
}
init_desc_args.push_back(var);
init_desc_args.insert(init_desc_args.end(), call->args.begin(), call->args.end());
Call init_desc = Call(DataType::Handle(), builtin::tvm_call_packed(), init_desc_args);
fptr->body = LetStmt(var, alloc_desc, SeqStmt({Evaluate(init_desc), fptr->body}));
}
return f;
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
// Insert the prefetch TMA descriptor statement TO the beginning of the kernel
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
auto body = StmtExprMutator::VisitStmt(op->body);
if (prefetch_calls_.empty() && init_mbarrier_calls_.empty()) {
return AttrStmt(op->node, op->attr_key, op->value, body);
} else {
Array<Stmt> stmt_seq;
if (!init_mbarrier_calls_.empty()) {
auto alloc_mbarrier = Evaluate(Call(DataType::Handle(), builtin::create_barriers(),
{static_cast<int>(init_mbarrier_calls_.size())}));
stmt_seq.push_back(alloc_mbarrier);
}
auto stmts = prefetch_calls_;
stmts.insert(stmts.end(), init_mbarrier_calls_.begin(), init_mbarrier_calls_.end());
auto init_stmt = IfThenElse(EQ(iv->var, 0), stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]);
stmt_seq.push_back(init_stmt);
if (!init_mbarrier_calls_.empty()) {
Stmt mem_sync = Evaluate(
Call(DataType::Handle(), builtin::tvm_storage_sync(), {StringImm("shared")}));
stmt_seq.push_back(mem_sync);
}
stmt_seq.push_back(body);
prefetch_calls_.clear();
init_mbarrier_calls_.clear();
return AttrStmt(op->node, op->attr_key, op->value, SeqStmt(stmt_seq));
}
}
}
return StmtExprMutator::VisitStmt_(op);
}
PrimExpr VisitExpr_(const CallNode* call) final {
if (call->op.same_as(CreateTMADescriptorOp()) ||
call->op.same_as(CreateTMAIm2ColDescriptorOp())) {
Var var;
auto iter = desc_map_.find(GetRef<Call>(call));
if (iter != desc_map_.end()) {
var = iter->second;
} else {
String name = call->args[2].as<Var>().value()->name_hint;
var = Var(name + "_desc", PointerType(PrimType(cuTensorMapType()), "grid_constant"));
desc_map_[GetRef<Call>(call)] = var;
prefetch_calls_.push_back(Evaluate(Call(DataType::Handle(), builtin::call_extern(),
{StringImm("tl::prefetch_tma_descriptor"), var})));
}
return var;
} else if (call->op.same_as(CreateListofMBarrierOp())) {
ICHECK(init_mbarrier_calls_.size() == 0);
int num_barriers = static_cast<int>(call->args.size());
for (int i = 0; i < num_barriers; i++) {
PrimExpr mbarrier = Call(DataType::Handle(), GetMBarrierOp(), {i});
init_mbarrier_calls_.push_back(
Evaluate(Call(DataType::Handle(), builtin::ptx_init_barrier_thread_count(),
{mbarrier, call->args[i]})));
}
return 0;
} else if (call->op.same_as(SyncThreadsPartialOp())) {
int barrier_id = init_mbarrier_calls_.size();
PrimExpr mbarrier = Call(DataType::Handle(), GetMBarrierOp(), {barrier_id});
init_mbarrier_calls_.push_back(
Evaluate(Call(DataType::Handle(), builtin::ptx_init_barrier_thread_count(),
{mbarrier, call->args[0]})));
return Call(DataType::Handle(), SyncThreadsPartialOp(), {mbarrier});
} else {
return StmtExprMutator::VisitExpr_(call);
}
}
private:
Array<Stmt> prefetch_calls_;
Array<Stmt> init_mbarrier_calls_;
std::unordered_map<Call, Var, StructuralHash, ExprDeepEqual> desc_map_;
LowerHopperIntrin() = default;
};
using namespace tir::transform;
tvm::transform::Pass LowerHopperIntrin() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return LowerHopperIntrin::Substitute(f);
};
return CreatePrimFuncPass(pass_func, 0, "tl.LowerHopperIntrin", {});
}
TVM_REGISTER_GLOBAL("tl.transform.LowerHopperIntrin")
.set_body_typed(LowerHopperIntrin);
} // namespace tl
} // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment