Commit 992bec46 authored by “yuguo”'s avatar “yuguo”
Browse files

2.5

parent 0259837d
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/info_registry.h"
namespace cinn {
namespace common {} // namespace common
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <absl/container/flat_hash_map.h>
#include <absl/types/any.h>
#include <string>
namespace cinn {
namespace common {
/**
* Key value.
*/
class InfoRegistry {
public:
template <typename T>
T& Get(const std::string& key);
size_t size() const { return data_.size(); }
void Clear() { data_.clear(); }
private:
absl::flat_hash_map<std::string, absl::any> data_;
};
template <typename T>
T& InfoRegistry::Get(const std::string& key) {
auto it = data_.find(key);
if (it == data_.end()) {
data_[key] = T();
}
return absl::any_cast<T&>(data_[key]);
}
} // namespace common
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/ir_util.h"
#include <algorithm>
#include <unordered_set>
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/cast_simplify.h"
namespace cinn {
namespace common {
namespace {
// ramp + scalar or broadcast
Expr RampRelatedMul(ir::Ramp *ramp, Expr other) {
CHECK_EQ(other.type().ElementOf(), Int(32));
CHECK_EQ(ramp->base.type(), Int(32));
CHECK_EQ(ramp->stride.type(), Int(32));
auto *other_broadcast = other.As<ir::Broadcast>();
if (other_broadcast) {
CHECK_EQ(ramp->lanes, other_broadcast->lanes);
other = other_broadcast->value;
}
return ir::Ramp::Make(ramp->base * other, ramp->stride * other, ramp->lanes);
}
Expr RampRelatedMul(ir::Broadcast *broadcast, Expr other) {
CHECK_EQ(other.type().lanes(), 1);
return ir::Broadcast::Make(broadcast->value * other, broadcast->lanes);
}
// ramp * ramp
Expr RampRelatedMul(ir::Ramp *ramp, ir::Ramp *other) {
CINN_NOT_IMPLEMENTED
return Expr();
}
// ramp + scalar
Expr RampRelatedAdd(ir::Ramp *ramp, Expr other) {
CHECK_EQ(other.type().ElementOf(), Int(32));
auto *other_broadcast = other.As<ir::Broadcast>();
if (other_broadcast) {
CHECK_EQ(ramp->lanes, other_broadcast->lanes);
other = other_broadcast->value;
}
return ir::Ramp::Make(ramp->base + other, ramp->stride, ramp->lanes);
}
Expr RampRelatedAdd(ir::Broadcast *broadcast, Expr other) {
CHECK_EQ(other.type().lanes(), 1);
return ir::Broadcast::Make(broadcast->value + other, broadcast->lanes);
}
// ramp + ramp
Expr RampRelatedAdd(ir::Ramp *ramp, ir::Ramp *other) {
CHECK(ramp);
CHECK(other);
if (ramp->lanes == other->lanes) {
Expr base_add = common::AutoSimplify(ramp->base + other->base);
Expr stride_add = common::AutoSimplify(ramp->stride + other->stride);
VLOG(2) << base_add;
VLOG(2) << stride_add;
return ir::Ramp::Make(base_add, stride_add, ramp->lanes);
}
CINN_NOT_IMPLEMENTED
return Expr();
}
Expr RampRelatedAdd(Expr a, Expr b) {
auto *a_ramp = a.As<ir::Ramp>();
auto *b_ramp = b.As<ir::Ramp>();
auto *a_broadcast = a.As<ir::Broadcast>();
auto *b_broadcast = b.As<ir::Broadcast>();
if (a_ramp && !b_ramp && (b->type().lanes() == 1 || b_broadcast)) {
return RampRelatedAdd(a_ramp, b);
} else if (!a_ramp && b_ramp && (a->type().lanes() == 1 || a_broadcast)) {
return RampRelatedAdd(b_ramp, a);
} else if (!a_ramp && !b_ramp && !a->type().is_vector() &&
!b->type().is_vector()) {
return a + b;
} else if (a_ramp && b_ramp) { // a_ramp && b_ramp
return RampRelatedAdd(a_ramp, b_ramp);
} else if (a_broadcast && !b_broadcast) {
return RampRelatedAdd(a_broadcast, b);
} else if (!a_broadcast && b_broadcast) {
return RampRelatedAdd(b_broadcast, a);
} else if (a_broadcast && b_broadcast) {
CHECK_EQ(a_broadcast->lanes, b_broadcast->lanes);
return ir::Broadcast::Make(a_broadcast->value + b_broadcast->value,
a_broadcast->lanes);
} else {
CINN_NOT_IMPLEMENTED
}
}
Expr RampRelatedMul(Expr a, Expr b) {
auto *a_ramp = a.As<ir::Ramp>();
auto *b_ramp = b.As<ir::Ramp>();
auto *a_broadcast = a.As<ir::Broadcast>();
auto *b_broadcast = b.As<ir::Broadcast>();
if (a_ramp && !b_ramp && (!b->type().is_vector() || b_broadcast)) {
return RampRelatedMul(a_ramp, b);
} else if (!a_ramp && b_ramp && (a->type().is_vector() || a_broadcast)) {
return RampRelatedMul(b_ramp, a);
} else if (!a_ramp && !b_ramp && !a->type().is_vector() &&
!b->type().is_vector()) {
return a * b;
} else if (a_ramp && b_ramp) { // a_ramp && b_ramp
return RampRelatedMul(a_ramp, b_ramp);
} else if (a_broadcast && !b_broadcast) {
return RampRelatedMul(a_broadcast, b);
} else if (!a_broadcast && b_broadcast) {
return RampRelatedMul(b_broadcast, a);
} else if (a_broadcast && b_broadcast) {
CHECK_EQ(a_broadcast->lanes, b_broadcast->lanes);
return ir::Broadcast::Make(a_broadcast->value * b_broadcast->value,
a_broadcast->lanes);
} else {
VLOG(3) << "a,b: " << a << " " << b;
CINN_NOT_IMPLEMENTED
}
}
} // namespace
Expr IndiceToAbsOffset(const std::vector<Expr> &shape,
const std::vector<Expr> &indices) {
VLOG(3) << "Begin IndiceToAbsOffset";
VLOG(3) << "shape is : " << utils::Join(shape, ",");
VLOG(3) << "indices is : " << utils::Join(indices, ",");
CHECK_LE(shape.size(), indices.size());
Expr res;
for (int i = 0; i < shape.size(); i++) {
CHECK_EQ(shape[i].type(), Int(32));
Expr indice_prod = indices[i];
optim::CastSimplify(&indice_prod);
for (int j = i + 1; j < shape.size(); j++) {
indice_prod = RampRelatedMul(indice_prod, shape[j]);
}
if (res.defined()) {
res = RampRelatedAdd(res, indice_prod);
} else {
res = indice_prod;
}
}
return common::AutoSimplify(res);
}
Expr IndiceToAbsOffset(const std::vector<int> &shape,
const std::vector<Expr> &indices) {
std::vector<Expr> shape_;
for (int v : shape) shape_.push_back(Expr(v));
return IndiceToAbsOffset(shape, indices);
}
Expr PrecedingAxisToAbsOffset(const std::vector<Expr> &shape,
int preceding_n_axis) {
std::vector<Expr> indices;
for (int i = 0; i < preceding_n_axis; i++) indices.push_back(shape[i]);
return IndiceToAbsOffset(shape, indices);
}
namespace {
class SubstituteMutator : ir::IRMutator<ir::Expr *> {
public:
explicit SubstituteMutator(const std::map<const ir::_Var_ *, Expr> &var_map) {
for (auto &item : var_map) {
var_map_[item.first->name] = item.second;
}
}
void operator()(ir::Expr *expr) { Visit(expr); }
private:
void Visit(Expr *expr) { ir::IRMutator<>::Visit(expr, expr); }
void Visit(const ir::_Var_ *op, ir::Expr *expr) override {
auto it = var_map_.find(op->name);
if (it == var_map_.end()) return;
*expr = it->second;
}
Expr *expr_{};
std::map<std::string, Expr> var_map_;
};
} // namespace
void Substitute(Expr *expr, const std::map<const ir::_Var_ *, Expr> &var_map) {
SubstituteMutator mutator(var_map);
mutator(expr);
}
bool is_zero(Expr v) {
v = AutoSimplify(v);
auto *int_n = v.As<ir::IntImm>();
auto *float_n = v.As<ir::FloatImm>();
if (int_n) return int_n->value == 0;
if (float_n) return float_n->value = 0.f;
return false;
}
Expr CastIfNeeded(Expr body, Type type) {
if (body.type() == type) return body;
return ir::Cast::Make(type, body);
}
bool MathEqual(const Expr &a, const Expr &b) {
auto c = a - b;
c = AutoSimplify(c);
return is_zero(c);
}
Expr select(Expr cond, Expr true_value, Expr false_value) {
return ir::Select::Make(cond, true_value, false_value);
}
Expr and_all(const std::vector<Expr> &conds) {
CHECK(!conds.empty());
Expr res = conds.front();
for (int i = 1; i < conds.size(); i++) {
res = ir::And::Make(res, conds[i]);
}
return res;
}
Expr or_all(const std::vector<Expr> &conds) {
CHECK(!conds.empty());
Expr res = conds.front();
for (int i = 1; i < conds.size(); i++) {
res = ir::Or::Make(res, conds[i]);
}
return res;
}
void CheckTensorUniqueInExpr(Expr expr) {
auto tensor_uniq =
ir::CollectIRNodes(expr, [](const Expr *x) { return x->as_tensor(); });
absl::flat_hash_map<std::string, const ir::_Tensor_ *> tensor_names;
for (auto &t : tensor_uniq) {
auto *tp = t.as_tensor();
if (!tensor_names.count(tp->name)) {
tensor_names[tp->name] = tp;
} else {
CHECK_EQ(tensor_names[tp->name], tp)
<< "Found tensor not unique [" << tp->name
<< "]\nThe original expression is \n"
<< expr;
}
}
}
void CheckBufferUniqueInExpr(Expr expr) {
// the buffers exists in tensor and lowered functions.
CheckTensorUniqueInExpr(expr);
auto tensors =
ir::CollectIRNodes(expr, [](const Expr *x) { return x->as_tensor(); });
auto funcs = ir::CollectIRNodes(
expr, [](const Expr *x) { return x->as_lowered_func(); });
absl::flat_hash_map<std::string, const ir::_Buffer_ *> buffer_name;
auto check_buffer_uniq = [&](const ir::_Buffer_ *b) {
if (buffer_name.count(b->name)) {
CHECK_EQ(buffer_name[b->name], b);
} else {
buffer_name[b->name] = b->const_self();
}
};
for (auto &e : tensors) {
auto *t = e.as_tensor();
if (t->buffer.defined()) {
check_buffer_uniq(t->buffer->const_self());
}
}
for (auto &e : funcs) {
auto *f = e.as_lowered_func();
for (auto &b : f->temp_bufs) {
if (b.defined()) {
check_buffer_uniq(b->const_self());
}
}
}
}
Expr cast(Expr e, Type type) {
if (e.is_constant()) {
if (type.is_bool()) {
return Expr(static_cast<bool>(e.get_constant()));
} else if (type.is_int(8)) {
return Expr(static_cast<int8_t>(e.get_constant()));
} else if (type.is_int(16)) {
return Expr(static_cast<int16_t>(e.get_constant()));
} else if (type.is_int(32)) {
return Expr(static_cast<int32_t>(e.get_constant()));
} else if (type.is_int(64)) {
return Expr(static_cast<int64_t>(e.get_constant()));
} else if (type.is_uint(8)) {
return Expr(static_cast<uint8_t>(e.get_constant()));
} else if (type.is_uint(16)) {
return Expr(static_cast<uint16_t>(e.get_constant()));
} else if (type.is_uint(32)) {
return Expr(static_cast<uint32_t>(e.get_constant()));
} else if (type.is_uint(64)) {
return Expr(static_cast<uint64_t>(e.get_constant()));
} else if (type.is_float(32)) {
return Expr(static_cast<float>(e.get_constant()));
} else if (type.is_float(64)) {
return Expr(static_cast<double>(e.get_constant()));
} else if (type.is_bfloat16()) {
return Expr(static_cast<cinn::common::bfloat16>(e.get_constant()));
} else if (type.is_float16()) {
return Expr(static_cast<cinn::common::float16>(e.get_constant()));
} else {
CINN_NOT_IMPLEMENTED
}
}
return ir::Cast::Make(type, e);
}
std::vector<std::string> GatherItersToTensorProducer(
const std::string &target_tensor_name, Expr *expr) {
struct Visitor : public ir::IRMutator<> {
std::vector<std::string> iters;
const std::string &target_tensor_name;
explicit Visitor(const std::string &target_tensor_name)
: target_tensor_name(target_tensor_name) {}
std::vector<std::string> operator()(Expr *expr) {
ir::IRMutator<>::Visit(expr, expr);
return iters;
}
void Visit(const ir::Store *op, Expr *expr) {
if (op->tensor.as_tensor()->name == target_tensor_name) {
CHECK(iters.empty());
for (auto &e : for_stack) {
auto *for_n = e->As<ir::For>();
auto *polyfor_n = e->As<ir::PolyFor>();
if (for_n) {
iters.push_back(for_n->loop_var->name);
} else {
iters.push_back(polyfor_n->iterator->name);
}
}
}
}
void Visit(const ir::For *op, Expr *expr) {
for_stack.push_back(expr);
ir::IRMutator<>::Visit(op, expr);
for_stack.pop_back();
}
void Visit(const ir::PolyFor *op, Expr *expr) {
for_stack.push_back(expr);
ir::IRMutator<>::Visit(op, expr);
for_stack.pop_back();
}
std::vector<Expr *> for_stack;
};
return Visitor(target_tensor_name)(expr);
}
std::vector<Expr *> GetForloopStackToStore(Expr *expr,
const std::string &tensor_name) {
VLOG(4) << "search store " << tensor_name << " in expr:\n";
VLOG(4) << *expr;
struct Mutator : public ir::IRMutator<> {
std::vector<Expr *> forloop_stack;
bool found{false};
std::string tensor_name;
explicit Mutator(const std::string &tensor_name)
: tensor_name(tensor_name) {}
std::vector<Expr *> operator()(Expr *expr) {
ir::IRMutator<>::Visit(expr, expr);
return forloop_stack;
}
void Visit(const ir::For *op, Expr *expr) {
auto *node = expr->As<ir::For>();
forloop_stack.push_back(expr);
ir::IRMutator<>::Visit(&node->body, &node->body);
if (!found) forloop_stack.pop_back();
}
void Visit(const ir::PolyFor *op, Expr *expr) {
auto *node = expr->As<ir::PolyFor>();
forloop_stack.push_back(expr);
ir::IRMutator<>::Visit(&node->body, &node->body);
if (!found) forloop_stack.pop_back();
}
void Visit(const ir::Store *op, Expr *expr) {
found = op->tensor.as_tensor()->name == tensor_name;
}
};
return Mutator(tensor_name)(expr);
}
Expr max(Expr a, Expr b) {
CHECK_EQ(a.type(), b.type());
return ir::Max::Make(a, b);
}
Expr min(Expr a, Expr b) {
CHECK_EQ(a.type(), b.type());
return ir::Min::Make(a, b);
}
} // namespace common
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <absl/container/flat_hash_map.h>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/cinn/common/bfloat16.h"
#include "paddle/cinn/common/float16.h"
#include "paddle/cinn/ir/ir.h"
namespace cinn {
namespace common {
Expr IndiceToAbsOffset(const std::vector<Expr> &shape,
const std::vector<Expr> &indices);
Expr IndiceToAbsOffset(const std::vector<int> &shape,
const std::vector<Expr> &indices);
Expr PrecedingAxisToAbsOffset(const std::vector<Expr> &shape,
int preceding_n_axis);
Expr CastIfNeeded(Expr body, Type type);
//! Substitute vars to other expressions.
//! @param expr The expression to do modification.
//! @param var_map The map from variables to the target expressions.
void Substitute(Expr *expr, const std::map<const ir::_Var_ *, Expr> &var_map);
//! Get a stack of forloops(For and PolyFor nodes) to a Store node target to \p
//! tensor_name
std::vector<Expr *> GetForloopStackToStore(Expr *expr,
const std::string &tensor_name);
// make const
// @{
inline Expr make_const(int32_t x) { return Expr(static_cast<int32_t>(x)); }
inline Expr make_const(int64_t x) { return Expr(static_cast<int64_t>(x)); }
inline Expr make_const(bfloat16 x) { return Expr(static_cast<bfloat16>(x)); }
inline Expr make_const(float16 x) { return Expr(static_cast<float16>(x)); }
inline Expr make_const(float x) { return Expr(static_cast<float>(x)); }
inline Expr make_const(double x) { return Expr(static_cast<double>(x)); }
inline Expr make_const(bool x) { return Expr(static_cast<bool>(x)); }
// @}
//! maker for some general consts.
// @{
template <typename T = int32_t>
inline Expr make_zero() {
return make_const(static_cast<T>(0));
}
template <typename T = int32_t>
inline Expr make_one() {
return make_const(static_cast<T>(1));
}
inline Expr make_bool(bool x) {
return common::make_shared<ir::UIntImm>(Bool(), x);
}
inline Expr make_bool(bool x, int lanes) {
return common::make_shared<ir::UIntImm>(Bool(lanes), x);
}
// @}
/**
* \brief Check all the tensors are unique in an expression.
*/
void CheckTensorUniqueInExpr(Expr expr);
/**
* \brief Check all the buffers are uniuqe in an expression.
*/
void CheckBufferUniqueInExpr(Expr expr);
std::vector<std::string> GatherItersToTensorProducer(
const std::string &target_tensor_name, Expr *expr);
bool is_zero(Expr v);
bool MathEqual(const Expr &a, const Expr &b);
//! helper function to get a ir::Select node.
Expr select(Expr cond, Expr true_value, Expr false_value);
//! helper function to get the And of all the conditions.
Expr and_all(const std::vector<Expr> &conds);
//! helper function to get the Or of all the conditions.
Expr or_any(const std::vector<Expr> &conds);
//! Cast the expression \p e to type \type.
Expr cast(Expr e, Type type);
Expr max(Expr a, Expr b);
Expr min(Expr a, Expr b);
template <typename T>
Expr make_const(Type t, T v) {
if (t.is_vector()) {
if (t.is_int()) {
return ir::Broadcast::Make(
make_shared<ir::IntImm>(t.ElementOf(), static_cast<int64_t>(v)),
t.lanes());
} else if (t.is_uint()) {
return ir::Broadcast::Make(
make_shared<ir::UIntImm>(t.ElementOf(), static_cast<uint64_t>(v)),
t.lanes());
} else if (t.is_float()) {
return ir::Broadcast::Make(
make_shared<ir::FloatImm>(t.ElementOf(), static_cast<double>(v)),
t.lanes());
} else if (t.is_bool()) {
return ir::Broadcast::Make(
make_shared<ir::UIntImm>(t.ElementOf(), static_cast<bool>(v)),
t.lanes());
} else {
CINN_NOT_IMPLEMENTED
}
} else {
if (t.is_int()) {
return make_shared<ir::IntImm>(t, static_cast<int64_t>(v));
} else if (t.is_uint()) {
return make_shared<ir::UIntImm>(t, static_cast<uint64_t>(v));
} else if (t.is_float()) {
return make_shared<ir::FloatImm>(t, static_cast<double>(v));
} else if (t.is_bool()) {
return make_shared<ir::UIntImm>(t, static_cast<bool>(v));
} else {
CINN_NOT_IMPLEMENTED
}
}
return Expr();
}
template <typename FuncOp>
Expr FoldExpr(FuncOp func_op, const std::vector<Expr> &values) {
Expr init_value;
for (const Expr &val : values) {
if (!init_value.defined()) {
init_value = val;
} else {
init_value = func_op(val, init_value);
}
}
return init_value;
}
} // namespace common
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <array>
#include <functional>
#include "paddle/cinn/common/bfs_walker.h"
namespace cinn {
namespace common {
template <typename NodeType>
class IsReachablePredicator final {
public:
IsReachablePredicator(const IsReachablePredicator&) = delete;
IsReachablePredicator(IsReachablePredicator&&) = delete;
using NodeHandlerType = std::function<void(NodeType)>;
using NodesVisitorType =
std::function<void(NodeType, const NodeHandlerType&)>;
using NodeDepthGetterType = std::function<size_t(NodeType)>;
IsReachablePredicator(const NodeDepthGetterType& MinDepth4Node,
const NodeDepthGetterType& MaxDepth4Node,
const NodesVisitorType& VisitNextNodes)
: MinDepth4Node_(MinDepth4Node),
MaxDepth4Node_(MaxDepth4Node),
VisitNextNodes_(VisitNextNodes) {}
bool operator()(NodeType src,
NodeType dst,
const NodeHandlerType& HandleVisited) const {
const size_t dst_max_depth = MaxDepth4Node_(dst);
bool detect_reachable = false;
BfsWalker<NodeType> bfs_walker(
[&](NodeType node, const NodeHandlerType& Handler) {
VisitNextNodes_(node, [&](NodeType out_node) {
if (dst_max_depth < MinDepth4Node_(out_node)) {
// Pruned.
// Do nothing.
} else if (detect_reachable) {
// Pruned.
// Reachability is detected.
} else {
Handler(out_node);
}
});
});
std::array<NodeType, 1> starts{src};
bfs_walker(starts.begin(), starts.end(), [&](NodeType node) {
HandleVisited(node);
if (node == dst) {
detect_reachable = true;
}
});
return detect_reachable;
}
private:
NodeDepthGetterType MinDepth4Node_;
NodeDepthGetterType MaxDepth4Node_;
NodesVisitorType VisitNextNodes_;
};
} // namespace common
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/is_reachable_predicator.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
namespace cinn {
namespace common {
TEST(IsReachablePredicator, simple) {
IsReachablePredicator<int> IsReachable(
// Get min depth
[](int x) { return std::abs(x); },
// Get max depth
[](int x) { return std::abs(x); },
// visit next node
[](int x, const std::function<void(int)>& Handler) {
Handler(x + (x / std::abs(x)));
});
EXPECT_TRUE(IsReachable(33, 99, [](int) {}));
EXPECT_FALSE(IsReachable(33, -99, [](int) {}));
}
} // namespace common
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#if !defined(NDEBUG)
#define CINN_DEBUG
#endif
#define CINN_DISALLOW_COPY_AND_ASSIGN(TypeName) \
TypeName(const TypeName&) = delete; \
void operator=(const TypeName&) = delete
#ifndef CINN_NOT_IMPLEMENTED
#define CINN_NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented";
#endif
#define CINN_RESULT_SHOULD_USE __attribute__((warn_unused_result))
/**
* A trick to enforce the registry.
*
* usage:
*
* CINN_REGISTER_HELPER(some_key) {
* // register methods
* }
*
* CINN_USE_REGISTER(some_key);
*/
#define CINN_REGISTER_HELPER(symbol__) bool __cinn__##symbol__##__registrar()
#define CINN_USE_REGISTER(symbol__) \
extern bool __cinn__##symbol__##__registrar(); \
[[maybe_unused]] static bool __cinn_extern_registrar_##symbol__ = \
__cinn__##symbol__##__registrar();
#if __cplusplus >= 201703L
#define CINN_NODISCARD [[nodiscard]]
#else
#define CINN_NODISCARD
#endif
#define DISABLE_COPY_AND_ASSIGN(classname) \
private: \
classname(const classname&) = delete; \
classname(classname&&) = delete; \
classname& operator=(const classname&) = delete; \
classname& operator=(classname&&) = delete
/**
* check if MACRO is used in GLOBAL NAMESPACE.
*/
#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
#define USE_FUSION_PASS(pass_name) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_fusion_pass_##pass_name, \
"USE_OP_ITSELF must be called in global namespace"); \
extern int TouchFusionPassRegistrar_##pass_name(); \
[[maybe_unused]] static int __use_fusion_pass_##pass_name##_ = \
TouchFusionPassRegistrar_##pass_name()
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2021 CINN Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/infrt/common/object.h" #include "paddle/cinn/common/object.h"
namespace infrt { namespace cinn {
namespace common {} // namespace common namespace common {} // namespace common
} // namespace infrt } // namespace cinn
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2021 CINN Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -14,23 +14,21 @@ ...@@ -14,23 +14,21 @@
#pragma once #pragma once
#include <cstring> #include <cstring>
#include <iostream>
#include "paddle/infrt/common/shared.h" #include "paddle/cinn/common/shared.h"
namespace infrt { namespace cinn {
namespace common { namespace common {
template <typename T> template <typename T>
class Shared; class Shared;
/** /**
* Object is the basic element in the INFRT, with `Shared` wrapper, the object * Object is the basic element in the CINN, with `Shared` wrapper, the object
* can be shared across the system. * can be shared across the system.
*/ */
struct Object { struct Object {
//! Get the type representation of this object. //! Get the type representation of this object.
virtual const char* type_info() const = 0; virtual const char* type_info() const = 0;
virtual ~Object() {}
//! Cast to a derived type. //! Cast to a derived type.
template <typename T> template <typename T>
...@@ -78,4 +76,4 @@ using object_ptr = Object*; ...@@ -78,4 +76,4 @@ using object_ptr = Object*;
using shared_object = Shared<Object>; using shared_object = Shared<Object>;
} // namespace common } // namespace common
} // namespace infrt } // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/python_interpreter_guard.h"
#include <pybind11/embed.h>
namespace cinn {
namespace common {
PythonInterpreterGuard::PythonInterpreterGuard() {
pybind11::initialize_interpreter();
}
PythonInterpreterGuard::~PythonInterpreterGuard() {
pybind11::finalize_interpreter();
}
PythonInterpreterGuard& PythonInterpreterGuard::Guard() {
static PythonInterpreterGuard guard;
return guard;
}
} // namespace common
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace cinn {
namespace common {
/**
* Singleton to handle Python interpreter life time, since
* pybind11::initialize_interpreter and pybind11::finalize_interpreter cannot be
* called initialization again after finalization, this singleton calls
* pybind11::finalize_interpreter when it constructs and calls finalization when
* it destructs.
*
* In this case, every caller can call this guard to make sure the pybind11
* Python interpreter is alive.
*/
class PythonInterpreterGuard {
public:
// Destructor
~PythonInterpreterGuard();
// Singleton get instance
static PythonInterpreterGuard& Guard();
private:
// Constructor
PythonInterpreterGuard();
};
} // namespace common
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <functional>
#include <list>
#include <queue>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/cinn/common/dfs_walker.h"
namespace cinn {
namespace common {
// strong connnected components visitor
template <typename NodeType>
class SccWalker final {
public:
SccWalker(const SccWalker&) = delete;
SccWalker(SccWalker&&) = delete;
using NodeHandlerType = std::function<void(NodeType)>;
using NodesVisitorType =
std::function<void(NodeType, const NodeHandlerType&)>;
SccWalker(const NodesVisitorType& VisitPrevNodes,
const NodesVisitorType& VisitNextNodes)
: VisitPrevNodes_(VisitPrevNodes), VisitNextNodes_(VisitNextNodes) {}
using SccHandlerType = std::function<void(const std::vector<NodeType>&)>;
// https://en.wikipedia.org/wiki/Kosaraju%27s_algorithm
template <typename NodeIt>
void operator()(NodeIt begin,
NodeIt end,
const SccHandlerType& SccHandler) const {
const std::list<NodeType>& dfs_ordered_nodes = [&]() {
std::list<NodeType> dfs_ordered_nodes;
DfsVisitor<NodeType> visitor(VisitNextNodes_);
visitor(
begin,
end,
/*on push*/ [](NodeType) {},
/*on pop*/
[&](NodeType node) { dfs_ordered_nodes.push_front(node); });
return dfs_ordered_nodes;
}();
std::unordered_map<NodeType, NodeType> node2root;
const auto& VisitPrevNode = [&](NodeType node,
const NodeHandlerType& NodeHandler) {
VisitPrevNodes_(node, [&](NodeType prev_node) {
if (node2root.count(prev_node) == 0) {
NodeHandler(prev_node);
}
});
};
for (NodeType root : dfs_ordered_nodes) {
if (node2root.count(root) > 0) {
continue;
}
std::vector<NodeType> scc;
// Use node2root immutablely inside dfs visitor.
DfsVisitor<NodeType> visitor(VisitPrevNode);
visitor(root, [&](NodeType node) { scc.push_back(node); });
SccHandler(scc);
// Update node2root outside dfs visitor.
for (NodeType node : scc) {
CHECK(node2root.emplace(node, root).second);
}
}
}
private:
NodesVisitorType VisitPrevNodes_;
NodesVisitorType VisitNextNodes_;
};
} // namespace common
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/scc_walker.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
namespace cinn {
namespace common {
TEST(SccWalker, trivial) {
std::list<std::pair<int, int>> edges{{0, 3}, {1, 2}, {1, 3}, {2, 4}, {3, 4}};
SccWalker<int> visitor(
[&](int node, const std::function<void(int)>& NodeHandler) {
for (const auto& pair : edges) {
if (pair.second == node) {
NodeHandler(pair.first);
}
}
},
[&](int node, const std::function<void(int)>& NodeHandler) {
for (const auto& pair : edges) {
if (pair.first == node) {
NodeHandler(pair.second);
}
}
});
std::vector<int> sources{0, 1};
std::vector<std::vector<int>> outputs;
visitor(sources.begin(), sources.end(), [&](const auto& nodes) {
outputs.push_back(nodes);
});
std::vector<std::vector<int>> expected{{1}, {2}, {0}, {3}, {4}};
EXPECT_TRUE((outputs == expected));
}
TEST(SccWalker, circle) {
std::list<std::pair<int, int>> edges{
{0, 1},
{1, 2},
{2, 3},
{3, 4},
{4, 0},
};
SccWalker<int> visitor(
[&](int node, const std::function<void(int)>& NodeHandler) {
for (const auto& pair : edges) {
if (pair.second == node) {
NodeHandler(pair.first);
}
}
},
[&](int node, const std::function<void(int)>& NodeHandler) {
for (const auto& pair : edges) {
if (pair.first == node) {
NodeHandler(pair.second);
}
}
});
std::vector<int> sources{0};
std::vector<std::vector<int>> outputs;
visitor(sources.begin(), sources.end(), [&](const auto& nodes) {
outputs.push_back(nodes);
});
std::vector<std::vector<int>> expected{{0, 4, 3, 2, 1}};
EXPECT_TRUE((outputs == expected));
}
TEST(SccWalker, double_circle) {
std::list<std::pair<int, int>> edges{
{0, 1},
{1, 0},
{1, 2},
{2, 3},
{3, 2},
};
SccWalker<int> visitor(
[&](int node, const std::function<void(int)>& NodeHandler) {
for (const auto& pair : edges) {
if (pair.second == node) {
NodeHandler(pair.first);
}
}
},
[&](int node, const std::function<void(int)>& NodeHandler) {
for (const auto& pair : edges) {
if (pair.first == node) {
NodeHandler(pair.second);
}
}
});
std::vector<int> sources{0};
std::vector<std::vector<int>> outputs;
visitor(sources.begin(), sources.end(), [&](const auto& nodes) {
outputs.push_back(nodes);
});
std::vector<std::vector<int>> expected{{0, 1}, {2, 3}};
EXPECT_TRUE((outputs == expected));
}
} // namespace common
} // namespace cinn
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2021 CINN Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,4 +12,4 @@ ...@@ -12,4 +12,4 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/infrt/common/shared.h" #include "paddle/cinn/common/shared.h"
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2021 CINN Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <string> #include <string>
#include <type_traits> #include <type_traits>
namespace infrt { namespace cinn {
namespace common { namespace common {
class RefCount { class RefCount {
...@@ -55,7 +55,7 @@ struct Shared { ...@@ -55,7 +55,7 @@ struct Shared {
using object_ptr = T*; using object_ptr = T*;
Shared() = default; Shared() = default;
explicit Shared(T* p) : p_(p) { Shared(T* p) : p_(p) { // NOLINT
if (p) IncRef(p); if (p) IncRef(p);
} }
Shared(const Shared& other) : p_(other.p_) { IncRef(p_); } Shared(const Shared& other) : p_(other.p_) { IncRef(p_); }
...@@ -74,7 +74,7 @@ struct Shared { ...@@ -74,7 +74,7 @@ struct Shared {
inline const T* self() const { return p_; } inline const T* self() const { return p_; }
// @} // @}
inline bool same_as(const Shared& other) { return p_ == other.p_; } inline bool same_as(const Shared& other) const { return p_ == other.p_; }
inline bool defined() const { return p_; } inline bool defined() const { return p_; }
inline bool operator<(const Shared& other) const { return p_ < other.p_; } inline bool operator<(const Shared& other) const { return p_ < other.p_; }
inline Shared<T>& operator=(T* x); inline Shared<T>& operator=(T* x);
...@@ -111,8 +111,7 @@ template <typename T> ...@@ -111,8 +111,7 @@ template <typename T>
Shared<T>& Shared<T>::operator=(const Shared<T>& other) { Shared<T>& Shared<T>::operator=(const Shared<T>& other) {
if (other.p_ == p_) return *this; if (other.p_ == p_) return *this;
// Other can be inside of something owned by this, so we should be careful to // Other can be inside of something owned by this, so we should be careful to
// incref other before we decref // incref other before we decref ourselves.
// ourselves.
T* tmp = other.p_; T* tmp = other.p_;
IncRef(tmp); IncRef(tmp);
DecRef(p_); DecRef(p_);
...@@ -150,4 +149,4 @@ void Shared<T>::Reset(T* x) { ...@@ -150,4 +149,4 @@ void Shared<T>::Reset(T* x) {
} }
} // namespace common } // namespace common
} // namespace infrt } // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/shared.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/cinn/common/object.h"
namespace cinn {
namespace common {
struct A : public Object {
const char *type_info() const override { return "A"; }
Shared<A> other;
};
class B : public Object {};
TEST(Shared, test) {
Shared<A> a_ref(make_shared<A>());
ASSERT_EQ(ref_count(a_ref.get()).val(), 1);
{ // local copy
Shared<A> b = a_ref;
EXPECT_EQ(ref_count(a_ref.get()).val(), 2);
ASSERT_EQ(ref_count(b.get()).val(), 2);
}
ASSERT_EQ(ref_count(a_ref.get()).val(), 1);
}
TEST(Shared, cycle_share) {
{
Shared<A> a_ref(make_shared<A>());
a_ref->other = a_ref;
ASSERT_EQ(a_ref->__ref_count__.val(), 2);
}
}
} // namespace common
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef CINN_WITH_CUDA
#include <cuda_runtime_api.h>
#include <driver_types.h>
#endif
#include <glog/logging.h>
#include <sstream>
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/runtime/cinn_runtime.h"
namespace cinn {
namespace common {
bool Target::operator==(const Target &other) const {
return os == other.os && //
arch == other.arch && //
bits == other.bits && //
features == other.features;
}
int Target::runtime_arch() const {
switch (arch) {
case Arch::Unk:
return cinn_unk_device;
case Arch::X86:
return cinn_x86_device;
case Arch::ARM:
return cinn_arm_device;
default:
LOG(FATAL) << "Not supported arch";
}
return -1;
}
int Target::max_num_threads() const {
CHECK(arch == Arch::NVGPU)
<< "The target is not NVGPU! Cannot get max number of threads.";
return 1024;
}
int Target::get_multi_processor_count() const {
CHECK(arch == Arch::NVGPU)
<< "The target is not NVGPU! Cannot get multi processor count";
int num_sm = 0;
#ifdef CINN_WITH_CUDA
cudaDeviceGetAttribute(
&num_sm, cudaDeviceAttr::cudaDevAttrMultiProcessorCount, 0);
#endif
return num_sm;
}
int Target::get_max_threads_per_sm() const {
CHECK(arch == Arch::NVGPU)
<< "The target is not NVGPU! Cannot get max threads per stream processor";
int max_thread = 0;
#ifdef CINN_WITH_CUDA
cudaDeviceGetAttribute(
&max_thread, cudaDeviceAttr::cudaDevAttrMaxThreadsPerMultiProcessor, 0);
#endif
return max_thread;
}
int Target::get_max_blocks_per_sm() const {
CHECK(arch == Arch::NVGPU)
<< "The target is not NVGPU! Cannot get max blocks per stream processor";
int max_blocks = 1;
#ifdef CINN_WITH_CUDA
cudaDeviceGetAttribute(
&max_blocks, cudaDeviceAttr::cudaDevAttrMaxBlocksPerMultiprocessor, 0);
#endif
return max_blocks;
}
std::vector<Target::Lib> Target::get_target_libs() const { return libs; }
int Target::get_target_bits() const {
switch (bits) {
case Bit::k32:
return 32;
case Bit::k64:
return 64;
case Bit::Unk:
return 0;
default:
LOG(FATAL) << "Not supported Bit";
}
return -1;
}
std::string Target::arch_str() const {
std::ostringstream oss;
oss << arch;
return oss.str();
}
std::ostream &operator<<(std::ostream &os, const Target &target) {
os << "Target<";
switch (target.os) {
case Target::OS::Linux:
os << "linux";
break;
case Target::OS::Windows:
os << "windows";
break;
case Target::OS::Unk:
os << "unk";
break;
}
os << ",";
switch (target.arch) {
case Target::Arch::X86:
os << "x86";
break;
case Target::Arch::ARM:
os << "arm";
break;
case Target::Arch::NVGPU:
os << "nvgpu";
break;
case Target::Arch::Unk:
os << "unk";
break;
}
os << ",";
switch (target.bits) {
case Target::Bit::k32:
os << "32";
break;
case Target::Bit::k64:
os << "64";
break;
case Target::Bit::Unk:
os << "unk";
break;
}
os << ">";
return os;
}
std::ostream &operator<<(std::ostream &os, Target::Arch arch) {
switch (arch) {
case Target::Arch::Unk:
os << "Unk";
break;
case Target::Arch::X86:
os << "X86";
break;
case Target::Arch::ARM:
os << "ARM";
break;
case Target::Arch::NVGPU:
os << "NVGPU";
break;
}
return os;
}
const Target &UnkTarget() {
static Target target(
Target::OS::Unk, Target::Arch::Unk, Target::Bit::Unk, {}, {});
return target;
}
const Target &DefaultHostTarget() {
static Target target(
Target::OS::Linux, Target::Arch::X86, Target::Bit::k64, {}, {});
return target;
}
const Target &DefaultNVGPUTarget() {
static Target target(
Target::OS::Linux, Target::Arch::NVGPU, Target::Bit::k64, {}, {});
return target;
}
int GetMaxThreads() {
// cudaDeviceGetAttribute ( int* value, cudaDeviceAttr attr, int device )
int max_threads = 1;
#ifdef CINN_WITH_CUDA
int num_sm = 1;
cudaDeviceGetAttribute(
&num_sm, cudaDeviceAttr::cudaDevAttrMultiProcessorCount, 0);
cudaDeviceGetAttribute(
&max_threads, cudaDeviceAttr::cudaDevAttrMaxThreadsPerMultiProcessor, 0);
// multiplication num_sm
max_threads *= (num_sm * 4);
#endif
return max_threads;
}
int GetMaxBlocks() {
// cudaDeviceGetAttribute ( int* value, cudaDeviceAttr attr, int device )
int max_blocks = 1;
#ifdef CINN_WITH_CUDA
int num_sm = 1;
cudaDeviceGetAttribute(
&num_sm, cudaDeviceAttr::cudaDevAttrMultiProcessorCount, 0);
cudaDeviceGetAttribute(
&max_blocks, cudaDeviceAttr::cudaDevAttrMaxBlocksPerMultiprocessor, 0);
// multiplication num_sm
max_blocks *= num_sm;
#endif
return max_blocks;
}
const Target &DefaultTarget() {
#ifdef CINN_WITH_CUDA
return DefaultNVGPUTarget();
#else
return DefaultHostTarget();
#endif
}
} // namespace common
} // namespace cinn
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2021 CINN Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -15,9 +15,10 @@ ...@@ -15,9 +15,10 @@
#pragma once #pragma once
#include <ostream> #include <ostream>
#include <string>
#include <vector> #include <vector>
namespace infrt { namespace cinn {
namespace common { namespace common {
struct Target { struct Target {
...@@ -77,36 +78,42 @@ struct Target { ...@@ -77,36 +78,42 @@ struct Target {
return os != OS::Unk && arch != Arch::Unk && bits != Bit::Unk; return os != OS::Unk && arch != Arch::Unk && bits != Bit::Unk;
} }
//! Get the Runtime architecture, it is casted to integer to avoid header file
//! depending.
int runtime_arch() const;
int max_num_threads() const; int max_num_threads() const;
int get_multi_processor_count() const;
int get_max_threads_per_sm() const;
int get_max_blocks_per_sm() const;
int get_target_bits() const; int get_target_bits() const;
std::vector<Lib> get_target_libs() const; std::vector<Lib> get_target_libs() const;
std::string arch_str() const;
bool operator==(const Target& other) const; bool operator==(const Target& other) const;
bool operator!=(const Target& other) const { return !(*this == other); } bool operator!=(const Target& other) const { return !(*this == other); }
friend std::ostream& operator<<(std::ostream& os, const Target& target); friend std::ostream& operator<<(std::ostream& os, const Target& target);
}; };
static const Target& UnkTarget() { const Target& UnkTarget();
static Target target(
Target::OS::Unk, Target::Arch::Unk, Target::Bit::Unk, {}, {}); const Target& DefaultHostTarget();
return target;
} const Target& DefaultNVGPUTarget();
const Target& DefaultTarget();
static const Target& DefaultHostTarget() { int GetMaxThreads();
static Target target(
Target::OS::Linux, Target::Arch::X86, Target::Bit::k64, {}, {});
return target;
}
static const Target& DefaultNVGPUTarget() { int GetMaxBlocks();
static Target target(
Target::OS::Linux, Target::Arch::NVGPU, Target::Bit::k64, {}, {});
return target;
}
std::ostream& operator<<(std::ostream& os, Target::Arch arch); std::ostream& operator<<(std::ostream& os, Target::Arch arch);
} // namespace common } // namespace common
} // namespace infrt } // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/test_helper.h"
namespace cinn {
namespace common {
cinn_buffer_t* BufferBuilder::Build() {
cinn_type_t cinn_type;
if (type_ == type_of<float>()) {
cinn_type = cinn_float32_t();
} else if (type_ == type_of<double>()) {
cinn_type = cinn_float64_t();
} else if (type_ == type_of<int8_t>()) {
cinn_type = cinn_int8_t();
} else if (type_ == type_of<int32_t>()) {
cinn_type = cinn_int32_t();
} else if (type_ == type_of<int64_t>()) {
cinn_type = cinn_int64_t();
} else if (type_ == type_of<bool>()) {
cinn_type = cinn_bool_t();
} else {
CINN_NOT_IMPLEMENTED
}
auto* buffer = cinn_buffer_t::new_(
cinn_device_kind_t::cinn_x86_device, cinn_type, shape_, align_);
cinn_buffer_malloc(nullptr, buffer);
switch (init_type_) {
case InitType::kZero:
memset(buffer->memory, 0, buffer->memory_size);
break;
case InitType::kRandom:
if (type_ == type_of<float>()) {
RandomFloat<float>(buffer->memory, buffer->num_elements());
} else if (type_ == type_of<double>()) {
RandomFloat<double>(buffer->memory, buffer->num_elements());
} else if (type_ == type_of<bool>()) {
RandomInt<int8_t>(buffer->memory, buffer->num_elements());
} else if (type_ == type_of<int8_t>()) {
RandomInt<int8_t>(buffer->memory, buffer->num_elements());
} else if (type_ == type_of<int32_t>()) {
RandomInt<int32_t>(buffer->memory, buffer->num_elements());
} else if (type_ == type_of<int64_t>()) {
RandomInt<int64_t>(buffer->memory, buffer->num_elements());
}
break;
case InitType::kSetValue:
if (type_ == type_of<int>()) {
SetVal<int>(buffer->memory, buffer->num_elements(), init_val_);
} else if (type_ == type_of<int8_t>()) {
SetVal<int8_t>(buffer->memory, buffer->num_elements(), init_val_);
} else if (type_ == type_of<float>()) {
SetVal<float>(buffer->memory, buffer->num_elements(), init_val_);
} else {
CINN_NOT_IMPLEMENTED
}
break;
}
return buffer;
}
} // namespace common
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/cinn/backends/llvm/codegen_llvm.h"
#include "paddle/cinn/backends/llvm/simple_jit.h"
#include "paddle/cinn/cinn.h"
namespace cinn {
namespace common {
/**
* Create buffer for test.
*
* usage:
*
* auto* buf = BufferBuilder(Float(32), {20, 20}).set_random().Build();
*/
struct BufferBuilder {
enum class InitType {
kRandom = 0,
kZero = 1,
kSetValue = 2,
};
explicit BufferBuilder(Type type, const std::vector<int>& shape)
: type_(type), shape_(shape) {}
BufferBuilder& set_random() {
init_type_ = InitType::kRandom;
return *this;
}
BufferBuilder& set_zero() {
init_type_ = InitType::kZero;
return *this;
}
BufferBuilder& set_val(float x) {
init_type_ = InitType::kSetValue;
init_val_ = x;
return *this;
}
BufferBuilder& set_align(int align) {
align_ = align;
return *this;
}
cinn_buffer_t* Build();
private:
template <typename T>
void RandomFloat(void* arr, uint64_t len) {
auto* data = static_cast<T*>(arr);
for (uint64_t i = 0; i < len; i++) {
data[i] = static_cast<T>(rand()) / RAND_MAX; // NOLINT
}
}
template <typename T>
void RandomInt(void* arr, int len) {
auto* data = static_cast<T*>(arr);
for (int i = 0; i < len; i++) {
data[i] =
static_cast<T>(rand() % std::numeric_limits<T>::max()); // NOLINT
}
}
template <typename T>
void SetVal(void* arr, int len, T x) {
auto* data = static_cast<T*>(arr);
for (int i = 0; i < len; i++) {
data[i] = x;
}
}
private:
std::vector<int> shape_;
InitType init_type_{InitType::kZero};
float init_val_{};
int align_{};
Type type_;
};
struct ArgsBuilder {
template <typename T>
ArgsBuilder& Add(T x) {
data_.emplace_back(x);
return *this;
}
std::vector<cinn_pod_value_t> Build() {
CHECK(!data_.empty());
return data_;
}
private:
std::vector<cinn_pod_value_t> data_;
};
} // namespace common
} // namespace cinn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment