Commit 01a10755 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

2.5.2-dtk24.04

parent 63eb0da5
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/stl.h>
#include <string>
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace py = pybind11;
namespace cinn::pybind {
void BindSchedule(py::module *m) {
py::class_<ir::IRSchedule> ir_schedule(*m, "IRSchedule");
ir_schedule
.def(py::init<const ir::ModuleExpr &,
utils::LinearRandomEngine::StateType,
bool,
utils::ErrorMessageLevel>(),
py::arg("modexpr"),
py::arg("rand_seed") = -1,
py::arg("debug_flag") = false,
py::arg("err_msg_level") = utils::ErrorMessageLevel::kGeneral)
.def_static(
"make",
[](ir::LoweredFunc &ir_func) {
ir::ModuleExpr *module_expr = new ir::ModuleExpr({ir_func->body});
auto scheduler = std::make_unique<ir::IRSchedule>(*module_expr);
return scheduler;
})
.def("fuse",
py::overload_cast<const std::vector<Expr> &>(&ir::IRSchedule::Fuse))
.def("split",
py::overload_cast<const Expr &, const std::vector<int> &>(
&ir::IRSchedule::Split),
py::arg("loop"),
py::arg("factors"))
.def("compute_at",
py::overload_cast<const Expr &, const Expr &, bool>(
&ir::IRSchedule::ComputeAt),
py::arg("block"),
py::arg("loop"),
py::arg("keep_unit_loops") = false)
.def("simple_compute_at",
py::overload_cast<const Expr &, const Expr &>(
&ir::IRSchedule::SimpleComputeAt),
py::arg("block"),
py::arg("loop"))
.def("reverse_compute_at",
py::overload_cast<const Expr &, const Expr &, bool>(
&ir::IRSchedule::ReverseComputeAt),
py::arg("block"),
py::arg("loop"),
py::arg("keep_unit_loops") = false)
.def("cache_read",
py::overload_cast<const Expr &, int, const std::string &>(
&ir::IRSchedule::CacheRead))
.def("cache_write",
py::overload_cast<const Expr &, int, const std::string &>(
&ir::IRSchedule::CacheWrite))
.def("sync_threads",
py::overload_cast<const Expr &, bool>(&ir::IRSchedule::SyncThreads),
py::arg("ir_node"),
py::arg("after_node") = true)
.def("set_buffer",
py::overload_cast<Expr &, const std::string &, bool>(
&ir::IRSchedule::SetBuffer),
py::arg("block"),
py::arg("memory_type"),
py::arg("fixed") = false)
.def("reorder",
py::overload_cast<const std::vector<Expr> &>(
&ir::IRSchedule::Reorder))
.def("parallel",
py::overload_cast<const Expr &>(&ir::IRSchedule::Parallel))
.def("vectorize",
py::overload_cast<const Expr &, int>(&ir::IRSchedule::Vectorize))
.def("unroll", py::overload_cast<const Expr &>(&ir::IRSchedule::Unroll))
.def("compute_inline",
py::overload_cast<const Expr &>(&ir::IRSchedule::ComputeInline))
.def("reverse_compute_inline",
py::overload_cast<const Expr &>(
&ir::IRSchedule::ReverseComputeInline))
.def("bind", &ir::IRSchedule::Bind)
.def("copy_transform_and_loop_info",
py::overload_cast<const Expr &, const Expr &>(
&ir::IRSchedule::CopyTransformAndLoopInfo))
.def("rfactor",
py::overload_cast<const Expr &, int>(&ir::IRSchedule::Rfactor))
.def("annotate",
py::overload_cast<const Expr &,
const std::string &,
const ir::attr_t &>(&ir::IRSchedule::Annotate))
.def("unannotate",
py::overload_cast<Expr &, const std::string &>(
&ir::IRSchedule::Unannotate))
.def("flatten_loops",
py::overload_cast<const std::vector<Expr> &, const bool>(
&ir::IRSchedule::FlattenLoops),
py::arg("loops"),
py::arg("force_flat") = false)
.def("sample_perfect_tile",
py::overload_cast<const Expr &, int, int, const std::vector<int> &>(
&ir::IRSchedule::SamplePerfectTile),
py::arg("loop"),
py::arg("n"),
py::arg("max_innermost_factor"),
py::arg("decision") = std::vector<int>())
.def("sample_categorical",
py::overload_cast<const std::vector<int> &,
const std::vector<float> &,
const std::vector<int> &>(
&ir::IRSchedule::SampleCategorical),
py::arg("candidates"),
py::arg("probs"),
py::arg("decision") = std::vector<int>())
.def("get_module",
py::overload_cast<>(&ir::IRSchedule::GetModule, py::const_))
.def("get_root_block", &ir::IRSchedule::GetRootBlock)
.def("get_block",
py::overload_cast<const std::string &>(&ir::IRSchedule::GetBlock,
py::const_))
.def("get_all_blocks",
py::overload_cast<>(&ir::IRSchedule::GetAllBlocks, py::const_))
.def("get_loops",
py::overload_cast<const std::string &>(&ir::IRSchedule::GetLoops,
py::const_))
.def("get_name2loops_dict",
[](const ir::IRSchedule &self, const std::string &block_name) {
std::vector<ir::Expr> loops = self.GetLoops(block_name);
std::map<std::string, ir::Expr> name2loops;
for (const ir::Expr &loop : loops) {
name2loops[loop.As<ir::For>()->loop_var->name] = loop;
}
return name2loops;
});
}
} // namespace cinn::pybind
......@@ -13,7 +13,9 @@
// limitations under the License.
#include "paddle/cinn/pybind/bind.h"
#include "paddle/cinn/utils/error.h"
#include "paddle/cinn/utils/profiler.h"
#include "paddle/cinn/utils/random_engine.h"
namespace py = pybind11;
......@@ -69,6 +71,9 @@ void BindUtils(py::module *m) {
"type",
[](HostEvent &self) -> const EventType & { return self.type_; },
[](HostEvent &self, const EventType &v) { self.type_ = v; });
py::class_<utils::LinearRandomEngine>(*m, "LinearRandomEngine");
py::class_<utils::ErrorMessageLevel>(*m, "ErrorMessageLevel");
}
} // namespace pybind
......
......@@ -128,7 +128,8 @@ typedef enum cinn_device_kind_t {
cinn_unk_device = -1, // Undefined device.
cinn_x86_device = 0, // X86 device
cinn_opencl_device = 1, // OpenCL device
cinn_arm_device = 2 // ARM device
cinn_arm_device = 2, // ARM device
cinn_nvgpu_device = 3 // NVIDIA GPU device
} cinn_device_kind_t;
//! Help to tell where the buffer locates.
......
......@@ -474,11 +474,11 @@ __device__ inline bool cinn_any(const bool left, const bool right) { return left
tmp_val = __shfl_sync(mask, tmp_val, 0, 32); \
return tmp_val; \
} else { \
tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_down_sync(mask, tmp_val, 16, 32)); \
tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_down_sync(mask, tmp_val, 8, 32)); \
tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_down_sync(mask, tmp_val, 4, 32)); \
tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_down_sync(mask, tmp_val, 2, 32)); \
tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_down_sync(mask, tmp_val, 1, 32)); \
tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 16, 32)); \
tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 8, 32)); \
tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 4, 32)); \
tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 2, 32)); \
tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 1, 32)); \
return tmp_val; \
} \
}
......@@ -530,25 +530,22 @@ __device__ inline float cinn_warp_reduce_avg_fp32(const float *buf, int offset,
#define CINN_BLOCK_REDUCE_INTERNAL_IMPL(TYPE, value, init_value, cinn_warp_shuffle_internal) \
int warp_id = threadIdx.x / 32; \
__shared__ TYPE tmp[32]; \
if (warp_id == 0) { \
tmp[threadIdx.x] = init_value; \
} \
TYPE tmp_val = cinn_warp_shuffle_internal(value); \
if (blockDim.x <= 32) { \
return tmp_val; \
} \
__shared__ TYPE tmp[32]; \
if (warp_id == 0) { \
tmp[threadIdx.x] = init_value; \
} \
__syncthreads(); \
if (threadIdx.x % 32 == 0) { \
if ((threadIdx.x & 31) == 0) { \
tmp[warp_id] = tmp_val; \
} \
__syncthreads(); \
if (warp_id == 0) { \
tmp_val = tmp[threadIdx.x]; \
tmp_val = cinn_warp_shuffle_internal(tmp_val); \
if (threadIdx.x == 0) { \
tmp[0] = tmp_val; \
} \
tmp[threadIdx.x] = cinn_warp_shuffle_internal(tmp_val); \
} \
__syncthreads(); \
return tmp[0];
......@@ -575,13 +572,57 @@ EXPAND_REDUCE_FP16_MACRO(CINN_BLOCK_REDUCE_INTERNAL_MACRO)
#undef CINN_BLOCK_REDUCE_INTERNAL_IMPL
#undef CINN_BLOCK_REDUCE_INTERNAL_MACRO
#define CINN_BLOCK_REDUCE_INTERNAL_SHM_IMPL(TYPE, value, init_value, cinn_warp_shuffle_internal) \
int warp_id = threadIdx.x / 32; \
TYPE tmp_val = cinn_warp_shuffle_internal(value); \
if (blockDim.x <= 32) { \
return tmp_val; \
} \
if (warp_id == 0) { \
shm[threadIdx.x] = init_value; \
} \
__syncthreads(); \
if ((threadIdx.x & 31) == 0) { \
shm[warp_id] = tmp_val; \
} \
__syncthreads(); \
if (warp_id == 0) { \
tmp_val = shm[threadIdx.x]; \
shm[threadIdx.x] = cinn_warp_shuffle_internal(tmp_val); \
} \
__syncthreads(); \
return shm[0];
#define CINN_BLOCK_REDUCE_INTERNAL_SHM_MACRO(REDUCE_TYPE, INITIAL_VALUE, DTYPE) \
__device__ inline DTYPE cinn_block_reduce_##REDUCE_TYPE##_internal_shm(const DTYPE value, DTYPE* shm) { \
CINN_BLOCK_REDUCE_INTERNAL_SHM_IMPL(DTYPE, value, (DTYPE)(INITIAL_VALUE), cinn_warp_shuffle_##REDUCE_TYPE##_internal); \
}
EXPAND_REDUCE_INT32_MARCO(CINN_BLOCK_REDUCE_INTERNAL_SHM_MACRO)
EXPAND_REDUCE_INT64_MARCO(CINN_BLOCK_REDUCE_INTERNAL_SHM_MACRO)
EXPAND_REDUCE_FP32_MACRO(CINN_BLOCK_REDUCE_INTERNAL_SHM_MACRO)
EXPAND_REDUCE_FP64_MACRO(CINN_BLOCK_REDUCE_INTERNAL_SHM_MACRO)
EXPAND_REDUCE_BOOL_MACRO(CINN_BLOCK_REDUCE_INTERNAL_SHM_MACRO)
#ifdef CINN_CUDA_BF16
EXPAND_REDUCE_BF16_MACRO(CINN_BLOCK_REDUCE_INTERNAL_SHM_MACRO)
#endif
#ifdef CINN_CUDA_FP16
EXPAND_REDUCE_FP16_MACRO(CINN_BLOCK_REDUCE_INTERNAL_SHM_MACRO)
#endif
#undef CINN_BLOCK_REDUCE_INTERNAL_SHM_IMPL
#undef CINN_BLOCK_REDUCE_INTERNAL_SHM_MACRO
#define CINN_BLOCK_REDUCE_IMPL(REDUCE_TYPE, INITIAL_VALUE, DTYPE) \
__device__ inline DTYPE cinn_block_reduce_##REDUCE_TYPE(const DTYPE *buf, int offset, int extend) { \
__shared__ DTYPE shm[32]; \
DTYPE tmp_val = (DTYPE)(INITIAL_VALUE); \
for (int i = threadIdx.x; i < extend; i += blockDim.x) { \
tmp_val = cinn_##REDUCE_TYPE(tmp_val, buf[offset + i]); \
} \
return cinn_block_reduce_##REDUCE_TYPE##_internal(tmp_val); \
return cinn_block_reduce_##REDUCE_TYPE##_internal_shm(tmp_val,shm); \
}
EXPAND_REDUCE_INT32_MARCO(CINN_BLOCK_REDUCE_IMPL)
......
......@@ -70,6 +70,27 @@ inline cublasStatus_t cublasGemm(cudaDataType_t dtype,
reinterpret_cast<double *>(C),
ldc);
} else if (dtype == CUDA_R_16F) {
#if CUDA_VERSION >= 11000
return cublasGemmEx(handle,
transa,
transb,
m,
n,
k,
&alpha,
A,
CUDA_R_16F,
lda,
B,
CUDA_R_16F,
ldb,
&beta,
C,
CUDA_R_16F,
ldc,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#else
common::float16 alpha_fp16{alpha};
common::float16 beta_fp16{beta};
return cublasHgemm(handle,
......@@ -86,6 +107,7 @@ inline cublasStatus_t cublasGemm(cudaDataType_t dtype,
reinterpret_cast<const __half *>(&beta_fp16),
reinterpret_cast<__half *>(C),
ldc);
#endif
} else if (dtype == CUDA_R_16BF) {
#if CUDA_VERSION >= 11000
return cublasGemmEx(handle,
......@@ -174,6 +196,31 @@ inline cublasStatus_t cublasGemmStridedBatched(cudaDataType_t dtype,
strideC,
batchCount);
} else if (dtype == CUDA_R_16F) {
#if CUDA_VERSION >= 11000
return cublasGemmStridedBatchedEx(handle,
transa,
transb,
m,
n,
k,
&alpha,
A,
CUDA_R_16F,
lda,
strideA,
B,
CUDA_R_16F,
ldb,
strideB,
&beta,
C,
CUDA_R_16F,
ldc,
strideC,
batchCount,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#else
common::float16 alpha_fp16{alpha};
common::float16 beta_fp16{beta};
return cublasHgemmStridedBatched(
......@@ -195,6 +242,7 @@ inline cublasStatus_t cublasGemmStridedBatched(cudaDataType_t dtype,
ldc,
strideC,
batchCount);
#endif
} else if (dtype == CUDA_R_16BF) {
#if CUDA_VERSION >= 11000
return cublasGemmStridedBatchedEx(handle,
......@@ -279,6 +327,28 @@ inline cublasStatus_t cublasGemmBatched(cudaDataType_t dtype,
ldc,
batchCount);
} else if (dtype == CUDA_R_16F) {
#if CUDA_VERSION >= 11000
return cublasGemmBatchedEx(handle,
transa,
transb,
m,
n,
k,
&alpha,
A,
CUDA_R_16F,
lda,
B,
CUDA_R_16F,
ldb,
&beta,
C,
CUDA_R_16F,
ldc,
batchCount,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#else
__half alpha_fp16{alpha};
__half beta_fp16{beta};
return cublasHgemmBatched(handle,
......@@ -296,6 +366,7 @@ inline cublasStatus_t cublasGemmBatched(cudaDataType_t dtype,
reinterpret_cast<__half **>(C),
ldc,
batchCount);
#endif
} else if (dtype == CUDA_R_16BF) {
#if CUDA_VERSION >= 11000
return cublasGemmBatchedEx(handle,
......
......@@ -110,6 +110,24 @@ CINN_REGISTER_HELPER(cuda_intrinsics_reduce) {
#undef REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL
#define REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL(REDUCE_TYPE, DTYPE) \
REGISTER_FACKED_EXTERN_FUNC_HELPER( \
cinn_block_reduce_##REDUCE_TYPE##_internal_shm, target) \
.SetRetType<DTYPE>() \
.AddInputType<DTYPE>() \
.AddInputType<cinn_buffer_t *>() \
.End();
EXPAND_REDUCE_INT32_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_INT64_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_BF16_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_FP16_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_FP32_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_FP64_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_BOOL_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
#undef REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL
#define REGISTER_BLOCK_REDUCE_FUNC_IMPL(REDUCE_TYPE, DTYPE) \
REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_block_reduce_##REDUCE_TYPE, target) \
.SetRetType<DTYPE>() \
......
......@@ -162,6 +162,8 @@ void cinn_call_cublas(void *v_args,
int n = trans_o ? (trans_b ? b3 : b4) : (trans_a ? a4 : a3);
int k = trans_a ? a3 : a4;
VLOG(3) << "m: " << m << ", n: " << n << ", k: " << k;
cublasOperation_t trans_op_l = trans_o
? (trans_a ? CUBLAS_OP_N : CUBLAS_OP_T)
: (trans_b ? CUBLAS_OP_T : CUBLAS_OP_N);
......@@ -245,7 +247,7 @@ void cinn_call_cublas(void *v_args,
int batch = std::max(a2, b2);
VLOG(3) << "call cublasGemmStridedBatched with a1*b1 = 1, stride_l = "
<< stride_l << ", stride_r = " << stride_r
<< ", batch = " << batch;
<< ", batch = " << batch << ", dtype = " << cuda_dtype;
cinn::utils::RecordEvent record_run("Call cublasGemmStridedBatched",
cinn::utils::EventType::kInstruction);
CUBLAS_CALL(cublasGemmStridedBatched(cuda_dtype,
......
......@@ -597,9 +597,9 @@ __host__ __device__ inline bool(isfinite)(const float16& a) {
__host__ __device__ inline float16(abs)(const float16& a) {
#if defined(CINN_CUDA_FP16) && (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
return float16(__habs(a.to_half()));
return static_cast<float16>(__habs(a.to_half()));
#else
return float16(fabsf(static_cast<float>(a)));
return static_cast<float16>(fabsf(static_cast<float>(a)));
#endif
}
......
......@@ -22,7 +22,7 @@
#include "paddle/cinn/runtime/flags.h"
#include "paddle/cinn/utils/string.h"
DECLARE_string(cinn_check_fusion_accuracy_pass);
PD_DECLARE_string(cinn_check_fusion_accuracy_pass);
namespace cinn {
namespace runtime {
......
......@@ -14,7 +14,6 @@
#include "paddle/cinn/runtime/flags.h"
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <sys/stat.h>
#include <sys/types.h>
......@@ -23,169 +22,214 @@
#include <unordered_set>
#include "paddle/cinn/common/target.h"
#include "paddle/utils/flags.h"
#ifdef CINN_WITH_CUDNN
DEFINE_bool(cinn_cudnn_deterministic,
false,
"Whether allow using an autotuning algorithm for convolution "
"operator. The autotuning algorithm may be non-deterministic. If "
"true, the algorithm is deterministic.");
PD_DEFINE_bool(
cinn_cudnn_deterministic,
false,
"Whether allow using an autotuning algorithm for convolution "
"operator. The autotuning algorithm may be non-deterministic. If "
"true, the algorithm is deterministic.");
#endif
using ::GFLAGS_NAMESPACE::BoolFromEnv;
using ::GFLAGS_NAMESPACE::DoubleFromEnv;
using ::GFLAGS_NAMESPACE::Int32FromEnv;
using ::GFLAGS_NAMESPACE::Int64FromEnv;
using ::GFLAGS_NAMESPACE::StringFromEnv;
DEFINE_string(cinn_x86_builtin_code_root,
StringFromEnv("FLAGS_cinn_x86_builtin_code_root", ""),
"");
DEFINE_string(cinn_nvcc_cmd_path,
StringFromEnv("FLAGS_cinn_nvcc_cmd_path", "/usr/local/cuda/bin"),
"Setting nvcc default path!");
DEFINE_int32(cinn_parallel_compile_thread,
Int32FromEnv("FLAGS_cinn_parallel_compile_thread",
(std::thread::hardware_concurrency() >> 1)),
"How much thread the parallel compile used.");
DEFINE_bool(cinn_use_op_fusion,
BoolFromEnv("FLAGS_cinn_use_op_fusion", true),
"Whether to use op fusion pass.");
DEFINE_bool(general_fusion_merge_pass,
BoolFromEnv("FLAGS_general_fusion_merge_pass", true),
"Whether to use general fusion_merge pass.");
DEFINE_bool(cinn_use_common_subexpression_elimination,
BoolFromEnv("FLAGS_cinn_use_common_subexpression_elimination",
false),
"Whether to use common subexpression elimination pass.");
DEFINE_string(
cinn_custom_call_deny_ops,
StringFromEnv("FLAGS_cinn_custom_call_deny_ops", ""),
"a blacklist of op are denied by MarkCustomCallOps pass, separated by ;");
using ::paddle::flags::BoolFromEnv;
using ::paddle::flags::DoubleFromEnv;
using ::paddle::flags::Int32FromEnv;
using ::paddle::flags::Int64FromEnv;
using ::paddle::flags::StringFromEnv;
PD_DEFINE_string(cinn_x86_builtin_code_root,
StringFromEnv("FLAGS_cinn_x86_builtin_code_root", ""),
"");
DEFINE_bool(cinn_use_custom_call,
BoolFromEnv("FLAGS_cinn_use_custom_call", true),
"Whether to use custom_call for ops with external_api registered");
PD_DEFINE_string(cinn_nvcc_cmd_path,
StringFromEnv("FLAGS_cinn_nvcc_cmd_path",
"/usr/local/cuda/bin"),
"Setting nvcc default path!");
DEFINE_bool(cinn_use_fill_constant_folding,
BoolFromEnv("FLAGS_cinn_use_fill_constant_folding", false),
"Whether use the FillConstantFolding pass.");
PD_DEFINE_int32(cinn_parallel_compile_thread,
Int32FromEnv("FLAGS_cinn_parallel_compile_thread",
(std::thread::hardware_concurrency() >> 1)),
"How much thread the parallel compile used.");
DEFINE_string(cinn_check_fusion_accuracy_pass,
StringFromEnv("FLAGS_cinn_check_fusion_accuracy_pass", ""),
"Check the correct of fusion kernels, if the results not "
"satisfied 'allclose(rtol=1e-05f, atol=1e-08f)', "
"report error and exited.");
PD_DEFINE_bool(cinn_use_op_fusion,
BoolFromEnv("FLAGS_cinn_use_op_fusion", true),
"Whether to use op fusion pass.");
DEFINE_bool(cinn_use_cuda_vectorize,
BoolFromEnv("FLAGS_cinn_use_cuda_vectorize", false),
"Whether use cuda vectroize on schedule config");
PD_DEFINE_bool(general_fusion_merge_pass,
BoolFromEnv("FLAGS_general_fusion_merge_pass", true),
"Whether to use general fusion_merge pass.");
DEFINE_bool(use_reduce_split_pass,
BoolFromEnv("FLAGS_use_reduce_split_pass", false),
"Whether use reduce split pass.");
PD_DEFINE_bool(cinn_new_group_scheduler,
BoolFromEnv("FLAGS_cinn_new_group_scheduler", false),
"Whether to use new group scheduler.");
DEFINE_bool(cinn_use_dense_merge_pass,
BoolFromEnv("FLAGS_cinn_use_dense_merge_pass", false),
"Whether use dense merge pass.");
PD_DEFINE_bool(cinn_bucket_compile,
BoolFromEnv("FLAGS_cinn_bucket_compile", false),
"Whether to enable bucket compile for dynamic shape.");
DEFINE_bool(nvrtc_compile_to_cubin,
BoolFromEnv("FLAGS_nvrtc_compile_to_cubin", false),
"Whether nvrtc compile cuda source into cubin instead of ptx (only "
"works after cuda-11.1).");
PD_DEFINE_bool(cinn_use_common_subexpression_elimination,
BoolFromEnv("FLAGS_cinn_use_common_subexpression_elimination",
false),
"Whether to use common subexpression elimination pass.");
PD_DEFINE_string(
cinn_custom_call_deny_ops,
StringFromEnv("FLAGS_cinn_custom_call_deny_ops", ""),
"a blacklist of op are denied by MarkCustomCallOps pass, separated by ;");
DEFINE_bool(cinn_compile_with_nvrtc,
BoolFromEnv("FLAGS_cinn_compile_with_nvrtc", true),
"Whether nvrtc compile cuda source with nvrtc(default nvcc).");
PD_DEFINE_bool(cinn_enable_map_expr,
BoolFromEnv("FLAGS_cinn_enable_map_expr", false),
"It controls whether to use cinn with map_expr");
PD_DEFINE_bool(cinn_enable_map_expr_schedule,
BoolFromEnv("FLAGS_cinn_enable_map_expr_schedule", false),
"It controls whether to schedule by map_expr");
PD_DEFINE_bool(cinn_enable_map_expr_inline,
BoolFromEnv("FLAGS_cinn_enable_map_expr_inline", false),
"It controls whether to inline by map_expr");
PD_DEFINE_bool(
cinn_use_custom_call,
BoolFromEnv("FLAGS_cinn_use_custom_call", true),
"Whether to use custom_call for ops with external_api registered");
PD_DEFINE_bool(cinn_use_fill_constant_folding,
BoolFromEnv("FLAGS_cinn_use_fill_constant_folding", false),
"Whether use the FillConstantFolding pass.");
PD_DEFINE_string(cinn_check_fusion_accuracy_pass,
StringFromEnv("FLAGS_cinn_check_fusion_accuracy_pass", ""),
"Check the correct of fusion kernels, if the results not "
"satisfied 'allclose(rtol=1e-05f, atol=1e-08f)', "
"report error and exited.");
PD_DEFINE_bool(cinn_use_cuda_vectorize,
BoolFromEnv("FLAGS_cinn_use_cuda_vectorize", false),
"Whether use cuda vectroize on schedule config");
PD_DEFINE_bool(use_reduce_split_pass,
BoolFromEnv("FLAGS_use_reduce_split_pass", false),
"Whether use reduce split pass.");
PD_DEFINE_bool(cinn_use_dense_merge_pass,
BoolFromEnv("FLAGS_cinn_use_dense_merge_pass", false),
"Whether use dense merge pass.");
PD_DEFINE_bool(
nvrtc_compile_to_cubin,
BoolFromEnv("FLAGS_nvrtc_compile_to_cubin", false),
"Whether nvrtc compile cuda source into cubin instead of ptx (only "
"works after cuda-11.1).");
PD_DEFINE_bool(cinn_compile_with_nvrtc,
BoolFromEnv("FLAGS_cinn_compile_with_nvrtc", true),
"Whether nvrtc compile cuda source with nvrtc(default nvcc).");
PD_DEFINE_bool(
cinn_nvrtc_cubin_with_fmad,
BoolFromEnv("FLAGS_cinn_nvrtc_cubin_with_fmad", true),
"Whether nvrtc enables fmad when compile to cubin. This flag only works "
"when FLAGS_nvrtc_compile_to_cubin=true. Fmad is the cuda speed up "
"technique which contract fp mulitplication and addition/subtraction into "
"multiply-add operation. It may result in different fp precision.");
// FLAGS for performance analysis and accuracy debug
DEFINE_bool(cinn_sync_run,
BoolFromEnv("FLAGS_cinn_sync_run", false),
"Whether sync all devices after each instruction run, which is "
"used for debug.");
DEFINE_string(cinn_self_check_accuracy,
StringFromEnv("FLAGS_cinn_self_check_accuracy", ""),
"Whether self-check accuracy after each instruction run, which "
"is used for debug.");
DEFINE_int64(cinn_self_check_accuracy_num,
Int64FromEnv("FLAGS_cinn_self_check_accuracy_num", 0L),
"Set self-check accuracy print numel, which is used for debug.");
DEFINE_string(cinn_fusion_groups_graphviz_dir,
StringFromEnv("FLAGS_cinn_fusion_groups_graphviz_dir", ""),
"Specify the directory path of dot file of graph, which is used "
"for debug.");
DEFINE_string(cinn_source_code_save_path,
StringFromEnv("FLAGS_cinn_source_code_save_path", ""),
"Specify the directory path of generated source code, which is "
"used for debug.");
DEFINE_string(cinn_dump_group_lowered_func,
StringFromEnv("FLAGS_cinn_dump_group_lowered_func", ""),
"Specify the path for dump lowered functions by group, which is "
"used for debug.");
DEFINE_string(
PD_DEFINE_bool(cinn_sync_run,
BoolFromEnv("FLAGS_cinn_sync_run", false),
"Whether sync all devices after each instruction run, which is "
"used for debug.");
PD_DEFINE_string(
cinn_self_check_accuracy,
StringFromEnv("FLAGS_cinn_self_check_accuracy", ""),
"Whether self-check accuracy after each instruction run, which "
"is used for debug.");
PD_DEFINE_int64(
cinn_self_check_accuracy_num,
Int64FromEnv("FLAGS_cinn_self_check_accuracy_num", 0L),
"Set self-check accuracy print numel, which is used for debug.");
PD_DEFINE_string(
cinn_fusion_groups_graphviz_dir,
StringFromEnv("FLAGS_cinn_fusion_groups_graphviz_dir", ""),
"Specify the directory path of dot file of graph, which is used "
"for debug.");
PD_DEFINE_string(
cinn_source_code_save_path,
StringFromEnv("FLAGS_cinn_source_code_save_path", ""),
"Specify the directory path of generated source code, which is "
"used for debug.");
PD_DEFINE_string(
cinn_dump_group_lowered_func,
StringFromEnv("FLAGS_cinn_dump_group_lowered_func", ""),
"Specify the path for dump lowered functions by group, which is "
"used for debug.");
PD_DEFINE_string(
cinn_dump_group_source_code,
StringFromEnv("FLAGS_cinn_dump_group_source_code", ""),
"Specify the path for dump source code by group, which is used for debug.");
DEFINE_string(
PD_DEFINE_string(
cinn_dump_group_ptx,
StringFromEnv("FLAGS_cinn_dump_group_ptx", ""),
"Specify the path for dump ptx by group, which is used for debug.");
DEFINE_string(
PD_DEFINE_string(
cinn_dump_group_instruction,
StringFromEnv("FLAGS_cinn_dump_group_instruction", ""),
"Specify the path for dump instruction by group, which is used for debug.");
DEFINE_string(cinn_pass_visualize_dir,
StringFromEnv("FLAGS_cinn_pass_visualize_dir", ""),
"Specify the directory path of pass visualize file of graph, "
"which is used for debug.");
DEFINE_bool(enable_auto_tuner,
BoolFromEnv("FLAGS_enable_auto_tuner", false),
"Whether enable auto tuner.");
DEFINE_bool(auto_schedule_use_cost_model,
BoolFromEnv("FLAGS_auto_schedule_use_cost_model", true),
"Whether to use cost model in auto schedule, this is an "
"on-developing flag and it will be removed when "
"cost model is stable.");
DEFINE_bool(enhance_vertical_fusion_with_recompute,
BoolFromEnv("FLAGS_enhance_vertical_fusion_with_recompute", true),
"Whether to enhance check logic on vertical fusion with recompute");
DEFINE_bool(verbose_function_register,
BoolFromEnv("FLAGS_verbose_function_register", false),
"Whether to verbose function regist log. This will only work if "
"CINN build with flag -DWITH_DEBUG=ON.");
DEFINE_int32(cinn_profiler_state,
Int32FromEnv("FLAGS_cinn_profiler_state", -1),
"Specify the ProfilerState by Int in CINN, 0 for kDisabled, 1 for "
"kCPU, 2 for kCUDA, 3 for kAll, default 0.");
DEFINE_int32(cinn_error_message_level,
Int32FromEnv("FLAGS_cinn_error_message_level", 0),
"Specify the level of printing error message in the schedule."
"0 means short, 1 means detailed.");
DEFINE_double(cinn_infer_model_version,
DoubleFromEnv("FLAGS_cinn_infer_model_version", 2.0),
"Paddle has different model format in inference model. We use "
"a flag to load different versions.");
PD_DEFINE_string(cinn_pass_visualize_dir,
StringFromEnv("FLAGS_cinn_pass_visualize_dir", ""),
"Specify the directory path of pass visualize file of graph, "
"which is used for debug.");
PD_DEFINE_bool(enable_auto_tuner,
BoolFromEnv("FLAGS_enable_auto_tuner", false),
"Whether enable auto tuner.");
PD_DEFINE_bool(auto_schedule_use_cost_model,
BoolFromEnv("FLAGS_auto_schedule_use_cost_model", true),
"Whether to use cost model in auto schedule, this is an "
"on-developing flag and it will be removed when "
"cost model is stable.");
PD_DEFINE_bool(
enhance_vertical_fusion_with_recompute,
BoolFromEnv("FLAGS_enhance_vertical_fusion_with_recompute", true),
"Whether to enhance check logic on vertical fusion with recompute");
PD_DEFINE_bool(verbose_function_register,
BoolFromEnv("FLAGS_verbose_function_register", false),
"Whether to verbose function regist log. This will only work if "
"CINN build with flag -DWITH_DEBUG=ON.");
PD_DEFINE_int32(
cinn_profiler_state,
Int32FromEnv("FLAGS_cinn_profiler_state", -1),
"Specify the ProfilerState by Int in CINN, 0 for kDisabled, 1 for "
"kCPU, 2 for kCUDA, 3 for kAll, default 0.");
PD_DEFINE_int32(cinn_error_message_level,
Int32FromEnv("FLAGS_cinn_error_message_level", 0),
"Specify the level of printing error message in the schedule."
"0 means short, 1 means detailed.");
PD_DEFINE_double(cinn_infer_model_version,
DoubleFromEnv("FLAGS_cinn_infer_model_version", 2.0),
"Paddle has different model format in inference model. We use "
"a flag to load different versions.");
PD_DEFINE_bool(cinn_use_cutlass,
BoolFromEnv("FLAGS_cinn_use_cutlass", false),
"Whether to use cutlass kernels");
namespace cinn {
namespace runtime {
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include "paddle/cinn/common/type.h"
#include "paddle/cinn/utils/type_defs.h"
#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/phi/common/data_type.h"
namespace cinn {
namespace utils {
using NewIR_AttributeMap = std::unordered_map<std::string, ::ir::Attribute>;
Attribute ConvertAttribute(const ::ir::Attribute& src_attr) {
Attribute dst_attr;
if (src_attr.isa<::ir::BoolAttribute>()) {
dst_attr = src_attr.dyn_cast<::ir::BoolAttribute>().data();
} else if (src_attr.isa<::ir::FloatAttribute>()) {
dst_attr = src_attr.dyn_cast<::ir::FloatAttribute>().data();
} else if (src_attr.isa<::ir::Int32Attribute>()) {
dst_attr = src_attr.dyn_cast<::ir::Int32Attribute>().data();
} else if (src_attr.isa<::ir::StrAttribute>()) {
dst_attr = src_attr.dyn_cast<::ir::StrAttribute>().AsString();
} else if (src_attr.isa<::ir::Int64Attribute>()) {
dst_attr = src_attr.dyn_cast<::ir::Int64Attribute>().data();
} else if (src_attr.isa<::ir::DoubleAttribute>()) {
dst_attr = src_attr.dyn_cast<::ir::DoubleAttribute>().data();
} else if (src_attr.isa<paddle::dialect::IntArrayAttribute>()) {
auto& arr = src_attr.dyn_cast<paddle::dialect::IntArrayAttribute>()
.data()
.GetData();
std::vector<int> val(arr.begin(), arr.end());
dst_attr = val;
} else if (src_attr.isa<paddle::dialect::DataTypeAttribute>()) {
auto dtype = src_attr.dyn_cast<paddle::dialect::DataTypeAttribute>().data();
dst_attr = phi::DataTypeToString(dtype);
} else {
LOG(FATAL) << "unknown Attribute: " << src_attr;
}
return dst_attr;
}
AttributeMap ConvertAttributes(const NewIR_AttributeMap& src_attrs) {
AttributeMap dst_attrs;
for (auto& item : src_attrs) {
VLOG(4) << "deal with " << item.first;
if (item.second.isa<paddle::dialect::PlaceAttribute>()) {
auto is_cpu =
item.second.dyn_cast<paddle::dialect::PlaceAttribute>().data() ==
phi::CPUPlace();
dst_attrs["force_cpu"] = is_cpu;
} else {
dst_attrs[item.first] = std::move(ConvertAttribute(item.second));
}
}
VLOG(4) << "dst_attrs.size(): " << dst_attrs.size();
return dst_attrs;
}
#define CASE_TYPE(src, dst) \
else if (type.isa<::ir::src>()) return common::dst();
common::Type ConvertIRType(::ir::Type type) {
if (type.isa<::ir::BFloat16Type>()) return common::BF16();
CASE_TYPE(Float16Type, F16)
CASE_TYPE(Float32Type, F32)
CASE_TYPE(Float64Type, F64)
CASE_TYPE(Int8Type, I8)
CASE_TYPE(UInt8Type, UI8)
CASE_TYPE(Int16Type, I16)
CASE_TYPE(Int32Type, I32)
CASE_TYPE(Int64Type, I64)
CASE_TYPE(IndexType, I32)
CASE_TYPE(BoolType, UI1)
LOG(FATAL) << "unknown ir::Type " << type;
}
} // namespace utils
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace cinn {
namespace utils {
#include <array>
#include <string>
#include <string_view>
#include <utility>
template <typename E, E V>
constexpr auto PrettyName() {
std::string_view name{__PRETTY_FUNCTION__, sizeof(__PRETTY_FUNCTION__) - 2};
name.remove_prefix(name.find_last_of(" ") + 1);
if (name.front() == '(') name.remove_prefix(name.size());
return name;
}
template <typename E, E V>
constexpr bool IsValidEnum() {
return !PrettyName<E, V>().empty();
}
template <int... Seq>
constexpr auto MakeIntegerSequence(std::integer_sequence<int, Seq...>) {
return std::integer_sequence<int, (Seq)...>();
}
constexpr auto NormalIntegerSequence =
MakeIntegerSequence(std::make_integer_sequence<int, 32>());
template <typename E, int... Seq>
constexpr size_t GetEnumSize(std::integer_sequence<int, Seq...>) {
constexpr std::array<bool, sizeof...(Seq)> valid{
IsValidEnum<E, static_cast<E>(Seq)>()...};
constexpr std::size_t count =
[](decltype((valid)) v) constexpr noexcept->std::size_t {
auto cnt = std::size_t{0};
for (auto b : v) {
if (b) {
++cnt;
}
}
return cnt;
}
(valid);
return count;
}
template <typename E, int... Seq>
constexpr auto GetAllValidValues(std::integer_sequence<int, Seq...>) {
constexpr std::size_t count = sizeof...(Seq);
constexpr std::array<bool, count> valid{
IsValidEnum<E, static_cast<E>(Seq)>()...};
constexpr std::array<int, count> seq{Seq...};
std::array<int, GetEnumSize<E>(NormalIntegerSequence)> values{};
for (std::size_t i = 0, v = 0; i < count; ++i) {
if (valid[i]) {
values[v++] = seq[i];
}
}
return values;
}
template <typename E, int... Seq>
constexpr auto GetAllValidNames(std::integer_sequence<int, Seq...>) {
constexpr std::array<std::string_view, sizeof...(Seq)> names{
PrettyName<E, static_cast<E>(Seq)>()...};
std::array<std::string_view, GetEnumSize<E>(NormalIntegerSequence)>
valid_names{};
for (std::size_t i = 0, v = 0; i < names.size(); ++i) {
if (!names[i].empty()) {
valid_names[v++] = names[i];
}
}
return valid_names;
}
template <typename E>
constexpr std::string_view Enum2String(E V) {
constexpr auto names = GetAllValidNames<E>(NormalIntegerSequence);
constexpr auto values = GetAllValidValues<E>(NormalIntegerSequence);
constexpr auto size = GetEnumSize<E>(NormalIntegerSequence);
for (size_t i = 0; i < size; ++i) {
if (static_cast<int>(V) == values[i]) {
return names[i];
}
}
return std::to_string(static_cast<int>(V));
}
} // namespace utils
} // namespace cinn
......@@ -24,7 +24,7 @@ std::string ErrorHandler::FormatErrorMessage(
? DetailedErrorMessage()
: GeneralErrorMessage();
os << "[Error info] " << err_msg;
os << err_msg;
return os.str();
}
......
......@@ -113,9 +113,6 @@ struct EnforceNotMet : public std::exception {
std::string err_str_;
};
#ifdef PADDLE_THROW
#define CINN_THROW PADDLE_THROW
#else
#define CINN_THROW(...) \
do { \
try { \
......@@ -125,7 +122,6 @@ struct EnforceNotMet : public std::exception {
throw; \
} \
} while (0)
#endif
} // namespace enforce
/**
......@@ -165,6 +161,9 @@ class ErrorHandler {
* \brief Format the error message.
*/
std::string FormatErrorMessage(const ErrorMessageLevel& err_msg_level) const;
protected:
const std::string indent_str_{" "};
};
} // namespace utils
......
......@@ -14,7 +14,7 @@
#include "paddle/cinn/utils/profiler.h"
#include <gflags/gflags.h>
#include "paddle/utils/flags.h"
#ifdef CINN_WITH_NVTX
#include <nvToolsExt.h>
......@@ -27,7 +27,7 @@
#endif
#include <chrono>
DECLARE_int32(cinn_profiler_state);
PD_DECLARE_int32(cinn_profiler_state);
namespace cinn {
namespace utils {
......
file(GLOB common_srcs "*.cc")
if(WIN32)
set(COMMON_NAME
common.dll
CACHE INTERNAL "" FORCE)
elseif(APPLE)
set(COMMON_NAME
libcommon.dylib
CACHE INTERNAL "" FORCE)
else()
set(COMMON_NAME
libcommon.so
CACHE INTERNAL "" FORCE)
endif()
set(COMMON_LIB
"${CMAKE_CURRENT_BINARY_DIR}/${COMMON_NAME}"
CACHE FILEPATH "COMMON Library" FORCE)
set(COMMON_BUILD_TYPE
SHARED
CACHE INTERNAL "" FORCE)
cc_library(common ${COMMON_BUILD_TYPE} SRCS ${common_srcs})
if(WIN32)
set_property(TARGET common PROPERTY WINDOWS_EXPORT_ALL_SYMBOLS ON)
endif()
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include "paddle/common/enforce.h"
#include "paddle/common/unroll_array_ops.h"
namespace common {
template <typename T, size_t N>
class Array {
public:
static constexpr size_t kSize = N;
HOSTDEVICE inline Array() {}
template <typename... Args>
HOSTDEVICE inline explicit Array(const T &val, Args... args) {
static_assert(N == sizeof...(Args) + 1, "Invalid argument");
UnrollVarArgsAssign<T>::Run(data_, val, args...);
}
HOSTDEVICE inline void Fill(const T &val) {
UnrollFillConstant<N>::Run(data_, val);
}
HOSTDEVICE inline const T *Get() const { return data_; }
HOSTDEVICE inline T *GetMutable() { return data_; }
HOSTDEVICE inline T &operator[](size_t i) { return *advance(data_, i); }
// Writing "return data_[i]" would cause compilation warning/error:
// "array subscript is above array bound" in Python 35 CI.
// It seems that it is a false warning of GCC if we do not check the bounds
// of array index. But for better performance, we do not check in operator[]
// like what is in STL. If users want to check the bounds, use at() instead
HOSTDEVICE inline const T &operator[](size_t i) const {
return *advance(data_, i);
}
HOSTDEVICE inline T &at(size_t i) {
#if !defined(__CUDA_ARCH__) && !defined(__HIPCC__)
COMMON_ENFORCE_LT(
i, N, common::errors::OutOfRange("Array index out of bounds."));
#endif
return (*this)[i];
}
HOSTDEVICE inline const T &at(size_t i) const {
#if !defined(__CUDA_ARCH__) && !defined(__HIPCC__)
COMMON_ENFORCE_LT(
i, N, common::errors::OutOfRange("Array index out of bounds."));
#endif
return (*this)[i];
}
HOSTDEVICE constexpr size_t size() const { return N; }
HOSTDEVICE inline bool operator==(const Array<T, N> &other) const {
return UnrollCompare<N>::Run(data_, other.data_);
}
HOSTDEVICE inline bool operator!=(const Array<T, N> &other) const {
return !(*this == other);
}
private:
template <typename U>
HOSTDEVICE static inline U *advance(U *ptr, size_t i) {
return ptr + i;
}
T data_[N] = {};
};
template <typename T>
class Array<T, 0> {
public:
static constexpr size_t kSize = 0;
HOSTDEVICE inline Array() {}
HOSTDEVICE inline void Fill(const T &val) {}
HOSTDEVICE inline constexpr T *Get() const { return nullptr; }
// Add constexpr to GetMutable() cause warning in MAC
HOSTDEVICE inline T *GetMutable() { return nullptr; }
HOSTDEVICE inline T &operator[](size_t) {
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
// HIP and CUDA will have compile error, if use "obj()"
// function declared in block scope cannot have 'static' storage class
static T obj{};
return obj;
#else
COMMON_THROW(common::errors::Unavailable("Array<T, 0> has no element."));
#endif
}
HOSTDEVICE inline const T &operator[](size_t) const {
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
// HIP and CUDA will have compile error, if use "obj()"
// function declared in block scope cannot have 'static' storage class
static const T obj{};
return obj;
#else
COMMON_THROW(common::errors::Unavailable("Array<T, 0> has no element."));
#endif
}
HOSTDEVICE inline T &at(size_t i) { return (*this)[i]; }
HOSTDEVICE inline const T &at(size_t i) const { return (*this)[i]; }
HOSTDEVICE constexpr size_t size() const { return 0; }
HOSTDEVICE constexpr bool operator==(const Array<T, 0> &other) const {
return true;
}
HOSTDEVICE constexpr bool operator!=(const Array<T, 0> &other) const {
return false;
}
};
} // namespace common
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/common/ddim.h"
#include <set>
namespace common {
DDim make_ddim(std::initializer_list<int64_t> dims) {
return DDim(dims.begin(), static_cast<int>(dims.size()));
}
DDim make_ddim(const std::vector<int64_t>& dims) {
return DDim(dims.data(), static_cast<int>(dims.size()));
}
DDim make_ddim(const std::vector<int>& dims) {
return DDim(dims.data(), static_cast<int>(dims.size()));
}
struct DDimEqualityVisitor {
explicit DDimEqualityVisitor(const int64_t* d) : d_(d) {}
template <int D>
inline bool operator()(const Dim<D>& self) const {
return UnrollCompare<D>::Run(self.Get(), d_);
}
const int64_t* d_;
};
bool DDim::operator==(const DDim& d) const {
if (size() == -1 && d.size() == -1) {
return true;
} else if (size() == -1 || d.size() == -1) {
return false;
} else {
return size() == d.size() &&
this->apply_visitor(DDimEqualityVisitor(d.Get()));
}
}
bool DDim::operator!=(const DDim& d) const { return !(*this == d); }
std::string DDim::to_str() const {
std::stringstream ss;
ss << '[';
if (rank_ > 0) ss << dim_[0];
for (int i = 1; i < rank_; ++i) ss << ", " << dim_[i];
ss << ']';
return ss.str();
}
struct ProductVisitor {
template <int D>
inline int64_t operator()(const Dim<D>& dim) {
return product(dim);
}
};
int64_t product(const DDim& ddim) {
if (ddim.size() == -1) {
return 0;
}
return ddim.apply_visitor(ProductVisitor());
}
bool contain_unknown_dim(const DDim& ddim) {
for (int i = 0; i < ddim.size(); ++i) {
if (ddim[i] < 0) {
return true;
}
}
return false;
}
DDim slice_ddim(const DDim& dim, int begin, int end) {
COMMON_ENFORCE_EQ(
(begin >= 0 && end <= dim.size()),
true,
common::errors::InvalidArgument(
"[begin(%d), end(%d)) must be inside [0, %d) in ddim slice.",
begin,
end,
dim.size()));
// Constructor of DDim would check whether end - begin is valid
return DDim(dim.Get() + begin, end - begin);
}
int arity(const DDim& d) { return d.size(); }
struct DDimPrinter {
std::ostream& os;
explicit DDimPrinter(std::ostream& os_) : os(os_) {}
template <int D>
void operator()(const Dim<D>& t) {
os << t;
}
};
std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
if (ddim.size() == -1) {
return os;
}
ddim.apply_visitor(DDimPrinter(os));
return os;
}
DDim flatten_to_3d(const DDim& src, int num_row_dims, int num_col_dims) {
COMMON_ENFORCE_GE(src.size(),
3,
common::errors::InvalidArgument(
"The rank of src dim should be at least 3 "
"in flatten_to_3d, but received %d.",
src.size()));
COMMON_ENFORCE_EQ((num_row_dims >= 1 && num_row_dims < src.size()),
true,
common::errors::InvalidArgument(
"The num_row_dims should be inside [1, %d] "
"in flatten_to_3d, but received %d.",
src.size() - 1,
num_row_dims));
COMMON_ENFORCE_EQ((num_col_dims >= 2 && num_col_dims <= src.size()),
true,
common::errors::InvalidArgument(
"The num_col_dims should be inside [2, %d] "
"in flatten_to_3d, but received %d.",
src.size(),
num_col_dims));
COMMON_ENFORCE_GE(
num_col_dims,
num_row_dims,
common::errors::InvalidArgument(
"The num_row_dims should be less than num_col_dims in flatten_to_3d,"
"but received num_row_dims = %d, num_col_dims = %d.",
num_row_dims,
num_col_dims));
return DDim({product(slice_ddim(src, 0, num_row_dims)),
product(slice_ddim(src, num_row_dims, num_col_dims)),
product(slice_ddim(src, num_col_dims, src.size()))});
}
DDim flatten_to_2d(const DDim& src, int num_col_dims) {
return DDim({product(slice_ddim(src, 0, num_col_dims)),
product(slice_ddim(src, num_col_dims, src.size()))});
}
DDim flatten_to_1d(const DDim& src) { return DDim({product(src)}); }
DDim stride(const DDim& ddim) {
DDim strides;
strides.rank_ = ddim.size();
if (ddim.size() > 0) strides[ddim.size() - 1] = 1;
for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i + 1];
}
return strides;
}
DDim stride_numel(const DDim& ddim) {
DDim strides;
strides.rank_ = ddim.size();
if (ddim.size() > 0) strides[ddim.size() - 1] = ddim[ddim.size() - 1];
for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i];
}
return strides;
}
DDim DDim::reshape(std::vector<int>& shape) const {
const DDim& in_dims = *this;
for (int i = 0; i < static_cast<int>(shape.size()); ++i) {
if (shape[i] == 0) {
shape[i] = static_cast<int>(in_dims.at(i));
}
}
// Dim marked as "-1" must be inferred
auto it = std::find(shape.begin(), shape.end(), -1);
if (it != shape.end()) {
int index = static_cast<int>(std::distance(shape.begin(), it));
int reshape_out_product =
std::accumulate(shape.begin(), shape.end(), -1, std::multiplies<int>());
shape[index] = static_cast<int>(product(in_dims)) / reshape_out_product;
}
return common::make_ddim(shape);
}
DDim DDim::transpose(const std::vector<int>& axis) const {
const DDim& in_dims = *this;
DDim out_dims(in_dims);
for (int i = 0; i < static_cast<int>(axis.size()); i++) {
out_dims[i] = in_dims[axis[i]];
}
return out_dims;
}
} // namespace common
namespace std {
std::size_t hash<common::DDim>::operator()(common::DDim const& ddim) const {
int ndim = ddim.size();
std::size_t seed = ndim;
for (int i = 0; i < ndim; ++i) {
seed ^= ddim.Get()[i] + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
return seed;
}
} // namespace std
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment