Commit 01a10755 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

2.5.2-dtk24.04

parent 63eb0da5
......@@ -85,7 +85,11 @@
#
# go_library(example SHARED)
#
# To build a unit test binary, which is an executable binary with libpaddle.so
# automatically linked:
#
# paddle_test(example SRCS example_test.cc)
#
# including binary directory for generated headers.
include_directories(${CMAKE_CURRENT_BINARY_DIR})
# including io directory for inference lib paddle_api.h
......@@ -118,6 +122,19 @@ function(find_fluid_modules TARGET_NAME)
endif()
endfunction()
# NOTE(Aurelius84): NOT_INFER_MODULES is used to tag
# and not considered as DEPS for inference libs.
set_property(GLOBAL PROPERTY NOT_INFER_MODULES "")
function(ignore_infer_modules TARGET_NAME)
get_property(not_infer_modules GLOBAL PROPERTY NOT_INFER_MODULES)
list(FIND not_infer_modules TARGET_NAME is_found)
if(is_found EQUAL -1) # NOT FOUND
set(not_infer_modules ${not_infer_modules} ${TARGET_NAME})
set_property(GLOBAL PROPERTY NOT_INFER_MODULES "${not_infer_modules}")
endif()
endfunction()
set_property(GLOBAL PROPERTY PHI_MODULES "")
# find all phi modules is used for paddle static library
# for building inference libs
......@@ -335,7 +352,15 @@ function(check_coverage_opt TARGET_NAME SRCS)
endfunction()
function(cc_library TARGET_NAME)
set(options STATIC static SHARED shared INTERFACE interface)
set(options
STATIC
static
SHARED
shared
INTERFACE
interface
NOT_FOR_INFER
not_for_infer)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(cc_library "${options}" "${oneValueArgs}"
......@@ -347,6 +372,9 @@ function(cc_library TARGET_NAME)
CACHE STRING "output library name for target ${TARGET_NAME}")
endif()
if(cc_library_SRCS)
if(cc_library_NOT_FOR_INFER OR cc_library_not_for_infer)
ignore_infer_modules(${TARGET_NAME})
endif()
if(cc_library_SHARED OR cc_library_shared) # build *.so
add_library(${TARGET_NAME} SHARED ${cc_library_SRCS})
elseif(cc_library_INTERFACE OR cc_library_interface)
......@@ -442,6 +470,7 @@ function(cc_test_build TARGET_NAME)
list(REMOVE_ITEM cc_test_DEPS python)
target_link_libraries(${TARGET_NAME} ${PYTHON_LIBRARIES})
endif()
target_compile_definitions(${TARGET_NAME} PUBLIC STATIC_PADDLE)
endif()
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
target_link_libraries(${TARGET_NAME} ${cc_test_DEPS}
......@@ -469,12 +498,15 @@ function(cc_test_run TARGET_NAME)
NAME ${TARGET_NAME}
COMMAND ${cc_test_COMMAND} ${cc_test_ARGS}
WORKING_DIRECTORY ${cc_test_DIR})
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT
FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT
FLAGS_init_allocated_mem=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT
FLAGS_cudnn_deterministic=true)
set_property(
TEST ${TARGET_NAME}
PROPERTY
ENVIRONMENT
FLAGS_cpu_deterministic=true
FLAGS_init_allocated_mem=true
FLAGS_cudnn_deterministic=true
"LD_LIBRARY_PATH=$ENV{LD_LIBRARY_PATH}:${PADDLE_BINARY_DIR}/python/paddle/libs:${PADDLE_BINARY_DIR}/python/paddle/base"
)
# No unit test should exceed 2 minutes.
if(WIN32)
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 150)
......@@ -554,6 +586,62 @@ function(cc_test_old TARGET_NAME)
endif()
endfunction()
function(paddle_test_build TARGET_NAME)
if(WITH_TESTING)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(paddle_test "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN})
add_executable(${TARGET_NAME} ${paddle_test_SRCS})
get_property(paddle_lib GLOBAL PROPERTY PADDLE_LIB_NAME)
target_link_libraries(${TARGET_NAME} $<TARGET_LINKER_FILE:${paddle_lib}>
${paddle_test_DEPS} paddle_gtest_main_new)
add_dependencies(${TARGET_NAME} ${paddle_lib} ${paddle_test_DEPS}
paddle_gtest_main_new)
if(WITH_SHARED_PHI)
target_link_libraries(${TARGET_NAME} $<TARGET_LINKER_FILE:phi>)
add_dependencies(${TARGET_NAME} phi)
endif()
if(WITH_SHARED_IR)
target_link_libraries(${TARGET_NAME} $<TARGET_LINKER_FILE:pir>)
add_dependencies(${TARGET_NAME} pir)
endif()
if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
target_link_libraries(${TARGET_NAME} ${PYTHON_LIBRARIES})
endif()
if(WITH_CINN AND NOT CINN_ONLY)
target_link_libraries(${TARGET_NAME} $<TARGET_LINKER_FILE:cinnapi>)
add_dependencies(${TARGET_NAME} cinnapi)
endif()
if(WITH_XPU)
target_link_libraries(${TARGET_NAME} xpulib)
endif()
if(WITH_ROCM)
target_link_libraries(${TARGET_NAME} ${ROCM_HIPRTC_LIB})
endif()
if(APPLE)
target_link_libraries(
${TARGET_NAME}
"-Wl,-rpath,$<TARGET_FILE_DIR:${paddle_lib}> -Wl,-rpath,$<TARGET_FILE_DIR:phi> -Wl,-rpath,$<TARGET_FILE_DIR:pir>"
)
endif()
common_link(${TARGET_NAME})
check_coverage_opt(${TARGET_NAME} ${paddle_test_SRCS})
endif()
endfunction()
function(paddle_test TARGET_NAME)
if(WITH_TESTING)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS ARGS)
cmake_parse_arguments(paddle_test "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN})
paddle_test_build(${TARGET_NAME} SRCS ${paddle_test_SRCS} DEPS
${paddle_test_DEPS})
cc_test_run(${TARGET_NAME} COMMAND ${TARGET_NAME} ARGS ${paddle_test_ARGS})
endif()
endfunction()
function(nv_library TARGET_NAME)
if(WITH_GPU)
set(options STATIC static SHARED shared)
......@@ -640,6 +728,7 @@ function(nv_test TARGET_NAME)
# 2. cuda_add_executable does not support ccache.
# Reference: https://cmake.org/cmake/help/v3.10/module/FindCUDA.html
add_executable(${TARGET_NAME} ${nv_test_SRCS})
target_compile_definitions(${TARGET_NAME} PUBLIC STATIC_PADDLE)
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
target_link_libraries(${TARGET_NAME} ${nv_test_DEPS}
${os_dependency_modules} paddle_gtest_main phi)
......@@ -770,7 +859,7 @@ function(hip_test TARGET_NAME)
TEST ${TARGET_NAME}
PROPERTY
ENVIRONMENT
"LD_LIBRARY_PATH=${CMAKE_BINARY_DIR}/python/paddle/libs:$LD_LIBRARY_PATH"
"LD_LIBRARY_PATH=${CMAKE_BINARY_DIR}/python/paddle/libs:$ENV{LD_LIBRARY_PATH}"
)
endif()
endfunction()
......@@ -1064,7 +1153,21 @@ function(py_proto_compile TARGET_NAME)
"${multiValueArgs}" ${ARGN})
set(py_srcs)
protobuf_generate_python(py_srcs ${py_proto_compile_SRCS})
add_custom_target(${TARGET_NAME} ALL DEPENDS ${py_srcs} protobuf)
add_custom_target(${TARGET_NAME}_replace DEPENDS ${py_srcs})
foreach(py_src ${py_srcs})
add_custom_command(
TARGET ${TARGET_NAME}_replace
POST_BUILD
COMMAND ${PYTHON_EXECUTABLE} ${PADDLE_SOURCE_DIR}/cmake/replace_string.py
${py_src}
COMMENT
"Replacing 'paddle.fluid' with 'paddle.base' generated by protobuf"
COMMENT "Replace ${py_src}")
endforeach()
add_custom_target(${TARGET_NAME} ALL DEPENDS protobuf ${TARGET_NAME}_replace)
endfunction()
function(py_test TARGET_NAME)
......
......@@ -85,8 +85,11 @@ find_package_and_include(rocsparse)
find_package_and_include(rocfft)
# set CXX flags for HIP
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -D__HIP_PLATFORM_HCC__")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__HIP_PLATFORM_HCC__")
set(CMAKE_C_FLAGS
"${CMAKE_C_FLAGS} -D__HIP_PLATFORM_HCC__ -DROCM_NO_WRAPPER_HEADER_WARNING")
set(CMAKE_CXX_FLAGS
"${CMAKE_CXX_FLAGS} -D__HIP_PLATFORM_HCC__ -DROCM_NO_WRAPPER_HEADER_WARNING"
)
set(CMAKE_CXX_FLAGS
"${CMAKE_CXX_FLAGS} -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP")
set(THRUST_DEVICE_SYSTEM THRUST_DEVICE_SYSTEM_HIP)
......@@ -96,7 +99,7 @@ list(APPEND HIP_CXX_FLAGS -fPIC)
list(APPEND HIP_CXX_FLAGS -D__HIP_PLATFORM_HCC__=1)
# Note(qili93): HIP has compile conflicts of float16.h as platform::float16 overload std::is_floating_point and std::is_integer
list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF_CONVERSIONS__=1)
list(APPEND HIP_CXX_FLAGS -mllvm -amdgpu-enable-flat-scratch=false)
list(APPEND HIP_CXX_FLAGS -DROCM_NO_WRAPPER_HEADER_WARNING)
list(APPEND HIP_CXX_FLAGS -Wno-macro-redefined)
list(APPEND HIP_CXX_FLAGS -Wno-inconsistent-missing-override)
list(APPEND HIP_CXX_FLAGS -Wno-exceptions)
......@@ -115,6 +118,11 @@ list(APPEND HIP_CXX_FLAGS -Wno-unused-value)
list(APPEND HIP_CXX_FLAGS -Wno-braced-scalar-init)
list(APPEND HIP_CXX_FLAGS -Wno-return-type)
list(APPEND HIP_CXX_FLAGS -Wno-pragma-once-outside-header)
list(APPEND HIP_CXX_FLAGS -Wno-deprecated-builtins)
list(APPEND HIP_CXX_FLAGS -Wno-switch)
list(APPEND HIP_CXX_FLAGS -Wno-literal-conversion)
list(APPEND HIP_CXX_FLAGS -Wno-constant-conversion)
list(APPEND HIP_CXX_FLAGS -Wno-defaulted-function-deleted)
if(WITH_CINN)
list(APPEND HIP_CXX_FLAGS -std=c++14)
......@@ -135,10 +143,10 @@ set(HIP_CLANG_FLAGS ${HIP_CXX_FLAGS})
# host linker to link.
list(APPEND HIP_HCC_FLAGS -fno-gpu-rdc)
list(APPEND HIP_HCC_FLAGS --offload-arch=gfx906)
list(APPEND HIP_HCC_FLAGS --offload-arch=gfx926) # gfx926 for DCU
list(APPEND HIP_HCC_FLAGS --offload-arch=gfx926)
list(APPEND HIP_CLANG_FLAGS -fno-gpu-rdc)
list(APPEND HIP_CLANG_FLAGS --offload-arch=gfx906)
list(APPEND HIP_CLANG_FLAGS --offload-arch=gfx926) # gfx926 for DCU
list(APPEND HIP_CLANG_FLAGS --offload-arch=gfx926)
if(HIP_COMPILER STREQUAL clang)
set(hip_library_name amdhip64)
......
......@@ -268,13 +268,11 @@ else()
SRCS ${paddle_phi_lib}
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/lib)
endif()
if(WITH_SHARED_IR)
set(paddle_ir_lib ${PADDLE_BINARY_DIR}/paddle/ir/libir.*)
copy(
inference_lib_dist
SRCS ${paddle_ir_lib}
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/lib)
endif()
set(paddle_common_lib ${PADDLE_BINARY_DIR}/paddle/common/libcommon.*)
copy(
inference_lib_dist
SRCS ${paddle_common_lib}
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/lib)
endif()
copy(
......@@ -336,11 +334,26 @@ copy(
inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/utils/flat_hash_map.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/utils/)
copy(
inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/utils/flags.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/utils/)
copy(
inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/utils/test_macros.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/utils/)
copy(
inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/extension.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/)
if(NOT WITH_GFLAGS)
copy(
inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/utils/flags_native.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/utils/)
endif()
# the include path of phi needs to be changed to adapt to inference api path
add_custom_command(
TARGET inference_lib_dist
......
......@@ -684,6 +684,9 @@ function(prune_pybind_h)
list(APPEND op_list "load_combine")
list(APPEND op_list "tensorrt_engine")
# TODO(ming1753): conditional_block_infer is temporarily reserved here to avoid link errors in functions of standalone_executor
list(APPEND op_list "conditional_block_infer")
# add fused_op in op_list
list(APPEND op_list "fc")
list(APPEND op_list "conv2d_fusion")
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
def main():
src = sys.argv[1]
with open(src, 'r') as file:
content = file.read()
new_content = content.replace('paddle.fluid', 'paddle.base')
with open(src, 'w') as file:
file.write(new_content)
if __name__ == "__main__":
main()
......@@ -247,6 +247,14 @@ if(NOT DEFINED WITH_MKLDNN)
endif()
endif()
if(WIN32)
if(MSVC)
if(MSVC_VERSION LESS 1920)
set(WITH_MKLDNN OFF)
endif()
endif()
endif()
if(WIN32
OR APPLE
OR NOT WITH_GPU
......@@ -264,6 +272,17 @@ include(external/gflags) # download, build, install gflags
include(external/glog) # download, build, install glog
########################### include third_party according to flags ###############################
if(WITH_GPU
AND NOT WITH_ARM
AND NOT WIN32
AND NOT APPLE)
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.0)
include(external/cutlass) # download, build, install cusparselt
list(APPEND third_party_deps extern_cutlass)
set(WITH_CUTLASS ON)
endif()
endif()
if(WITH_CINN)
if(WITH_MKL)
add_definitions(-DCINN_WITH_MKL_CBLAS)
......@@ -375,6 +394,10 @@ if(WITH_GPU)
if(${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0)
include(external/cub) # download cub
list(APPEND third_party_deps extern_cub)
elseif(${CMAKE_CUDA_COMPILER_VERSION} EQUAL 12.0
OR ${CMAKE_CUDA_COMPILER_VERSION} GREATER 12.0)
include(external/cccl)
add_definitions(-DPADDLE_WITH_CCCL)
endif()
set(URL
"https://paddlepaddledeps.bj.bcebos.com/externalErrorMsg_20210928.tar.gz"
......@@ -539,15 +562,16 @@ if(WITH_CUSPARSELT)
list(APPEND third_party_deps extern_cusparselt)
endif()
if(WITH_ROCM)
include(external/flashattn)
list(APPEND third_party_deps extern_flashattn)
set(WITH_FLASHATTN ON)
endif()
if(WITH_GPU
AND NOT WITH_ARM
AND NOT WIN32
AND NOT APPLE)
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.0)
include(external/cutlass) # download, build, install cusparselt
list(APPEND third_party_deps extern_cutlass)
set(WITH_CUTLASS ON)
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.4)
foreach(arch ${NVCC_ARCH_BIN})
if(${arch} GREATER_EQUAL 80)
......
set(PYTHON_TESTS_DIR
${PADDLE_BINARY_DIR}/python/paddle/fluid/tests
${PADDLE_BINARY_DIR}/python/paddle/base/tests
CACHE INTERNAL "python tests directory")
add_subdirectory(utils)
add_subdirectory(ir)
add_subdirectory(common)
add_subdirectory(pir)
add_subdirectory(scripts)
add_subdirectory(testing)
add_subdirectory(phi)
......
if(WITH_TESTING)
cinn_cc_library(cinn_gtest_main SRCS gtest_main.cc DEPS gtest gflags)
cinn_cc_library(cinn_gtest_main SRCS gtest_main.cc DEPS gtest ${flags_dep})
endif()
add_subdirectory(adt)
add_subdirectory(api)
add_subdirectory(ast_gen_ius)
add_subdirectory(auto_schedule)
add_subdirectory(common)
add_subdirectory(utils)
......
```
___ ___ ___
/\__\ /\ \ /\ \
/:/ / ___ \:\ \ \:\ \
/:/ / /\__\ \:\ \ \:\ \
/:/ / ___ /:/__/ _____\:\ \ _____\:\ \
/:/__/ /\__\/::\ \ /::::::::\__\/::::::::\__\
\:\ \ /:/ /\/\:\ \__\:\~~\~~\/__/\:\~~\~~\/__/
\:\ /:/ / \:\/\__\\:\ \ \:\ \
\:\/:/ / \::/ / \:\ \ \:\ \
\::/ / /:/ / \:\__\ \:\__\
\/__/ \/__/ \/__/ \/__/
```
# CINN : Compiler Infrastructure for Neural Networks
The project CINN is a machine learning compiler and executor for multiple hardware backends.
It is designed to provide multiple layers of APIs to make tensor computation easier to define, faster to execute, and more convenient to extend with hardware backends.
Currently, it targets x86 CPUs and Nvidia GPUs.
This project is under active development.
## How it works
The CINN lowers a traditional DNN model into a two-level intermediate representation(IR), the high-level IR(HLIR) and CINN IR.
The HLIR helps to define some domain-specific computation and perform some overall optimization on the IR-graph;
the CINN IR helps to represent some computation semantic and finally lower to a hardware backend.
Both levels of IR have the similar SSA graph, analysis and optimization facilities.
The schedule transform is applied on the CINN IR to do optimizations.
For more details, you can refer to:
https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/cinn
## Getting Started
### Compile
Clone PaddlePaddle first.
```
git clone https://github.com/PaddlePaddle/Paddle.git
cd Paddle
mkdir build
cd build
```
Build paddle with cinn:
```
cmake .. -DCINN_ONLY=OFF -DWITH_CINN=ON -DWITH_GPU=ON
```
Build cinn only:
```
cmake .. -DCINN_ONLY=ON -DWITH_CINN=ON -DWITH_GPU=ON
```
And then
```
make -j
```
### Install
Install paddle with cinn:
```
pip install python/dist/paddlepaddle_gpu-xxx.whl
```
Install cinn only:
```
pip install python/dist/cinn_gpu-xxx.whl
```
Then you can import paddle in the python environment and check if a paddle version with CINN is installed.
```
import paddle
paddle.is_compiled_with_cinn()
```
### Concepts
There are two levels of APIs in CINN, the higher level is HLIR and the lower level is CINN IR, both contain some concepts.
In HLIR
- `frontend::Program`, the program helps to define a machine learning computation,
- `hlir::framework::Tensor`, multi-dimensional arrays helps to manage a memory buffer.
- `hlir::framework::Program`, the final executable program in runtime. It holds many basic executable elements.
- `hlir::framework::Graph`, the graph that represents the structure of a model. Each node in the graph represents an operator (conv2d, relu, mul, etc.).
- `hlir::framework::GraphCompiler`, the compiler that transforms the graph representation(hlir::framework::Graph) of a model into an executable program(hlir::framework::Program).
In CINN IR
- `Compute`, the method to define a computation,
- `Lower`, the method to lower a computation to the corresponding IR,
- `LoweredFunc`, the function defined in CINN IR,
- `Var`, a scalar variable,
- `Expr`, an expression represents any CINN IR node(no specified Statement node),
## License
CINN is licensed under the [Apache 2.0 license](LICENSE).
## Acknowledgement
CINN learned a lot from the following projects:
- [Halide](https://github.com/halide/Halide): Referenced the design of most IR nodes,
- [TVM](https://github.com/apache/tvm): We learned many ideas including the semantics of some schedule primitives, TOPI, NNVM, and so on,
- [tiramisu](https://github.com/Tiramisu-Compiler): The isl usage, polyhedral compilation, schedule primitive implementation, and so on,
- [tensorflow/xla](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla): Referenced the semantics of the primitive operations.
if(NOT CINN_ONLY)
add_subdirectory(print_utils)
core_gather_headers()
gather_srcs(
cinnapi_src
SRCS
adapter_tensor.cc
anchor_sd_equation_context.cc
equation_function.cc
equation_solver.cc
equation_value.cc
generate_map_expr.cc
get_sub_reshape_dim_ranges.cc
igroup.cc
index_expr_infer_context.cc
kgroup.cc
m_ir.cc
naive_bidirection_equation_generator.cc
naive_op_equation_context.cc
partition_op_stmts.cc
schedule_descriptor.cc
schedule_dim.cc
schedule_mesh.cc
simplify_value.cc
write_broadcast_disabled_bidirection_equation_generator.cc)
cinn_cc_test(equation_value_match_trait_test SRCS
equation_value_match_trait_test.cc DEPS gtest glog)
cinn_cc_test(tree_test SRCS tree_test.cc DEPS gtest glog)
cinn_cc_test(inline_translator_test SRCS inline_translator_test.cc DEPS
cinncore)
message(STATUS "ADT srcs: ${cinnapi_src}")
endif()
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/adt/adapter_tensor.h"
#include "glog/logging.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
namespace cinn::adt::adapter {
std::size_t Tensor::GetRank() const {
return cinn::hlir::framework::pir::CompatibleInfo::ValueShape(node_data)
.size();
}
std::vector<int32_t> Tensor::GetShape() const {
std::vector<int32_t> ret{};
for (int dim_size :
cinn::hlir::framework::pir::CompatibleInfo::ValueShape(node_data)) {
ret.emplace_back(dim_size);
}
return ret;
}
std::size_t Tensor::GetNumel() const {
std::size_t ret = 1;
for (int dim_size :
cinn::hlir::framework::pir::CompatibleInfo::ValueShape(node_data)) {
ret = ret * dim_size;
}
return ret;
}
} // namespace cinn::adt::adapter
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/adt/adt.h"
#include "paddle/pir/core/value.h"
namespace cinn::adt::adapter {
struct Tensor final {
::pir::Value node_data;
bool operator==(const Tensor& other) const {
return this->node_data == other.node_data;
}
std::size_t GetRank() const;
std::vector<int32_t> GetShape() const;
std::size_t GetNumel() const;
};
inline std::size_t GetHashValueImpl(const Tensor& tensor) {
return std::hash<::pir::Value>()(tensor.node_data);
}
} // namespace cinn::adt::adapter
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <tuple>
#include <variant>
#include <vector>
#include "glog/logging.h"
namespace cinn {
namespace adt {
template <class... Ts>
struct match : Ts... {
using Ts::operator()...;
};
template <class... Ts>
match(Ts...) -> match<Ts...>;
template <typename... Ts, typename... Fs>
constexpr decltype(auto) operator>>(std::variant<Ts...> const& v,
match<Fs...> const& match) {
return std::visit(match, v);
}
template <typename... Ts>
class Union {
public:
Union(const Union&) = default;
Union(Union&&) = default;
template <
typename Arg,
std::enable_if_t<!std::is_same_v<std::decay_t<Arg>, Union>, bool> = true>
explicit Union(Arg&& arg) : variant_(std::forward<Arg>(arg)) {}
template <typename... Fs>
auto operator>>(match<Fs...> const& match) const {
return variant_ >> match;
}
const std::variant<Ts...>& variant() const { return variant_; }
private:
std::variant<Ts...> variant_;
};
template <typename... Ts>
class Tuple {
public:
Tuple(const Tuple&) = default;
Tuple(Tuple&&) = default;
Tuple& operator=(const Tuple&) = default;
Tuple& operator=(Tuple&&) = default;
template <typename... Args>
explicit Tuple(Args&&... args)
: tuple_(
std::make_shared<std::tuple<Ts...>>(std::forward<Args>(args)...)) {}
const std::tuple<Ts...>& tuple() const { return *tuple_; }
template <std::size_t I>
const auto& Get() const {
return std::get<I>(*tuple_);
}
protected:
std::shared_ptr<std::tuple<Ts...>> tuple_;
};
template <typename T>
bool TupleEqual(const T& lhs, const T& rhs) {
if (&lhs.tuple() == &rhs.tuple()) {
return true;
}
return lhs.tuple() == rhs.tuple();
}
template <typename T>
class List final {
public:
List(const List&) = default;
List(List&&) = default;
List& operator=(const List&) = default;
List& operator=(List&&) = default;
using value_type = T;
explicit List() : vector_(std::make_shared<std::vector<T>>()) {}
template <
typename Arg,
std::enable_if_t<!std::is_same_v<std::decay_t<Arg>, List>, bool> = true>
explicit List(Arg&& arg)
: vector_(std::make_shared<std::vector<T>>(
std::vector<T>{std::forward<Arg>(arg)})) {}
template <typename Arg0, typename Arg1, typename... Args>
List(Arg0&& arg0, Arg1&& arg1, Args&&... args)
: vector_(std::make_shared<std::vector<T>>(
std::vector<T>{std::forward<Arg0>(arg0),
std::forward<Arg1>(arg1),
std::forward<Args>(args)...})) {}
bool operator==(const List& other) const {
if (&vector() == &other.vector()) {
return true;
}
return vector() == other.vector();
}
bool operator!=(const List& other) const { return !(*this == other); }
std::vector<T>& operator*() const { return *vector_; }
std::vector<T>* operator->() const { return vector_.get(); }
const std::vector<T>& vector() const { return *vector_; }
const auto& Get(std::size_t idx) const { return vector_->at(idx); }
private:
std::shared_ptr<std::vector<T>> vector_;
};
#define DEFINE_ADT_TAG(TagName) \
template <typename T> \
class TagName { \
public: \
TagName() = default; \
TagName(const TagName&) = default; \
TagName(TagName&&) = default; \
TagName& operator=(const TagName&) = default; \
TagName& operator=(TagName&&) = default; \
\
bool operator==(const TagName& other) const { \
return value_ == other.value(); \
} \
\
bool operator!=(const TagName& other) const { \
return value_ != other.value(); \
} \
\
template <typename Arg, \
std::enable_if_t<!std::is_same_v<std::decay_t<Arg>, TagName>, \
bool> = true> \
explicit TagName(Arg&& value) : value_(value) {} \
\
const T& value() const { return value_; } \
\
private: \
T value_; \
};
#define DEFINE_ADT_UNION(class_name, ...) \
class class_name final { \
public: \
class_name(const class_name&) = default; \
class_name(class_name&&) = default; \
class_name& operator=(const class_name& other) = default; \
class_name& operator=(class_name&& other) = default; \
\
template <typename Arg, \
std::enable_if_t<!std::is_same_v<std::decay_t<Arg>, class_name>, \
bool> = true> \
class_name(Arg&& arg) : variant_(std::forward<Arg>(arg)) {} \
\
template <typename T> \
const T& Get() const { \
return std::get<T>(variant_); \
} \
\
template <typename T> \
bool Has() const { \
return std::holds_alternative<T>(variant_); \
} \
\
template <typename T> \
auto Visit(const T& visitor) const { \
return std::visit(visitor, variant_); \
} \
\
template <typename... Fs> \
auto operator>>(match<Fs...> const& match) const { \
return variant_ >> match; \
} \
\
const std::variant<__VA_ARGS__>& variant() const { return variant_; } \
\
private: \
std::variant<__VA_ARGS__> variant_; \
}
template <typename UnionT>
bool UnionEqual(const UnionT& lhs, const UnionT& rhs) {
if (&lhs == &rhs) {
return true;
}
return std::visit(
[](auto&& lhs, auto&& rhs) {
if constexpr (std::is_same<std::decay_t<decltype(lhs)>,
std::decay_t<decltype(rhs)>>::value) {
return lhs == rhs;
} else {
return false;
}
},
lhs.variant(),
rhs.variant());
}
#define DEFINE_ADT_UNARY(name) \
template <typename T> \
struct name : public Tuple<T> { \
using Tuple<T>::Tuple; \
}
#define DEFINE_ADT_BINARY(name) \
template <typename T0, typename T1> \
struct name : public Tuple<T0, T1> { \
using Tuple<T0, T1>::Tuple; \
}
DEFINE_ADT_UNARY(Neg);
DEFINE_ADT_BINARY(Add);
DEFINE_ADT_BINARY(Mul);
DEFINE_ADT_BINARY(Div);
DEFINE_ADT_BINARY(Mod);
#define OVERLOAD_OPERATOR_EQ_NE(type, function) \
inline bool operator==(const type& lhs, const type& rhs) { \
return function(lhs, rhs); \
} \
inline bool operator!=(const type& lhs, const type& rhs) { \
return !(lhs == rhs); \
}
template <typename T>
std::size_t TagHashValue(const T& tag) {
return std::hash<std::decay_t<decltype(tag.value())>>()(tag.value());
}
#define OVERRIDE_TAG_GET_HASH_VALUE(cls) \
inline std::size_t GetHashValue(const cls& tag) { return TagHashValue(tag); }
#define OVERRIDE_UNION_GET_HASH_VALUE(cls) \
inline std::size_t GetHashValue(const cls& union_obj) { \
return std::visit([](const auto& impl) { return GetHashValueImpl(impl); }, \
union_obj.variant()); \
}
using Name = std::string;
// Undefined = {}
struct Undefined final {
bool operator==(const Undefined&) const { return true; }
bool operator!=(const Undefined&) const { return false; }
};
// Ok = {}
struct Ok final {
bool operator==(const Ok&) const { return true; }
bool operator!=(const Ok&) const { return false; }
};
#define ADT_TODO() LOG(FATAL) << "TODO"
inline std::size_t hash_combine(std::size_t lhs, std::size_t rhs) {
return lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
}
} // namespace adt
} // namespace cinn
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/adt/anchor_sd_equation_context.h"
namespace cinn::adt::config {
void GenerateScheduleMeshEquations(
const ScheduleMesh& sched_mesh,
const List<Iterator>& tmp_anchor_iterators,
const List<Iterator>& sd_iterators,
Equations* equations,
std::unordered_map<Dim, const Constant>* dim2constant);
namespace {
void GenerateScheduleMeshEquationsImpl(
const List<ScheduleDim>& sched_dims,
const List<Iterator>& input_iterators,
const List<Iterator>& output_iterators,
Equations* equations,
std::unordered_map<Dim, const Constant>* dim2constant) {
CHECK_EQ(input_iterators->size(), output_iterators->size());
for (std::size_t i = 0; i < output_iterators->size(); ++i) {
Equal(input_iterators->at(i), output_iterators->at(i), equations);
}
}
void GenerateScheduleMeshEquationsImpl(
const ScheduleMeshReshape<ScheduleMesh>& sched_reshape,
const List<Iterator>& input_iterators,
const List<Iterator>& output_iterators,
Equations* equations,
std::unordered_map<Dim, const Constant>* dim2constant) {
const auto& [middle_sched_mesh, shape] = sched_reshape.tuple();
List<Iterator> middle_iterators =
MakeIterators(GetOutputRank(middle_sched_mesh));
List<Dim> middle_dims = MakeDims(GetOutputRank(middle_sched_mesh));
CHECK_EQ(shape.value()->size(), output_iterators->size());
List<Dim> output_dims = MakeDims(output_iterators->size());
{
List<Constant> middle_dim_values = GetOutputDimValues(middle_sched_mesh);
for (std::size_t i = 0; i < middle_dim_values->size(); ++i) {
CHECK(dim2constant->emplace(middle_dims->at(i), middle_dim_values->at(i))
.second);
}
List<Constant> output_dim_values =
GetOutputDimValues(ScheduleMesh{sched_reshape});
for (std::size_t i = 0; i < output_dims->size(); ++i) {
CHECK(dim2constant->emplace(output_dims->at(i), output_dim_values->at(i))
.second);
}
}
const auto& middle_index = MakeDot(middle_iterators, middle_dims, equations);
const auto& output_index = MakeDot(output_iterators, output_dims, equations);
Equal(middle_index, output_index, equations);
GenerateScheduleMeshEquations(middle_sched_mesh,
input_iterators,
middle_iterators,
equations,
dim2constant);
}
void GenerateScheduleMeshEquationsImpl(
const ScheduleMeshTranspose<ScheduleMesh>& sched_transpose,
const List<Iterator>& input_iterators,
const List<Iterator>& output_iterators,
Equations* equations,
std::unordered_map<Dim, const Constant>* dim2constant) {
const auto& [sched_mesh, perm] = sched_transpose.tuple();
CHECK_EQ(GetOutputRank(sched_mesh), output_iterators->size());
List<Iterator> middle_iterators = MakeIterators(output_iterators->size());
for (std::size_t i = 0; i < perm.value()->size(); ++i) {
Equal(middle_iterators->at(perm.value()->at(i)),
output_iterators->at(i),
equations);
}
GenerateScheduleMeshEquations(
sched_mesh, input_iterators, middle_iterators, equations, dim2constant);
}
void GenerateScheduleMeshEquationsImpl(
const ScheduleMeshPadding<ScheduleMesh>& sched_padding,
const List<Iterator>& input_iterators,
const List<Iterator>& output_iterators,
Equations* equations,
std::unordered_map<Dim, const Constant>* dim2constant) {
const auto& [sched_mesh, _] = sched_padding.tuple();
CHECK_EQ(GetOutputRank(sched_mesh), output_iterators->size());
List<Iterator> middle_iterators = MakeIterators(output_iterators->size());
for (std::size_t i = 0; i < output_iterators->size(); ++i) {
Equal(middle_iterators->at(i), output_iterators->at(i), equations);
}
GenerateScheduleMeshEquations(
sched_mesh, input_iterators, middle_iterators, equations, dim2constant);
}
} // namespace
void GenerateScheduleMeshEquations(
const ScheduleMesh& sched_mesh,
const List<Iterator>& tmp_anchor_iterators,
const List<Iterator>& sd_iterators,
Equations* equations,
std::unordered_map<Dim, const Constant>* dim2constant) {
return std::visit(
[&](const auto& impl) {
return GenerateScheduleMeshEquationsImpl(
impl, tmp_anchor_iterators, sd_iterators, equations, dim2constant);
},
sched_mesh.variant());
}
void AnchorSdEquationContext::InitDim2Constant(const ScheduleMesh& sched_mesh) {
const auto& AddDimValue = [&](const List<Dim>& dims,
const List<Constant>& dim_values) {
CHECK_EQ(dims->size(), dim_values->size());
for (std::size_t i = 0; i < dims->size(); ++i) {
CHECK(dim2constant_.emplace(dims->at(i), dim_values->at(i)).second);
}
};
const auto& anchor_dim_values =
GetOutputDimValues(GetInputScheduleMesh(sched_mesh));
AddDimValue(anchor_dims_, anchor_dim_values);
const auto& sd_dim_values = GetOutputDimValues(sched_mesh);
AddDimValue(sd_dims_, sd_dim_values);
}
void AnchorSdEquationContext::GenerateSdEquation(const ScheduleMesh& sched_mesh,
const Index& anchor_index) {
const auto& tmp_anchor_iterators = MakeIterators(GetInputRank(sched_mesh));
{
const auto& tmp_anchor_index =
MakeDot(tmp_anchor_iterators, anchor_dims_, &equations_);
Equal(tmp_anchor_index, anchor_index, &equations_);
}
GenerateScheduleMeshEquations(sched_mesh,
tmp_anchor_iterators,
sd_iterators_,
&equations_,
&dim2constant_);
}
} // namespace cinn::adt::config
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/adt/equation.h"
#include "paddle/cinn/adt/equation_util.h"
#include "paddle/cinn/adt/schedule_mesh.h"
namespace cinn::adt::config {
using AnchorIndex = Index;
class AnchorSdEquationContext final {
public:
AnchorSdEquationContext(const AnchorSdEquationContext&) = default;
AnchorSdEquationContext(AnchorSdEquationContext&&) = default;
AnchorSdEquationContext& operator=(const AnchorSdEquationContext&) = default;
AnchorSdEquationContext& operator=(AnchorSdEquationContext&&) = default;
AnchorSdEquationContext(const ScheduleMesh& sched_mesh,
const AnchorIndex& anchor_index)
: sd_dims_(MakeDims(GetOutputRank(sched_mesh))),
sd_iterators_(MakeIterators(GetOutputRank(sched_mesh))),
anchor_dims_(MakeDims(GetInputRank(sched_mesh))) {
InitDim2Constant(sched_mesh);
GenerateSdEquation(sched_mesh, anchor_index);
}
const List<Dim>& sd_dims() const { return sd_dims_; }
const List<Dim>& anchor_dims() const { return anchor_dims_; }
const List<Iterator>& sd_iterators() const { return sd_iterators_; }
const Equations& equations() const { return equations_; }
const std::unordered_map<Dim, const Constant>& dim2constant() const {
return dim2constant_;
}
private:
void InitDim2Constant(const ScheduleMesh& sched_mesh);
void GenerateSdEquation(const ScheduleMesh& sched_mesh,
const Index& tensor_index);
List<Dim> sd_dims_;
List<Iterator> sd_iterators_;
List<Dim> anchor_dims_;
Equations equations_;
std::unordered_map<Dim, const Constant> dim2constant_;
};
} // namespace cinn::adt::config
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <optional>
#include "paddle/cinn/adt/equation_function.h"
namespace cinn::adt {
class OpStmt;
class DirectionEquationGenerator {
public:
DirectionEquationGenerator(const DirectionEquationGenerator&) = delete;
DirectionEquationGenerator(DirectionEquationGenerator&&) = delete;
~DirectionEquationGenerator() = default;
virtual Equations GetDirectionEquations() const = 0;
virtual std::function<const OpStmt*(const FakeOpPlaceHolder&)>
MakeGetterOpStmt4OpPlaceHolder() const = 0;
virtual std::optional<Index> OutMsgIndex4InMsgIndex(
const Index& index) const = 0;
protected:
DirectionEquationGenerator() = default;
};
} // namespace cinn::adt
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/adt/equation_constant.h"
#include "paddle/cinn/adt/equation_function.h"
#include "paddle/cinn/adt/equation_variable.h"
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/adt/adt.h"
#include "paddle/cinn/adt/logical.h"
#include "paddle/cinn/adt/tags.h"
#include "paddle/cinn/adt/unique_id.h"
namespace cinn::adt {
// Dim = tDim UniqueId
using Dim = tDim<UniqueId>;
// DimTuple = [Dim]
using DimTuple = List<Dim>;
DEFINE_ADT_UNION(Constant, std::int64_t, Dim, List<Constant>);
OVERLOAD_OPERATOR_EQ_NE(Constant, UnionEqual);
// EquationStaticValue = Dim | std::int64_t
DEFINE_ADT_UNION(EquationStaticValue, Dim, std::int64_t);
OVERLOAD_OPERATOR_EQ_NE(EquationStaticValue, UnionEqual);
using EquationStaticLogical = Logical<EquationStaticValue>;
inline std::size_t GetHashValue(const Constant& c);
inline std::size_t GetHashValueImpl(const std::int64_t& c) { return c; }
inline std::size_t GetHashValueImpl(const Dim& c) {
return c.value().unique_id();
}
inline std::size_t GetHashValueImpl(const List<Constant>& c) {
std::size_t ret = 0;
for (const auto& c_item : *c) {
ret = hash_combine(ret, GetHashValue(c_item));
}
return ret;
}
OVERRIDE_UNION_GET_HASH_VALUE(Constant);
} // namespace cinn::adt
namespace std {
template <>
struct hash<::cinn::adt::Dim> final {
std::size_t operator()(const ::cinn::adt::Dim& dim) const {
return dim.value().unique_id();
}
};
template <>
struct hash<cinn::adt::Constant> {
std::size_t operator()(const cinn::adt::Constant& constant) const {
return GetHashValue(constant);
}
};
} // namespace std
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/adt/equation_function.h"
namespace cinn::adt {
std::pair<std::unordered_set<Variable>, std::unordered_set<Variable>>
CollectInputAndOutputVariables(const Function& function) {
std::unordered_set<Variable> in_variables;
std::unordered_set<Variable> out_variables;
function >>
match{
[&](const Identity<tOut<Iterator>, tIn<Iterator>>& identity) {
const auto& [out_iter, in_iter] = identity.tuple();
out_variables.emplace(Variable{out_iter.value()});
in_variables.emplace(Variable{in_iter.value()});
},
[&](const Identity<tOut<Index>, tIn<Index>>& identity) {
const auto& [out_index, in_index] = identity.tuple();
out_variables.emplace(Variable{out_index.value()});
in_variables.emplace(Variable{in_index.value()});
},
[&](const IndexDot<List<Dim>, tOut<Index>, tIn<List<Iterator>>>&
dot) {
const auto& [dims, out_index, in_iterators] = dot.tuple();
out_variables.emplace(Variable{out_index.value()});
for (const auto& iterator : *in_iterators.value()) {
in_variables.emplace(Variable{iterator});
}
},
[&](const GetBroadcastedIterator<Dim, tOut<Iterator>, tIn<Iterator>>&
broadcast) {
const auto& [dim, out_iterator, in_iterator] = broadcast.tuple();
out_variables.emplace(Variable{out_iterator.value()});
in_variables.emplace(Variable{in_iterator.value()});
},
[&](const IndexUnDot<List<Dim>, tOut<List<Iterator>>, tIn<Index>>&
undot) {
const auto& [dims, out_iterators, in_index] = undot.tuple();
for (const auto& iterator : *out_iterators.value()) {
out_variables.emplace(Variable{iterator});
}
in_variables.emplace(Variable{in_index.value()});
},
[&](const InMsg2OutMsg<tOut<FakeOpPlaceHolder>,
tOut<OpArgIndexes<std::optional<Index>>>,
tIn<OpArgIndexes<Index>>>& in_msg2out_msg) {
const auto& [op_placeholder, out_msg_indexes, in_msg_indexes] =
in_msg2out_msg.tuple();
out_variables.emplace(Variable{op_placeholder.value()});
const auto& [out_msg_in_indexes, out_msg_out_indexes] =
out_msg_indexes.value().tuple();
const auto& [in_msg_in_indexes, in_msg_out_indexes] =
in_msg_indexes.value().tuple();
for (const auto& index : *out_msg_in_indexes.value()) {
out_variables.emplace(Variable{index});
}
for (const auto& index : *out_msg_out_indexes.value()) {
if (index.has_value()) {
out_variables.emplace(Variable{index.value()});
}
}
for (const auto& index : *in_msg_in_indexes.value()) {
in_variables.emplace(Variable{index});
}
for (const auto& index : *in_msg_out_indexes.value()) {
in_variables.emplace(Variable{index});
}
},
[&](const ConstantFunction<tOut<Iterator>, tIn<Index>>&
constant_function) {
const auto& [out_iterator, in_index, constant] =
constant_function.tuple();
out_variables.emplace(Variable{out_iterator.value()});
in_variables.emplace(Variable{in_index.value()});
},
};
return std::make_pair(in_variables, out_variables);
}
std::string GetFunctionTypeName(const Function& function) {
return function >>
match{
[&](const Identity<tOut<Iterator>, tIn<Iterator>>& identity) {
return "Identity";
},
[&](const Identity<tOut<Index>, tIn<Index>>& identity) {
return "Identity";
},
[&](const IndexDot<List<Dim>, tOut<Index>, tIn<List<Iterator>>>&
dot) { return "IndexDot"; },
[&](const GetBroadcastedIterator<Dim,
tOut<Iterator>,
tIn<Iterator>>& broadcast) {
return "GetBroadcastedIterator";
},
[&](const IndexUnDot<List<Dim>, tOut<List<Iterator>>, tIn<Index>>&
undot) { return "IndexUnDot"; },
[&](const InMsg2OutMsg<tOut<FakeOpPlaceHolder>,
tOut<OpArgIndexes<std::optional<Index>>>,
tIn<OpArgIndexes<Index>>>& in_msg2out_msg) {
return "InMsg2OutMsg";
},
[&](const ConstantFunction<tOut<Iterator>, tIn<Index>>&
constant_function) { return "ConstantFunction"; },
};
}
const void* GetFunctionDataPtr(const Function& function) {
return function >>
match{
[&](const Identity<tOut<Iterator>, tIn<Iterator>>& identity)
-> const void* { return &identity.tuple(); },
[&](const Identity<tOut<Index>, tIn<Index>>& identity)
-> const void* { return &identity.tuple(); },
[&](const IndexDot<List<Dim>, tOut<Index>, tIn<List<Iterator>>>&
dot) -> const void* { return &dot.tuple(); },
[&](const GetBroadcastedIterator<Dim,
tOut<Iterator>,
tIn<Iterator>>& broadcast)
-> const void* { return &broadcast.tuple(); },
[&](const IndexUnDot<List<Dim>, tOut<List<Iterator>>, tIn<Index>>&
undot) -> const void* { return &undot.tuple(); },
[&](const InMsg2OutMsg<tOut<FakeOpPlaceHolder>,
tOut<OpArgIndexes<std::optional<Index>>>,
tIn<OpArgIndexes<Index>>>& in_msg2out_msg)
-> const void* { return &in_msg2out_msg.tuple(); },
[&](const ConstantFunction<tOut<Iterator>, tIn<Index>>&
constant_function) -> const void* {
return &constant_function.tuple();
},
};
}
} // namespace cinn::adt
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <functional>
#include <string>
#include <type_traits>
#include "paddle/cinn/adt/adt.h"
#include "paddle/cinn/adt/equation_constant.h"
#include "paddle/cinn/adt/equation_variable.h"
#include "paddle/cinn/adt/tags.h"
#include "paddle/cinn/common/equation_graph_topo_walker.h"
namespace cinn::adt {
template <typename OutT, typename InT>
struct Identity;
// Identity (tOut Iterator) (tIn Iterator)
template <>
struct Identity<tOut<Iterator>, tIn<Iterator>>
: public Tuple<tOut<Iterator>, tIn<Iterator>> {
using Tuple<tOut<Iterator>, tIn<Iterator>>::Tuple;
};
// Identity (tOut Index) (tIn Index)
template <>
struct Identity<tOut<Index>, tIn<Index>>
: public Tuple<tOut<Index>, tIn<Index>> {
using Tuple<tOut<Index>, tIn<Index>>::Tuple;
};
template <typename DimT, typename OutT, typename InT>
struct IndexDot;
// IndexDot [Dim] (tOut Index) (tIn [Iterator])
template <>
struct IndexDot<List<Dim>, tOut<Index>, tIn<List<Iterator>>>
: public Tuple<List<Dim>, tOut<Index>, tIn<List<Iterator>>> {
using Tuple<List<Dim>, tOut<Index>, tIn<List<Iterator>>>::Tuple;
};
template <typename DimT, typename OutT, typename InT>
struct IndexUnDot;
// IndexUnDot [Dim] (tOut [Iterator]) (tIn Index)
template <>
struct IndexUnDot<List<Dim>, tOut<List<Iterator>>, tIn<Index>>
: public Tuple<List<Dim>, tOut<List<Iterator>>, tIn<Index>> {
using Tuple<List<Dim>, tOut<List<Iterator>>, tIn<Index>>::Tuple;
};
// OpArgIndexes = (tIn [Index], tOut [Index])
template <typename OutIndexT>
struct OpArgIndexes final
: public Tuple<tIn<List<Index>>, tOut<List<OutIndexT>>> {
using Tuple<tIn<List<Index>>, tOut<List<OutIndexT>>>::Tuple;
};
template <typename FakeOpT, typename OutT, typename InT>
struct InMsg2OutMsg;
// InMsg2OutMsg (tOut FakeOpPlaceHolder) (tOut (tOutMsg OpArgIndexes))
// (tIn (tInMsg OpArgIndexes))
template <>
struct InMsg2OutMsg<tOut<FakeOpPlaceHolder>,
tOut<OpArgIndexes<std::optional<Index>>>,
tIn<OpArgIndexes<Index>>>
: public Tuple<tOut<FakeOpPlaceHolder>,
tOut<OpArgIndexes<std::optional<Index>>>,
tIn<OpArgIndexes<Index>>> {
using Tuple<tOut<FakeOpPlaceHolder>,
tOut<OpArgIndexes<std::optional<Index>>>,
tIn<OpArgIndexes<Index>>>::Tuple;
};
template <typename T0, typename T1>
struct ConstantFunction;
template <>
struct ConstantFunction<tOut<Iterator>, tIn<Index>> final
: public Tuple<tOut<Iterator>, tIn<Index>, Constant> {
using Tuple<tOut<Iterator>, tIn<Index>, Constant>::Tuple;
};
template <typename DimT, typename OutT, typename InT>
struct GetBroadcastedIterator;
template <>
struct GetBroadcastedIterator<Dim, tOut<Iterator>, tIn<Iterator>>
: public Tuple<Dim, tOut<Iterator>, tIn<Iterator>> {
using Tuple<Dim, tOut<Iterator>, tIn<Iterator>>::Tuple;
};
// clang-format off
DEFINE_ADT_UNION(Equation,
Identity<tOut<Iterator>, tIn<Iterator>>,
Identity<tOut<Index>, tIn<Index>>,
GetBroadcastedIterator<Dim, tOut<Iterator>, tIn<Iterator>>,
IndexDot<List<Dim>, tOut<Index>, tIn<List<Iterator>>>,
IndexUnDot<List<Dim>, tOut<List<Iterator>>, tIn<Index>>,
InMsg2OutMsg<tOut<FakeOpPlaceHolder>,
tOut<OpArgIndexes<std::optional<Index>>>,
tIn<OpArgIndexes<Index>>>,
ConstantFunction<tOut<Iterator>, tIn<Index>>);
// clang-format on
// Function = Equation
using Function = Equation;
using Equations = List<Equation>;
using GraphView = EquationGraphTopoWalker<Variable, const Equation*>;
std::pair<std::unordered_set<Variable> /*input*/,
std::unordered_set<Variable> /*output*/>
CollectInputAndOutputVariables(const Function& function);
std::string GetFunctionTypeName(const Function& function);
const void* GetFunctionDataPtr(const Function& function);
} // namespace cinn::adt
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment