"platforms/cuda/vscode:/vscode.git/clone" did not exist on "d7b3a3c2ed7faf103152ff1fafbdca4a75c132ec"
Commit 992bec46 authored by “yuguo”'s avatar “yuguo”
Browse files

2.5

parent 0259837d
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/cinn_value.h"
#include <gtest/gtest.h>
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace cinn {
namespace common {
TEST(CINNValue, test) {
{
CINNValue value(32);
ASSERT_EQ(int(value), 32); // NOLINT
}
{
CINNValue value(32.f);
ASSERT_NEAR(float(value), 32.f, 1e-6); // NOLINT
}
}
TEST(CINNValue, buffer) {
cinn_buffer_t* v = nullptr;
CINNValue value(v);
ASSERT_EQ((cinn_buffer_t*)value, nullptr);
}
TEST(CINNValue, Expr) {
Expr a(1);
{
CINNValue value(a);
ASSERT_TRUE(a == value);
}
{
CINNValue copied = CINNValue(a);
ASSERT_TRUE(copied == common::make_const(1));
}
}
} // namespace common
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <absl/strings/string_view.h>
#include "paddle/cinn/common/axis.h"
#include "paddle/cinn/common/cinn_value.h"
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/common/graph_utils.h"
#include "paddle/cinn/common/macros.h"
#include "paddle/cinn/common/shared.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/common/type.h"
namespace cinn {
// export some general concepts.
using common::Context;
using common::make_shared;
using common::Object;
using common::ref_count;
using common::Shared;
using common::UniqName;
// Type related.
using common::Bool;
using common::Float;
using common::Int;
using common::UInt;
using common::Void;
using common::type_of;
using common::Target;
using common::Type;
using common::UnkTarget;
template <typename T>
T& Reference(const T* x) {
return *const_cast<T*>(x);
}
static void CheckVarNameValid(const absl::string_view name) {
CHECK(!name.empty());
CHECK(name.find(' ') == std::string::npos && //
name.find('.') == std::string::npos && //
name.find('@') == std::string::npos && //
name.find('/') == std::string::npos && //
name.find('\t') == std::string::npos && //
name.find('\n') == std::string::npos && //
name.find('\r') == std::string::npos)
<< "Some invalid character found";
CHECK(!common::IsAxisNameReserved(std::string(name)))
<< "The name [" << name << "] is reserved for internal axis";
}
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/context.h"
#include <glog/logging.h>
#include <isl/cpp.h>
#include <mutex>
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
namespace common {
namespace {
#ifdef RUNTIME_INCLUDE_DIR
static constexpr char* defined_runtime_include_dir = RUNTIME_INCLUDE_DIR;
#else
static constexpr char* defined_runtime_include_dir = nullptr;
#endif
} // namespace
thread_local isl::ctx Context::ctx_ = isl_ctx_alloc();
thread_local InfoRegistry Context::info_rgt_;
thread_local DebugManager Context::debug_mgr_;
Context& Context::Global() {
static Context x;
isl_options_set_on_error(ctx_.get(), ISL_ON_ERROR_ABORT);
return x;
}
const std::vector<std::string>& Context::runtime_include_dir() {
std::lock_guard<std::mutex> lock(mutex_);
if (runtime_include_dir_.empty()) {
const char* env = std::getenv(kRuntimeIncludeDirEnvironKey);
if (env) { // use environment variable firstly
VLOG(4) << "get runtime_include_dir from env: " << env;
runtime_include_dir_ = cinn::utils::Split(env, ":");
} else if (defined_runtime_include_dir) {
VLOG(4) << "get runtime_include_dir from RUNTIME_INCLUDE_DIR: "
<< defined_runtime_include_dir;
runtime_include_dir_ =
cinn::utils::Split(defined_runtime_include_dir, ":");
}
}
return runtime_include_dir_;
}
void Context::AddRuntimeIncludeDir(std::string dir) {
// TODO(Shixiaowei02): path deduplication
runtime_include_dir_.emplace_back(std::move(dir));
}
const char* kRuntimeIncludeDirEnvironKey = "runtime_include_dir";
std::string NameGenerator::New(const std::string& name_hint) {
std::lock_guard<std::mutex> lock(mutex_);
auto it = name_hint_idx_.find(name_hint);
if (it == name_hint_idx_.end()) {
name_hint_idx_.emplace(name_hint, -1);
return name_hint;
}
return name_hint + "_" + std::to_string(++it->second);
}
} // namespace common
DEFINE_bool(cinn_runtime_display_debug_info,
false,
"Whether to display debug information in runtime");
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <absl/types/any.h>
#include <gflags/gflags.h>
#include <isl/cpp.h>
#include <mutex>
#include <set>
#include <string>
#include <vector>
#include "paddle/cinn/common/debug_manager.h"
#include "paddle/cinn/common/info_registry.h"
#include "paddle/cinn/common/target.h"
namespace cinn {
DECLARE_bool(cinn_runtime_display_debug_info);
namespace ir {
class Expr;
} // namespace ir
namespace common {
extern const char* kRuntimeIncludeDirEnvironKey;
struct NameGenerator {
std::string New(const std::string& name_hint);
// Reset id to initial.
void ResetID() {
std::lock_guard<std::mutex> lock(mutex_);
name_hint_idx_.clear();
}
private:
absl::flat_hash_map<std::string, uint32_t> name_hint_idx_;
mutable std::mutex mutex_;
};
class Context {
public:
static Context& Global();
/**
* Generate a new unique name.
* @param name_hint The prefix.
*/
std::string NewName(const std::string& name_hint) {
return name_generator_.New(name_hint);
}
void ResetNameId() { name_generator_.ResetID(); }
const std::vector<std::string>& runtime_include_dir();
void AddRuntimeIncludeDir(std::string dir);
/**
* The global isl ctx.
*/
static isl::ctx& isl_ctx() { return ctx_; }
static InfoRegistry& info_rgt() { return info_rgt_; }
static DebugManager& debug_mgr() { return debug_mgr_; }
private:
Context() = default;
NameGenerator name_generator_;
std::vector<std::string> runtime_include_dir_;
mutable std::mutex mutex_;
static thread_local isl::ctx ctx_;
static thread_local InfoRegistry info_rgt_;
static thread_local DebugManager debug_mgr_;
};
static std::string UniqName(const std::string& prefix) {
return Context::Global().NewName(prefix);
}
} // namespace common
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
namespace cinn {
namespace auto_schedule {
/**
* A C++ cost model virtual base class
*/
class CostModel {
public:
virtual void Train(const std::vector<std::vector<float>>& samples,
const std::vector<float>& labels) = 0;
virtual std::vector<float> Predict(
const std::vector<std::vector<float>>& samples) const = 0;
virtual void Update(const std::vector<std::vector<float>>& samples,
const std::vector<float>& labels) = 0;
virtual void Save(const std::string& path) = 0;
virtual void Load(const std::string& path) = 0;
};
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/cuda_test_helper.h"
#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/nvrtc/nvrtc_util.h"
#include "paddle/cinn/runtime/cuda/cuda_module.h"
#include "paddle/cinn/runtime/cuda/cuda_util.h"
namespace cinn {
namespace common {
#ifdef CINN_WITH_CUDA
void CudaModuleTester::Compile(const ir::Module& m,
const std::string& rewrite_cuda_code) {
auto _host_module_device_module_ =
backends::SplitCudaAndHostModule(m); // NOLINT
auto& host_module = std::get<0>(_host_module_device_module_);
auto& device_module = std::get<1>(_host_module_device_module_);
CHECK(!host_module.functions().empty());
CHECK(!device_module.functions().empty());
backends::CodeGenCUDA_Dev codegen(DefaultHostTarget());
auto source_code = codegen.Compile(device_module);
// compile CUDA kernel.
backends::nvrtc::Compiler compiler;
std::string ptx;
if (rewrite_cuda_code.empty())
ptx = compiler(source_code);
else
ptx = compiler(rewrite_cuda_code);
cuda_module_ =
new runtime::cuda::CUDAModule(ptx, runtime::cuda::CUDAModule::Kind::PTX);
for (auto& fn : device_module.functions()) {
std::string kernel_fn_name = fn->name;
auto fn_kernel = reinterpret_cast<runtime::cuda::CUDAModule*>(cuda_module_)
->GetFunction(0, kernel_fn_name);
CHECK(fn_kernel);
kernel_handles_.push_back(fn_kernel);
backends::GlobalSymbolRegistry::Global().RegisterFn(
kernel_fn_name + "_ptr_",
reinterpret_cast<void*>(&kernel_handles_.back()));
}
jit_ = backends::SimpleJIT::Create();
// compile host module
jit_->Link<backends::CodeGenCUDA_Host>(host_module, false);
}
void* CudaModuleTester::CreateDeviceBuffer(const cinn_buffer_t* host_buffer) {
CHECK(host_buffer->memory);
int num_bytes = host_buffer->num_elements() * sizeof(float);
CUdeviceptr data;
cuMemAlloc(&data, num_bytes);
CUDA_CALL(cudaMemcpy(reinterpret_cast<void*>(data),
host_buffer->memory,
num_bytes,
cudaMemcpyHostToDevice));
return reinterpret_cast<void*>(data);
}
CudaModuleTester::CudaModuleTester() {}
void CudaModuleTester::operator()(const std::string& fn_name,
void* args,
int arg_num) {
auto fn = jit_->Lookup(fn_name);
auto fnp = reinterpret_cast<lower_func_ptr_g>(fn);
(*fnp)(args, arg_num, stream_);
}
void* CudaModuleTester::LookupKernel(const std::string& name) {
return reinterpret_cast<runtime::cuda::CUDAModule*>(cuda_module_)
->GetFunction(0, name);
}
CudaModuleTester::~CudaModuleTester() {
if (cuda_module_) {
delete reinterpret_cast<runtime::cuda::CUDAModule*>(cuda_module_);
}
}
#endif
} // namespace common
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/cinn/backends/llvm/codegen_llvm.h"
#include "paddle/cinn/backends/llvm/simple_jit.h"
#include "paddle/cinn/cinn.h"
namespace cinn {
namespace common {
#ifdef CINN_WITH_CUDA
class CudaModuleTester {
public:
CudaModuleTester();
// Call the host function in JIT.
void operator()(const std::string& fn_name, void* args, int arg_num);
void Compile(const ir::Module& m, const std::string& rewrite_cuda_code = "");
void* LookupKernel(const std::string& name);
void* CreateDeviceBuffer(const cinn_buffer_t* host_buffer);
~CudaModuleTester();
private:
std::unique_ptr<backends::SimpleJIT> jit_;
void* stream_{};
std::vector<void*> kernel_handles_;
void* cuda_module_{nullptr};
};
#endif
} // namespace common
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/debug_manager.h"
namespace cinn {
namespace common {
inline std::vector<std::pair<std::string, absl::any>> &GetVec(
absl::any &data) { // NOLINT
return absl::any_cast<std::vector<std::pair<std::string, absl::any>> &>(data);
}
//! AppendTypeSuffix for multiple types.
// @{
template <>
inline std::string DebugManager::AppendTypeSuffix<int32_t>(
const std::string &key) {
return key + "_i32";
}
template <>
inline std::string DebugManager::AppendTypeSuffix<int64_t>(
const std::string &key) {
return key + "_i64";
}
template <>
inline std::string DebugManager::AppendTypeSuffix<float>(
const std::string &key) {
return key + "_f32";
}
template <>
inline std::string DebugManager::AppendTypeSuffix<double>(
const std::string &key) {
return key + "_f64";
}
template <>
inline std::string DebugManager::AppendTypeSuffix<bool>(
const std::string &key) {
return key + "_b";
}
template <>
inline std::string DebugManager::AppendTypeSuffix<std::string>(
const std::string &key) {
return key + "_s";
}
// @}
void DebugManager::Append(const std::string &key, absl::any value) {
GetVec(data_).push_back(std::make_pair(key, value));
}
void DebugManager::Append(const std::string &key, int32_t value) {
GetVec(data_).push_back(
std::make_pair(AppendTypeSuffix<int32_t>(key), value));
}
void DebugManager::Append(const std::string &key, bool value) {
GetVec(data_).push_back(std::make_pair(AppendTypeSuffix<bool>(key), value));
}
void DebugManager::Append(const std::string &key, const std::string &value) {
GetVec(data_).push_back(
std::make_pair(AppendTypeSuffix<std::string>(key), value));
}
void DebugManager::Clear() { GetVec(data_).clear(); }
} // namespace common
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <absl/types/any.h>
#include <string>
#include <utility>
#include <vector>
namespace cinn {
namespace common {
/**
* Container for debug info.
* DebugManager is integrated into the global Context, and used to log
* something(but not print to stdout directly).
*/
class DebugManager {
public:
void Append(const std::string& key, int32_t value);
void Append(const std::string& key, bool value);
void Append(const std::string& key, const std::string& value);
void Clear();
protected:
void Append(const std::string& key, absl::any value);
template <typename T>
inline std::string AppendTypeSuffix(const std::string& key) {
return key;
}
private:
//! hide the type of vector<pair<string, any>>
absl::any data_;
};
} // namespace common
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <array>
#include <functional>
#include <stack>
#include <unordered_set>
namespace cinn {
namespace common {
// DFS Topological order walker.
// Try to walk in a depth first manner while ensuring topological order.
// For example:
// Graph:
// 0 -> 1
// 2 -> 3
// 0 -> 3
// 1 -> 3
// 3 -> 4
// Start nodes: 0, 2
// Walking order: 0 -> 1 -> 2 -> 3 -> 4
template <typename NodeType,
typename NodeHash = std::hash<NodeType>,
typename NodeEqual = std::equal_to<NodeType>>
class DfsTopoWalker final {
public:
DfsTopoWalker(const DfsTopoWalker&) = delete;
DfsTopoWalker(DfsTopoWalker&&) = delete;
using NodeHandlerType = std::function<void(NodeType)>;
using NodesVisitorType =
std::function<void(NodeType, const NodeHandlerType&)>;
DfsTopoWalker(const NodesVisitorType& VisitPreNodes,
const NodesVisitorType& VisitNextNodes)
: VisitPreNodes_(VisitPreNodes), VisitNextNodes_(VisitNextNodes) {}
// Start walking from 1 node and make every effort to access all nodes that
// meet the walking rules.
// If there are more than 1 nodes with a degree of 0 in a graph,
// only one part will be accessed.
// If you want to access the entire graph,
// you need to provide all starting nodes.
void operator()(NodeType node, const NodeHandlerType& NodeHandler) const {
std::array<NodeType, 1> nodes{node};
(*this)(nodes.begin(), nodes.end(), NodeHandler);
}
// Start walking from a collection of node and make every effort to access all
// nodes that meet the walking rules.
// If there are other start nodes in a graph,
// some nodes on the graph will not be accessed.
// If you want to access the entire graph,
// you need to provide all starting nodes.
template <typename NodeIt>
void operator()(NodeIt begin,
NodeIt end,
const NodeHandlerType& NodeHandler) const {
std::stack<NodeType> node_stack;
std::unordered_set<NodeType, NodeHash, NodeEqual> visited;
std::unordered_map<NodeType, int, NodeHash, NodeEqual> in_degree;
const auto& InitInDegree = [&](NodeType node) {
if (in_degree.count(node) == 0) {
in_degree[node] = 0;
VisitPreNodes_(node, [&](NodeType in_node) { ++in_degree[node]; });
}
};
const auto& UpdateInDegree = [&](NodeType node) {
InitInDegree(node);
--in_degree[node];
};
const auto& TryPush = [&](NodeType node) {
InitInDegree(node);
if (visited.count(node) == 0 && in_degree[node] == 0) {
node_stack.push(node);
visited.insert(node);
}
};
for (NodeIt iter = begin; iter != end; ++iter) {
TryPush(*iter);
while (!node_stack.empty()) {
NodeType cur = node_stack.top();
node_stack.pop();
NodeHandler(cur);
VisitNextNodes_(cur, UpdateInDegree);
VisitNextNodes_(cur, TryPush);
}
}
}
private:
NodesVisitorType VisitNextNodes_;
NodesVisitorType VisitPreNodes_;
};
} // namespace common
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/cinn/common/dfs_topo_walker.h"
namespace cinn {
namespace common {
TEST(DfsTopoWalker, simple) {
std::vector<std::pair<int, int>> edges{
{0, 1}, {2, 3}, {1, 3}, {0, 3}, {3, 4}};
DfsTopoWalker<int> walker(
[&](int node, const std::function<void(int)>& NodeHandler) {
for (const auto& pair : edges) {
if (pair.second == node) {
NodeHandler(pair.first);
}
}
},
[&](int node, const std::function<void(int)>& NodeHandler) {
for (const auto& pair : edges) {
if (pair.first == node) {
NodeHandler(pair.second);
}
}
});
std::vector<int> sources{0, 2};
std::vector<int> outputs;
walker(sources.begin(), sources.end(), [&](int node) {
outputs.push_back(node);
});
for (auto output : outputs) {
LOG(INFO) << output;
}
std::vector<int> expected{0, 1, 2, 3, 4};
EXPECT_TRUE((outputs == expected));
}
} // namespace common
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <array>
#include <functional>
#include <iostream>
#include <queue>
#include <stack>
#include <unordered_set>
namespace cinn {
namespace common {
// depth-first search visitor
template <typename NodeType>
class DfsWalker final {
public:
DfsWalker(const DfsWalker&) = delete;
DfsWalker(DfsWalker&&) = delete;
using NodeHandlerType = std::function<void(NodeType)>;
using NodesVisitorType =
std::function<void(NodeType, const NodeHandlerType&)>;
DfsWalker(const NodesVisitorType& VisitNextNodes)
: VisitNextNodes_(VisitNextNodes) {}
void operator()(NodeType node, const NodeHandlerType& NodeHandler) const {
std::array<NodeType, 1> nodes{node};
(*this)(nodes.begin(), nodes.end(), NodeHandler, [&](NodeType) {});
}
template <typename NodeIt>
void operator()(NodeIt begin,
NodeIt end,
const NodeHandlerType& NodeHandler) const {
(*this)(begin, end, NodeHandler, [&](NodeType) {});
}
// https://en.wikipedia.org/wiki/Depth-first_search
template <typename NodeIt>
void operator()(NodeIt begin,
NodeIt end,
const NodeHandlerType& NodeHandlerOnPush,
const NodeHandlerType& NodeHandlerOnPop) const {
std::unordered_set<NodeType> discovered;
struct Neighbours {
NodeType producer;
std::queue<NodeType> consumers;
};
std::stack<Neighbours> stack;
const auto& TryPush = [&](NodeType node) {
if (discovered.count(node) == 0) {
discovered.insert(node);
NodeHandlerOnPush(node);
stack.push(Neighbours{.producer = node});
VisitNextNodes_(node, [&](NodeType next_node) {
stack.top().consumers.push(next_node);
});
}
};
for (NodeIt node_iter = begin; node_iter != end; ++node_iter) {
TryPush(*node_iter);
while (!stack.empty()) {
auto* neighbours = &stack.top();
if (neighbours->consumers.empty()) {
NodeHandlerOnPop(neighbours->producer);
stack.pop();
} else {
TryPush(neighbours->consumers.front());
neighbours->consumers.pop();
}
}
}
}
private:
NodesVisitorType VisitNextNodes_;
};
} // namespace common
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/dfs_walker.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
namespace cinn {
namespace common {
TEST(DfsWalker, simple_on_push) {
DfsWalker<int> visitor(
[](int node, const std::function<void(int)>& NodeHandler) {
if (node == 0) {
NodeHandler(3);
} else if (node == 1) {
NodeHandler(2);
NodeHandler(3);
} else if (node == 2 || node == 3) {
NodeHandler(4);
}
});
std::vector<int> sources{0, 1};
std::vector<int> outputs;
visitor(sources.begin(), sources.end(), [&](int node) {
LOG(ERROR) << node;
outputs.push_back(node);
});
std::vector<int> expected{0, 3, 4, 1, 2};
EXPECT_TRUE((outputs == expected));
}
TEST(DfsWalker, simple_on_pop) {
DfsWalker<int> visitor(
[](int node, const std::function<void(int)>& NodeHandler) {
if (node == 0) {
NodeHandler(3);
} else if (node == 1) {
NodeHandler(2);
NodeHandler(3);
} else if (node == 2 || node == 3) {
NodeHandler(4);
}
});
std::vector<int> sources{0, 1};
std::vector<int> outputs;
visitor(
sources.begin(),
sources.end(),
[](int) {},
[&](int node) {
LOG(ERROR) << node;
outputs.push_back(node);
});
std::vector<int> expected{4, 3, 0, 2, 1};
EXPECT_TRUE((outputs == expected));
}
} // namespace common
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CINN_COMMON_FLOAT16_H
#define CINN_COMMON_FLOAT16_H
#ifdef __cplusplus
#pragma once
#endif // __cplusplus
#if defined(_M_X64) || defined(__x86_64__) || defined(_M_IX86) || \
defined(__i386__)
#define __CINN_x86__
#include <immintrin.h>
#endif
#include <stdint.h>
#include <cmath>
#ifdef CINN_WITH_CUDA
#include <cuda.h>
#if (defined(__CUDACC__) || defined(__CUDACC_RTC__)) && CUDA_VERSION >= 7050
#define CINN_CUDA_FP16
#include <cuda_fp16.h>
#define CUDA_ARCH_FP16_SUPPORTED(CUDA_ARCH) (CUDA_ARCH >= 600)
#endif // __CUDACC__
#endif // CINN_WITH_CUDA
#ifdef __cplusplus
#ifndef _WIN32
#define CINN_ALIGN(x) __attribute__((aligned(x)))
#else // _WIN32
#define CINN_ALIGN(x) __declspec(align(x))
#endif // _WIN32
#else // __cplusplus
#define CINN_ALIGN(x)
#endif // __cplusplus
// The `HOST` macro definition is not used here, it has a potential
// conflict with the enumeration `kHOST` representing the backend.
#ifndef __host__
#define __host__
#endif
#ifndef __device__
#define __device__
#endif
#ifdef __cplusplus
namespace cinn {
namespace common {
#endif // __cplusplus
// Use CINN_ALIGNED(2) to ensure that each float16 will be allocated
// and aligned at least on a 2-byte boundary, which leads to efficient
// memory access of float16 struct and also makes float16 compatible
// with CUDA half
struct CINN_ALIGN(2) float16 {
uint16_t x;
#ifdef __cplusplus
// The following defaulted special class member functions
// are added to make float16 pass the std::is_trivial test
float16() = default;
float16(const float16& o) = default;
float16& operator=(const float16& o) = default;
float16(float16&& o) = default;
float16& operator=(float16&& o) = default;
~float16() = default;
// Constructors
#ifdef CINN_CUDA_FP16
__host__ __device__ inline explicit float16(const half& h) {
#if (CUDA_VERSION >= 9000)
x = reinterpret_cast<__half_raw*>(const_cast<half*>(&h))->x;
#else
x = h.x;
#endif // CUDA_VERSION >= 9000
}
#endif // CINN_CUDA_FP16
__host__ __device__ inline explicit float16(float val) {
#if defined(CINN_CUDA_FP16) && (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300)
half tmp = __float2half(val);
x = *reinterpret_cast<uint16_t*>(&tmp);
#elif defined(__F16C__) && defined(__CINN_x86__)
x = _cvtss_sh(val, 0);
#else
// Conversion routine adapted from
// http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
Bits v, s;
v.f = val;
uint32_t sign = v.si & sigN;
v.si ^= sign;
sign >>= shiftSign; // logical shift
s.si = mulN;
s.si = s.f * v.f; // correct subnormals
v.si ^= (s.si ^ v.si) & -(minN > v.si);
v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
v.ui >>= shift; // logical shift
v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
x = v.ui | sign;
#endif
}
__host__ __device__ inline explicit float16(bool b) : x(b ? 0x3c00 : 0) {}
template <class T>
__host__ __device__ inline explicit float16(const T& val)
: x(float16(static_cast<float>(val)).x) {}
// Assignment operators
#ifdef CINN_CUDA_FP16
__host__ __device__ inline float16& operator=(const half& rhs) {
#if CUDA_VERSION >= 9000
x = reinterpret_cast<__half_raw*>(const_cast<half*>(&rhs))->x;
#else
x = rhs.x;
#endif
return *this;
}
#endif
__host__ __device__ inline float16& operator=(bool b) {
x = b ? 0x3c00 : 0;
return *this;
}
__host__ __device__ inline float16& operator=(int8_t val) {
x = float16(val).x;
return *this;
}
__host__ __device__ inline float16& operator=(uint8_t val) {
x = float16(val).x;
return *this;
}
__host__ __device__ inline float16& operator=(int16_t val) {
x = float16(val).x;
return *this;
}
__host__ __device__ inline float16& operator=(uint16_t val) {
x = float16(val).x;
return *this;
}
__host__ __device__ inline float16& operator=(int32_t val) {
x = float16(val).x;
return *this;
}
__host__ __device__ inline float16& operator=(uint32_t val) {
x = float16(val).x;
return *this;
}
__host__ __device__ inline float16& operator=(int64_t val) {
x = float16(val).x;
return *this;
}
__host__ __device__ inline float16& operator=(uint64_t val) {
x = float16(val).x;
return *this;
}
__host__ __device__ inline float16& operator=(float val) {
x = float16(val).x;
return *this;
}
__host__ __device__ inline float16& operator=(double val) {
x = float16(val).x;
return *this;
}
// Conversion opertors
#ifdef CINN_CUDA_FP16
__host__ __device__ inline half to_half() const {
#if CUDA_VERSION >= 9000
__half_raw h;
h.x = x;
return half(h);
#else
half h;
h.x = x;
return h;
#endif // CUDA_VERSION >= 9000
}
#endif // CINN_CUDA_FP16
__host__ __device__ inline operator float() const {
#if defined(CINN_CUDA_FP16) && (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300)
half tmp = *reinterpret_cast<const half*>(this);
return __half2float(tmp);
#elif defined(__F16C__)
return _cvtsh_ss(this->x);
#else
// Conversion routine adapted from
// http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
Bits v;
v.ui = this->x;
int32_t sign = v.si & sigC;
v.si ^= sign;
sign <<= shiftSign;
v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
Bits s;
s.si = mulC;
s.f *= v.si;
int32_t mask = -(norC > v.si);
v.si <<= shift;
v.si ^= (s.si ^ v.si) & mask;
v.si |= sign;
return v.f;
#endif
}
__host__ __device__ inline explicit operator bool() const {
return (x & 0x7fff) != 0;
}
__host__ __device__ inline explicit operator int8_t() const {
return static_cast<int8_t>(static_cast<float>(*this));
}
__host__ __device__ inline explicit operator uint8_t() const {
return static_cast<uint8_t>(static_cast<float>(*this));
}
__host__ __device__ inline explicit operator int16_t() const {
return static_cast<int16_t>(static_cast<float>(*this));
}
__host__ __device__ inline explicit operator uint16_t() const {
return static_cast<uint16_t>(static_cast<float>(*this));
}
__host__ __device__ inline explicit operator int32_t() const {
return static_cast<int32_t>(static_cast<float>(*this));
}
__host__ __device__ inline explicit operator uint32_t() const {
return static_cast<uint32_t>(static_cast<float>(*this));
}
__host__ __device__ inline explicit operator int64_t() const {
return static_cast<int64_t>(static_cast<float>(*this));
}
__host__ __device__ inline explicit operator uint64_t() const {
return static_cast<uint64_t>(static_cast<float>(*this));
}
__host__ __device__ inline operator double() const {
return static_cast<double>(static_cast<float>(*this));
}
private:
union Bits {
float f;
int32_t si;
uint32_t ui;
};
static const int shift = 13;
static const int shiftSign = 16;
static const int32_t infN = 0x7F800000;
static const int32_t maxN = 0x477FE000; // max flt16 as flt32
static const int32_t minN = 0x38800000; // min flt16 normal as flt32
static const int32_t sigN = 0x80000000; // sign bit
static constexpr int32_t infC = infN >> shift;
static constexpr int32_t nanN = (infC + 1)
<< shift; // minimum flt16 nan as float32
static constexpr int32_t maxC = maxN >> shift;
static constexpr int32_t minC = minN >> shift;
static constexpr int32_t sigC = sigN >> shiftSign;
static const int32_t mulN = 0x52000000; // (1 << 23) / minN
static const int32_t mulC = 0x33800000; // minN / (1 << (23 - shift))
static const int32_t subC = 0x003FF; // max flt32 subnormal downshifted
static const int32_t norC = 0x00400; // min flt32 normal downshifted
static constexpr int32_t maxD = infC - maxC - 1;
static constexpr int32_t minD = minC - subC - 1;
#endif // __cplusplus
};
struct CINN_ALIGN(32) float8 {
float x, y, z, w, v, u, t, s;
};
struct CINN_ALIGN(16) half8 {
float16 x, y, z, w, v, u, t, s;
};
struct CINN_ALIGN(8) half4 {
float16 x, y, z, w;
};
#ifdef __cplusplus
// Arithmetic operators on GPU
// CUDA 9.0 provides built-in arithmetic operators for half while
// CUDA 7.5 and 8.0 do not. The arithmetic operators defined here are
// for users to write similar CUDA code in CUDA 7.5 and 8.0 as in
// CUDA 9.0 regarding the half data type.
// ROCM has built-in arithmetic operators as not defined
// __HIP_NO_HALF_OPERATORS__
#if defined(CINN_CUDA_FP16) && CUDA_VERSION < 9000
__device__ inline half operator+(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hadd(a, b);
#else
float res = static_cast<float>(float16(a)) + static_cast<float>(float16(b));
return float16(res).to_half();
#endif
}
__device__ inline half operator-(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hsub(a, b);
#else
float res = static_cast<float>(float16(a)) - static_cast<float>(float16(b));
return float16(res).to_half();
#endif
}
__device__ inline half operator*(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hmul(a, b);
#else
float res = static_cast<float>(float16(a)) * static_cast<float>(float16(b));
return float16(res).to_half();
#endif
}
__device__ inline half operator/(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
float num = __half2float(a);
float denom = __half2float(b);
return __float2half(num / denom);
#else
float res = static_cast<float>(float16(a)) / static_cast<float>(float16(b));
return float16(res).to_half();
#endif
}
__device__ inline half operator-(const half& a) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hneg(a);
#else
float res = -static_cast<float>(float16(a));
return float16(res).to_half();
#endif
}
__device__ inline half& operator+=(half& a, const half& b) { // NOLINT
a = a + b;
return a;
}
__device__ inline half& operator-=(half& a, const half& b) { // NOLINT
a = a - b;
return a;
}
__device__ inline half& operator*=(half& a, const half& b) { // NOLINT
a = a * b;
return a;
}
__device__ inline half& operator/=(half& a, const half& b) { // NOLINT
a = a / b;
return a;
}
__device__ inline bool operator==(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __heq(a, b);
#else
return static_cast<float>(float16(a)) == static_cast<float>(float16(b));
#endif
}
__device__ inline bool operator!=(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hne(a, b);
#else
return static_cast<float>(float16(a)) != static_cast<float>(float16(b));
#endif
}
__device__ inline bool operator<(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hlt(a, b);
#else
return static_cast<float>(float16(a)) < static_cast<float>(float16(b));
#endif
}
__device__ inline bool operator<=(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hle(a, b);
#else
return static_cast<float>(float16(a)) <= static_cast<float>(float16(b));
#endif
}
__device__ inline bool operator>(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hgt(a, b);
#else
return static_cast<float>(float16(a)) > static_cast<float>(float16(b));
#endif
}
__device__ inline bool operator>=(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hge(a, b);
#else
return static_cast<float>(float16(a)) >= static_cast<float>(float16(b));
#endif
}
#endif // CINN_CUDA_FP16
// Arithmetic operators for float16 on GPU
__host__ __device__ inline float16 operator+(const float16& a,
const float16& b) {
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hadd(a.to_half(), b.to_half()));
#else
return float16(static_cast<float>(a) + static_cast<float>(b));
#endif
}
__host__ __device__ inline float16 operator-(const float16& a,
const float16& b) {
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hsub(a.to_half(), b.to_half()));
#else
return float16(static_cast<float>(a) - static_cast<float>(b));
#endif
}
__host__ __device__ inline float16 operator*(const float16& a,
const float16& b) {
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hmul(a.to_half(), b.to_half()));
#else
return float16(static_cast<float>(a) * static_cast<float>(b));
#endif
}
__host__ __device__ inline float16 operator/(const float16& a,
const float16& b) {
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
// TODO(kexinzhao): check which cuda version starts to support __hdiv
float num = __half2float(a.to_half());
float denom = __half2float(b.to_half());
return float16(num / denom);
#else
return float16(static_cast<float>(a) / static_cast<float>(b));
#endif
}
__host__ __device__ inline float16 operator-(const float16& a) {
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hneg(a.to_half()));
#else
float16 res;
res.x = a.x ^ 0x8000;
return res;
#endif
}
__host__ __device__ inline float16& operator+=(float16& a, // NOLINT
const float16& b) { // NOLINT
a = a + b;
return a;
}
__host__ __device__ inline float16& operator-=(float16& a, // NOLINT
const float16& b) { // NOLINT
a = a - b;
return a;
}
__host__ __device__ inline float16& operator*=(float16& a, // NOLINT
const float16& b) { // NOLINT
a = a * b;
return a;
}
__host__ __device__ inline float16& operator/=(float16& a, // NOLINT
const float16& b) { // NOLINT
a = a / b;
return a;
}
__host__ __device__ inline bool operator==(const float16& a, const float16& b) {
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __heq(a.to_half(), b.to_half());
#else
return static_cast<float>(a) == static_cast<float>(b);
#endif
}
__host__ __device__ inline bool operator!=(const float16& a, const float16& b) {
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hne(a.to_half(), b.to_half());
#else
return static_cast<float>(a) != static_cast<float>(b);
#endif
}
__host__ __device__ inline bool operator<(const float16& a, const float16& b) {
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hlt(a.to_half(), b.to_half());
#else
return static_cast<float>(a) < static_cast<float>(b);
#endif
}
__host__ __device__ inline bool operator<=(const float16& a, const float16& b) {
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hle(a.to_half(), b.to_half());
#else
return static_cast<float>(a) <= static_cast<float>(b);
#endif
}
__host__ __device__ inline bool operator>(const float16& a, const float16& b) {
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hgt(a.to_half(), b.to_half());
#else
return static_cast<float>(a) > static_cast<float>(b);
#endif
}
__host__ __device__ inline bool operator>=(const float16& a, const float16& b) {
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hge(a.to_half(), b.to_half());
#else
return static_cast<float>(a) >= static_cast<float>(b);
#endif
}
#endif // __cplusplus
__host__ __device__ inline float16 raw_uint16_to_float16(uint16_t a) {
float16 res;
res.x = a;
return res;
}
__host__ __device__ inline bool(isnan)(const float16& a) {
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hisnan(a.to_half());
#else
return (a.x & 0x7fff) > 0x7c00;
#endif
}
__host__ __device__ inline bool(isinf)(const float16& a) {
return (a.x & 0x7fff) == 0x7c00;
}
__host__ __device__ inline bool(isfinite)(const float16& a) {
return !((isnan)(a)) && !((isinf)(a));
}
__host__ __device__ inline float16(abs)(const float16& a) {
#if defined(CINN_CUDA_FP16) && (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
return static_cast<float16>(__habs(a.to_half()));
#else
return static_cast<float16>(fabsf(static_cast<float>(a)));
#endif
}
__host__ __device__ inline float16(log)(const float16& a) {
return float16(std::log(static_cast<float>(a)));
}
#ifdef __cplusplus
} // namespace common
} // namespace cinn
#endif // __cplusplus
#if defined(__cplusplus) && defined(CINN_CUDA_FP16)
__device__ inline cinn::common::float16 __shfl_sync(unsigned mask,
cinn::common::float16 var,
int srcLane,
int width = warpSize) {
return cinn::common::float16(
__shfl_sync(mask, var.to_half(), srcLane, width));
}
__device__ inline cinn::common::float16 __shfl_up_sync(
unsigned mask,
cinn::common::float16 var,
unsigned int delta,
int width = warpSize) {
return cinn::common::float16(
__shfl_up_sync(mask, var.to_half(), delta, width));
}
__device__ inline cinn::common::float16 __shfl_down_sync(
unsigned mask,
cinn::common::float16 var,
unsigned int delta,
int width = warpSize) {
return cinn::common::float16(
__shfl_down_sync(mask, var.to_half(), delta, width));
}
__device__ inline cinn::common::float16 __shfl_xor_sync(
unsigned mask,
cinn::common::float16 var,
int laneMask,
int width = warpSize) {
return cinn::common::float16(
__shfl_xor_sync(mask, var.to_half(), laneMask, width));
}
__host__ __device__ inline cinn::common::float16 max(
const cinn::common::float16& a, const cinn::common::float16& b) {
return a > b ? a : b;
}
__host__ __device__ inline cinn::common::float16 min(
const cinn::common::float16& a, const cinn::common::float16& b) {
return a < b ? a : b;
}
#endif // __cplusplus && CINN_CUDA_FP16
#endif // CINN_COMMON_FLOAT16_H
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <random>
#include <vector>
#include "paddle/cinn/common/bfloat16.h"
#include "paddle/cinn/common/float16.h"
namespace cinn {
namespace common {
#define CUDA_CALL(func) \
{ \
auto status = func; \
if (status != cudaSuccess) { \
LOG(FATAL) << "CUDA Error : " << cudaGetErrorString(status); \
} \
}
class CudaMem {
public:
CudaMem() = default;
void* mutable_data(size_t bytes) {
CHECK_GT(bytes, 0) << "Cannot allocate empty memory!";
if (ptr) {
CHECK_EQ(bytes, bytes_) << "Try allocate memory twice!";
return ptr;
}
CUDA_CALL(cudaMalloc(&ptr, bytes));
bytes_ = bytes;
return ptr;
}
template <typename T>
T* mutable_data(size_t num) {
return reinterpret_cast<T*>(mutable_data(num * sizeof(T)));
}
void* data() const {
CHECK(ptr) << "Try get nullptr!";
return ptr;
}
template <typename T>
T* data() const {
return reinterpret_cast<T*>(data());
}
void MemcpyFromHost(const void* src,
size_t bytes,
cudaStream_t stream = nullptr) {
CHECK_LE(bytes, bytes_) << "Too many data need copy";
CUDA_CALL(cudaMemcpyAsync(ptr, src, bytes, cudaMemcpyHostToDevice, stream));
}
void MemcpyToHost(void* dst, size_t bytes, cudaStream_t stream = nullptr) {
CHECK_LE(bytes, bytes_) << "Too many data need copy";
CUDA_CALL(cudaMemcpyAsync(dst, ptr, bytes, cudaMemcpyDeviceToHost, stream));
}
~CudaMem() {
if (ptr) {
cudaFree(ptr);
}
bytes_ = 0;
}
private:
void* ptr{nullptr};
size_t bytes_{0};
};
__global__ void cast_fp32_to_fp16_cuda_kernel(const float* input,
const int num,
float16* out) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num) {
out[idx] = float16(input[idx]);
}
}
__global__ void cast_fp16_to_fp32_cuda_kernel(const float16* input,
const int num,
float* out) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num) {
out[idx] = static_cast<float>(input[idx]);
}
}
__global__ void test_fp16_cuda_kernel(const float16* x,
const float16* y,
const int num,
float16* out) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num) {
float16 x_i = x[idx], y_i = y[idx];
x_i += float16(1);
out[idx] = (x_i + y_i) * (x_i - y_i);
}
}
__global__ void cast_fp32_to_bf16_cuda_kernel(const float* input,
const int num,
bfloat16* out) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num) {
out[idx] = bfloat16(input[idx]);
}
}
__global__ void cast_bf16_to_fp32_cuda_kernel(const bfloat16* input,
const int num,
float* out) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num) {
out[idx] = static_cast<float>(input[idx]);
}
}
__global__ void test_bf16_cuda_kernel(const bfloat16* x,
const bfloat16* y,
const int num,
bfloat16* out) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num) {
bfloat16 x_i = x[idx], y_i = y[idx];
x_i += bfloat16(1);
out[idx] = (x_i + y_i) * (x_i - y_i);
}
}
__global__ void test_fp32_cuda_kernel(const float* x,
const float* y,
const int num,
float* out) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num) {
float x_i = x[idx], y_i = y[idx];
x_i += 1.0f;
out[idx] = (x_i + y_i) * (x_i - y_i);
}
}
TEST(FP16_BF16, basic_cuda) {
#ifdef CUDA_VERSION
LOG(INFO) << "CUDA version: " << CUDA_VERSION;
#endif
int num = 2048;
cudaStream_t stream;
CUDA_CALL(cudaStreamCreate(&stream));
dim3 block = 1024;
dim3 grid = (num + block.x - 1) / block.x;
std::vector<float> x_fp32_host(num), y_fp32_host(num);
{ // step1 : generate input data
std::random_device r;
std::default_random_engine eng(r());
std::uniform_real_distribution<float> dis(1e-5f, 1.0f);
for (int i = 0; i < num; ++i) {
x_fp32_host[i] = dis(eng);
y_fp32_host[i] = dis(eng);
}
}
CudaMem x_fp32_device, y_fp32_device, out_fp32_device;
{ // step2 : compute fp32 result
auto x_fp32_ptr = x_fp32_device.mutable_data<float>(num);
auto y_fp32_ptr = y_fp32_device.mutable_data<float>(num);
auto out_fp32_ptr = out_fp32_device.mutable_data<float>(num);
x_fp32_device.MemcpyFromHost(
x_fp32_host.data(), num * sizeof(float), stream);
y_fp32_device.MemcpyFromHost(
y_fp32_host.data(), num * sizeof(float), stream);
test_fp32_cuda_kernel<<<grid, block, 0, stream>>>(
x_fp32_ptr, y_fp32_ptr, num, out_fp32_ptr);
}
CudaMem x_fp16_device, y_fp16_device, out_fp16_device;
CudaMem x_bf16_device, y_bf16_device, out_bf16_device;
{ // step3 : compute fp16/bf16 result
// step3.1 : compute fp16 result
auto x_fp16_ptr = x_fp16_device.mutable_data<float16>(num);
auto y_fp16_ptr = y_fp16_device.mutable_data<float16>(num);
auto out_fp16_ptr = out_fp16_device.mutable_data<float16>(num);
cast_fp32_to_fp16_cuda_kernel<<<grid, block, 0, stream>>>(
x_fp32_device.data<float>(), num, x_fp16_ptr);
cast_fp32_to_fp16_cuda_kernel<<<grid, block, 0, stream>>>(
y_fp32_device.data<float>(), num, y_fp16_ptr);
test_fp16_cuda_kernel<<<grid, block, 0, stream>>>(
x_fp16_ptr, y_fp16_ptr, num, out_fp16_ptr);
// step3.2 : compute bf16 result
auto x_bf16_ptr = x_bf16_device.mutable_data<bfloat16>(num);
auto y_bf16_ptr = y_bf16_device.mutable_data<bfloat16>(num);
auto out_bf16_ptr = out_bf16_device.mutable_data<bfloat16>(num);
cast_fp32_to_bf16_cuda_kernel<<<grid, block, 0, stream>>>(
x_fp32_device.data<float>(), num, x_bf16_ptr);
cast_fp32_to_bf16_cuda_kernel<<<grid, block, 0, stream>>>(
y_fp32_device.data<float>(), num, y_bf16_ptr);
test_bf16_cuda_kernel<<<grid, block, 0, stream>>>(
x_bf16_ptr, y_bf16_ptr, num, out_bf16_ptr);
}
CudaMem fp32res_fp16_device;
CudaMem fp32res_bf16_device;
{ // step4 : cast fp16/bf16 result to fp32 result
// step4.1 : cast fp16 result to fp32 result
auto fp32res_fp16_ptr = fp32res_fp16_device.mutable_data<float>(num);
cast_fp16_to_fp32_cuda_kernel<<<grid, block, 0, stream>>>(
out_fp16_device.data<float16>(), num, fp32res_fp16_ptr);
// step4.2 : cast bf16 result to fp32 result
auto fp32res_bf16_ptr = fp32res_bf16_device.mutable_data<float>(num);
cast_bf16_to_fp32_cuda_kernel<<<grid, block, 0, stream>>>(
out_bf16_device.data<bfloat16>(), num, fp32res_bf16_ptr);
}
std::vector<float> out_fp32_host(num), out_fp16_host(num), out_bf16_host(num);
{ // step5 : copy result from device to host
out_fp32_device.MemcpyToHost(
out_fp32_host.data(), num * sizeof(float), stream);
fp32res_fp16_device.MemcpyToHost(
out_fp16_host.data(), num * sizeof(float), stream);
fp32res_bf16_device.MemcpyToHost(
out_bf16_host.data(), num * sizeof(float), stream);
}
CUDA_CALL(cudaStreamSynchronize(stream));
for (int i = 0; i < num; ++i) {
ASSERT_NEAR(out_fp32_host[i], out_fp16_host[i], 1e-2f);
ASSERT_NEAR(out_fp32_host[i], out_bf16_host[i], 1e-1f);
}
CUDA_CALL(cudaStreamDestroy(stream));
}
} // namespace common
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <random>
#include <vector>
#include "paddle/cinn/common/bfloat16.h"
#include "paddle/cinn/common/float16.h"
namespace cinn {
namespace common {
std::vector<float16> test_fp16_host_kernel(const float16* x,
const float16* y,
const int num) {
std::vector<float16> out(num);
for (int idx = 0; idx < num; ++idx) {
float16 x_i = x[idx], y_i = y[idx];
x_i += float16(1);
out[idx] = (x_i + y_i) * (x_i - y_i);
}
return out;
}
std::vector<bfloat16> test_bf16_host_kernel(const bfloat16* x,
const bfloat16* y,
const int num) {
std::vector<bfloat16> out(num);
for (int idx = 0; idx < num; ++idx) {
bfloat16 x_i = x[idx], y_i = y[idx];
x_i += bfloat16(1);
out[idx] = (x_i + y_i) * (x_i - y_i);
}
return out;
}
std::vector<float> test_fp32_host_kernel(const float* x,
const float* y,
const int num) {
std::vector<float> out(num);
for (int idx = 0; idx < num; ++idx) {
float x_i = x[idx], y_i = y[idx];
x_i += 1.0f;
out[idx] = (x_i + y_i) * (x_i - y_i);
}
return out;
}
TEST(FP16_BF16, basic_host) {
int num = 2048;
// int num = 2;
std::vector<float16> x_fp16(num), y_fp16(num);
std::vector<bfloat16> x_bf16(num), y_bf16(num);
std::vector<float> x_fp32(num), y_fp32(num);
std::random_device r;
std::default_random_engine eng(r());
std::uniform_real_distribution<float> dis(1e-5f, 1.0f);
for (int i = 0; i < num; ++i) {
x_fp16[i] = x_fp32[i] = dis(eng);
y_fp16[i] = y_fp32[i] = dis(eng);
x_fp16[i] = x_fp32[i];
y_fp16[i] = y_fp32[i];
x_bf16[i] = x_fp32[i];
y_bf16[i] = y_fp32[i];
}
auto out_fp16 = test_fp16_host_kernel(x_fp16.data(), y_fp16.data(), num);
ASSERT_EQ(out_fp16.size(), num);
auto out_bf16 = test_bf16_host_kernel(x_bf16.data(), y_bf16.data(), num);
ASSERT_EQ(out_bf16.size(), num);
auto out_fp32 = test_fp32_host_kernel(x_fp32.data(), y_fp32.data(), num);
ASSERT_EQ(out_fp32.size(), num);
for (int i = 0; i < num; ++i) {
ASSERT_NEAR(static_cast<float>(out_fp16[i]), out_fp32[i], 1e-2f);
ASSERT_NEAR(static_cast<float>(out_bf16[i]), out_fp32[i], 1e-1f);
}
}
} // namespace common
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iostream>
#include <limits>
#include "paddle/cinn/common/bfloat16.h"
#include "paddle/cinn/common/float16.h"
namespace std {
// Override the std::is_pod::value for float16 and bfloat16
// The reason is that different compilers implemented std::is_pod based on
// different C++ standards. float16 class is a plain old data in C++11 given
// that it is both trivial and standard_layout.
// However, std::is_pod in nvcc 8.0 host c++ compiler follows C++0x and is
// more restricted in that you cannot provide any customized
// constructor in float16. Hence, we override is_pod here following C++11
// so that .cu files can be successfully compiled by nvcc.
// for float16
template <>
struct is_pod<cinn::common::float16> {
static const bool value = is_trivial<cinn::common::float16>::value &&
is_standard_layout<cinn::common::float16>::value;
};
template <>
struct is_floating_point<cinn::common::float16>
: std::integral_constant<
bool,
std::is_same<
cinn::common::float16,
typename std::remove_cv<cinn::common::float16>::type>::value> {};
template <>
struct is_signed<cinn::common::float16> {
static const bool value = true;
};
template <>
struct is_unsigned<cinn::common::float16> {
static const bool value = false;
};
__host__ __device__ inline cinn::common::float16 abs(
const cinn::common::float16& a) {
return cinn::common::abs(a);
}
inline bool isnan(const cinn::common::float16& a) {
return cinn::common::isnan(a);
}
inline bool isinf(const cinn::common::float16& a) {
return cinn::common::isinf(a);
}
inline bool isfinite(const cinn::common::float16& a) {
return cinn::common::isfinite(a);
}
template <>
struct numeric_limits<cinn::common::float16> {
static const bool is_specialized = true;
static const bool is_signed = true;
static const bool is_integer = false;
static const bool is_exact = false;
static const bool has_infinity = true;
static const bool has_quiet_NaN = true;
static const bool has_signaling_NaN = true;
static const float_denorm_style has_denorm = denorm_present;
static const bool has_denorm_loss = false;
static const std::float_round_style round_style = std::round_to_nearest;
static const bool is_iec559 = false;
static const bool is_bounded = false;
static const bool is_modulo = false;
static const int digits = 11;
static const int digits10 = 3;
static const int max_digits10 = 5;
static const int radix = 2;
static const int min_exponent = -13;
static const int min_exponent10 = -4;
static const int max_exponent = 16;
static const int max_exponent10 = 4;
static const bool traps = true;
static const bool tinyness_before = false;
__host__ __device__ static cinn::common::float16(min)() {
return cinn::common::raw_uint16_to_float16(0x400);
}
__host__ __device__ static cinn::common::float16 lowest() {
return cinn::common::raw_uint16_to_float16(0xfbff);
}
__host__ __device__ static cinn::common::float16(max)() {
return cinn::common::raw_uint16_to_float16(0x7bff);
}
__host__ __device__ static cinn::common::float16 epsilon() {
return cinn::common::raw_uint16_to_float16(0x0800);
}
__host__ __device__ static cinn::common::float16 round_error() {
return cinn::common::float16(0.5);
}
__host__ __device__ static cinn::common::float16 infinity() {
return cinn::common::raw_uint16_to_float16(0x7c00);
}
__host__ __device__ static cinn::common::float16 quiet_NaN() {
return cinn::common::raw_uint16_to_float16(0x7e00);
}
__host__ __device__ static cinn::common::float16 signaling_NaN() {
return cinn::common::raw_uint16_to_float16(0x7e00);
}
__host__ __device__ static cinn::common::float16 denorm_min() {
return cinn::common::raw_uint16_to_float16(0x1);
}
};
// for bfloat16
template <>
struct is_pod<cinn::common::bfloat16> {
static const bool value = is_trivial<cinn::common::bfloat16>::value &&
is_standard_layout<cinn::common::bfloat16>::value;
};
template <>
struct is_floating_point<cinn::common::bfloat16>
: std::integral_constant<
bool,
std::is_same<
cinn::common::bfloat16,
typename std::remove_cv<cinn::common::bfloat16>::type>::value> {};
template <>
struct is_signed<cinn::common::bfloat16> {
static const bool value = true;
};
template <>
struct is_unsigned<cinn::common::bfloat16> {
static const bool value = false;
};
inline bool isnan(const cinn::common::bfloat16& a) {
return cinn::common::isnan(a);
}
inline bool isinf(const cinn::common::bfloat16& a) {
return cinn::common::isinf(a);
}
template <>
struct numeric_limits<cinn::common::bfloat16> {
static const bool is_specialized = true;
static const bool is_signed = true;
static const bool is_integer = false;
static const bool is_exact = false;
static const bool has_infinity = true;
static const bool has_quiet_NaN = true;
static const bool has_signaling_NaN = true;
static const float_denorm_style has_denorm = denorm_present;
static const bool has_denorm_loss = false;
static const std::float_round_style round_style = std::round_to_nearest;
static const bool is_iec559 = false;
static const bool is_bounded = false;
static const bool is_modulo = false;
static const int digits = 8;
static const int digits10 = 2;
static const int max_digits10 = 9;
static const int radix = 2;
static const int min_exponent = -125;
static const int min_exponent10 = -37;
static const int max_exponent = 128;
static const int max_exponent10 = 38;
static const bool traps = true;
static const bool tinyness_before = false;
__host__ __device__ static cinn::common::bfloat16(min)() {
return cinn::common::raw_uint16_to_bfloat16(0x007f);
}
__host__ __device__ static cinn::common::bfloat16 lowest() {
return cinn::common::raw_uint16_to_bfloat16(0xff7f);
}
__host__ __device__ static cinn::common::bfloat16(max)() {
return cinn::common::raw_uint16_to_bfloat16(0x7f7f);
}
__host__ __device__ static cinn::common::bfloat16 epsilon() {
return cinn::common::raw_uint16_to_bfloat16(0x3400);
}
__host__ __device__ static cinn::common::bfloat16 round_error() {
return cinn::common::bfloat16(0.5);
}
__host__ __device__ static cinn::common::bfloat16 infinity() {
return cinn::common::raw_uint16_to_bfloat16(0x7f80);
}
__host__ __device__ static cinn::common::bfloat16 quiet_NaN() {
return cinn::common::raw_uint16_to_bfloat16(0xffc1);
}
__host__ __device__ static cinn::common::bfloat16 signaling_NaN() {
return cinn::common::raw_uint16_to_bfloat16(0xff81);
}
__host__ __device__ static cinn::common::bfloat16 denorm_min() {
return cinn::common::raw_uint16_to_bfloat16(0x0001);
}
};
} // namespace std
namespace cinn {
namespace common {
inline std::ostream& operator<<(std::ostream& os, const float16& a) {
os << std::showpoint << static_cast<float>(a);
return os;
}
inline std::ostream& operator<<(std::ostream& os, const bfloat16& a) {
os << std::showpoint << static_cast<float>(a);
return os;
}
} // namespace common
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/graph_utils.h"
#include <glog/logging.h>
#include <deque>
#include <functional>
#include <set>
#include <stack>
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/utils/dot_lang.h"
namespace cinn {
namespace common {
namespace {
void DFSSortUtil(const GraphNode *node, std::vector<GraphNode *> *order) {}
std::vector<GraphNode *> DFSSort(const std::vector<GraphNode *> &nodes) {
LOG(FATAL) << "not implemented";
return {};
}
} // namespace
std::set<GraphNode *> Graph::dependencies(
const std::vector<GraphNode *> &targets) {
// A naive implementation.
std::set<GraphNode *> _targets(targets.begin(), targets.end());
std::set<GraphNode *> res;
int targets_count = 0;
while (targets_count != _targets.size()) {
targets_count = _targets.size();
for (auto *node : nodes()) {
if (_targets.count(node)) continue;
for (auto &edge : node->outlinks()) {
if (_targets.count(edge->sink())) {
res.insert(edge->sink());
_targets.insert(edge->sink());
}
}
}
}
return res;
}
std::vector<const GraphNode *> Graph::nodes() const {
std::vector<const GraphNode *> res;
for (auto &s : nodes_) res.push_back(s.get());
return res;
}
std::vector<GraphNode *> Graph::nodes() {
std::vector<GraphNode *> res;
for (auto &s : nodes_) res.push_back(s.get());
return res;
}
std::tuple<std::vector<GraphNode *>, std::vector<GraphEdge *>>
Graph::topological_order() const {
std::vector<GraphNode *> node_order;
std::vector<GraphEdge *> edge_order;
std::deque<GraphNode *> queue;
// collect indegreee.
std::map<std::string, int> indegree;
for (auto *n : nodes()) {
indegree[n->id()] = n->inlinks().size();
}
// insert start points first.
for (auto *n : start_points()) {
queue.push_back(&Reference(n));
}
// start to visit
int count = 0;
while (!queue.empty()) {
auto *top_node = queue.front();
top_node->set_index(count);
node_order.push_back(top_node);
count++;
queue.pop_front();
for (auto &edge : top_node->outlinks()) {
CHECK_EQ(edge->source(), top_node);
edge_order.push_back(edge.get());
auto *sink = edge->sink();
if ((--indegree[sink->id()]) == 0) {
queue.push_back(sink);
}
}
}
CHECK_EQ(node_order.size(), nodes().size())
<< "circle detected in the schedule graph:\n\n"
<< Visualize();
return std::make_tuple(node_order, edge_order);
}
std::vector<GraphNode *> Graph::dfs_order() {
return std::vector<GraphNode *>();
}
std::vector<const GraphNode *> Graph::start_points() const {
std::vector<const GraphNode *> res;
for (auto *node : nodes()) {
if (node->inlinks().empty()) res.push_back(node);
}
return res;
}
std::vector<GraphNode *> Graph::start_points() {
std::vector<GraphNode *> res;
for (auto *node : nodes()) {
if (node->inlinks().empty()) res.push_back(node);
}
return res;
}
GraphNode *Graph::RegisterNode(size_t key, GraphNode *node) {
registry_.emplace(key, node);
nodes_.emplace_back(node);
return node;
}
GraphNode *Graph::RegisterNode(const std::string &key, GraphNode *node) {
return RegisterNode(std::hash<std::string>{}(key), node);
}
GraphNode *Graph::RetrieveNode(size_t key) const {
auto it = registry_.find(key);
return it == registry_.end() ? nullptr : it->second;
}
GraphNode *Graph::RetrieveNode(const std::string &key) const {
return RetrieveNode(std::hash<std::string>()(key));
}
std::string Graph::Visualize() const {
utils::DotLang dot;
// 1. create nodes
for (auto &node : nodes_) {
dot.AddNode(node->id(), {}, "", "", true);
}
// 2. link each other
for (auto &source : nodes_) {
for (auto &sink : source->outlinks()) {
dot.AddEdge(source->id(), sink->sink()->id(), {});
}
}
return dot();
}
void Graph::ClearUnlinkedNodes(
absl::flat_hash_map<std::string, std::vector<int>> *shape_dict,
absl::flat_hash_map<std::string, Type> *type_dict,
absl::flat_hash_map<std::string, std::string> *layout_dict) {
CHECK(shape_dict);
CHECK(type_dict);
CHECK(layout_dict);
for (auto it = nodes_.begin(); it < nodes_.end(); ++it) {
auto node = *it;
if (node->inlinks().empty() && node->outlinks().empty()) {
VLOG(2) << "delete unlinked node: " << node->id();
nodes_.erase(it);
if (shape_dict->count(node->id())) {
shape_dict->erase(node->id());
}
if (type_dict->count(node->id())) {
type_dict->erase(node->id());
}
if (layout_dict->count(node->id())) {
layout_dict->erase(node->id());
}
--it;
}
}
}
const char *GraphNode::__type_info__ = "GraphNode";
bool GraphEdgeCompare::operator()(const Shared<GraphEdge> &a,
const Shared<GraphEdge> &b) const {
if (a->source()->id() == b->source()->id()) {
if (a->sink()->id() == b->sink()->id()) {
return a->index() < b->index();
}
return a->sink()->id() > b->sink()->id();
}
return a->source()->id() < b->source()->id();
}
std::set<GraphNode *> Graph::CollectNodes(
std::function<bool(const common::GraphNode *)> &&teller) {
std::set<GraphNode *> res;
for (auto *node : nodes()) {
if (teller(node)) res.insert(node);
}
return res;
}
} // namespace common
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
//! \file This file contains the utilities of graph.
#include <absl/container/flat_hash_map.h>
#include <glog/logging.h>
#include <algorithm>
#include <functional>
#include <list>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <tuple>
#include <vector>
#include "paddle/cinn/common/object.h"
#include "paddle/cinn/common/shared.h"
#include "paddle/cinn/common/type.h"
namespace cinn {
namespace common {
#ifdef As
#undef As
#endif
class GraphNode;
/**
* Edge in the graph, which can hold some attributes.
*/
class GraphEdge : public Object {
public:
GraphEdge(GraphNode* source, GraphNode* sink, int index = -1)
: source_(source), sink_(sink), index_(index) {}
GraphNode* source() const { return source_; }
GraphNode* sink() const { return sink_; }
const char* type_info() const override { return __type_info__; }
int index() const { return index_; }
private:
//! the index in sink node's inlinks_ or source node's outlinks_
//! this is used to keep the input/output tensor's order of operator node
int index_{-1};
//! Source of this edge.
GraphNode* source_{};
//! End of this edge.
GraphNode* sink_{};
static constexpr char* __type_info__ = "graph_edge";
};
struct GraphEdgeCompare {
bool operator()(const common::Shared<GraphEdge>& a,
const common::Shared<GraphEdge>& b) const;
};
/**
* @brief The base class of all node of graph.
* This is used to normalize and share the graph operations.
*/
class GraphNode : public Object {
public:
//! The unique identifier of the node.
virtual std::string id() const = 0;
inline int get_index() { return index; }
inline void set_index(int index) { this->index = index; }
//! Links from this to other.
template <typename EdgeT = GraphEdge>
std::tuple<EdgeT*, EdgeT*> LinkTo(GraphNode* other) {
EdgeT *a, *b;
CHECK(other);
CHECK_NE(other, this) << "Cannot link to itself";
auto outlink_edge = make_shared<GraphEdge>(this, other, index_outlinks);
auto inlink_edge =
make_shared<GraphEdge>(this, other, other->index_inlinks);
index_outlinks++;
other->index_inlinks++;
outlinks_.insert(outlink_edge);
other->inlinks_.insert(inlink_edge);
for (auto& item : outlinks_) {
if (item->index() == index_outlinks - 1) {
a = static_cast<EdgeT*>(item.get());
break;
}
}
for (auto& item : other->inlinks_) {
if (item->index() == other->index_inlinks - 1) {
b = static_cast<EdgeT*>(item.get());
break;
}
}
CHECK(a);
CHECK(b);
return std::make_tuple(a, b);
}
void Controls(GraphNode* other) {
bool outlink_linked = false;
bool inlink_linked = false;
for (auto& item : outlinks_) {
if (item->sink()->id() == other->id()) {
outlink_linked = true;
break;
}
}
for (auto& item : other->inlinks_) {
if (item->source()->id() == this->id()) {
inlink_linked = true;
break;
}
}
CHECK_EQ(outlink_linked, inlink_linked);
if (outlink_linked)
return;
else
this->LinkTo(other);
}
void UnLinkAllTo(GraphNode* other) {
if (other == this) return;
// remove all this node's outlink
{
auto it = std::find_if(
outlinks_.begin(), outlinks_.end(), [&](const Shared<GraphEdge>& x) {
return x->source() == this && x->sink() == other;
});
while (it != outlinks_.end()) {
outlinks_.erase(it);
it = std::find_if(outlinks_.begin(),
outlinks_.end(),
[&](const Shared<GraphEdge>& x) {
return x->source() == this && x->sink() == other;
});
}
}
// remove all other node's inlink
{
auto it = std::find_if(other->inlinks_.begin(),
other->inlinks_.end(),
[&](const Shared<GraphEdge>& x) {
return x->source() == this && x->sink() == other;
});
while (it != other->inlinks_.end()) {
other->inlinks_.erase(it);
it = std::find_if(other->inlinks_.begin(),
other->inlinks_.end(),
[&](const Shared<GraphEdge>& x) {
return x->source() == this && x->sink() == other;
});
}
}
}
void UnLinkSingleTo(GraphNode* other) {
if (other == this) return;
// remove single outlink
{
auto it = std::find_if(
outlinks_.begin(), outlinks_.end(), [&](const Shared<GraphEdge>& x) {
return x->source() == this && x->sink() == other;
});
if (it != outlinks_.end()) outlinks_.erase(it);
}
// remove single inlink
{
auto it = std::find_if(other->inlinks_.begin(),
other->inlinks_.end(),
[&](const Shared<GraphEdge>& x) {
return x->source() == this && x->sink() == other;
});
if (it != other->inlinks_.end()) other->inlinks_.erase(it);
}
}
bool IsLinkedTo(GraphNode* other) const {
for (auto& e : outlinks_) {
if (e->sink()->id() == other->id()) return true;
}
return false;
}
//! Get the input links of the node.
virtual const std::set<Shared<GraphEdge>, GraphEdgeCompare>& inlinks() const {
return inlinks_;
}
//! Get the output links of the node.
virtual const std::set<Shared<GraphEdge>, GraphEdgeCompare>& outlinks()
const {
return outlinks_;
}
//! Reset graph traversal meta info.
void ResetVisitMeta() { visited_time_ = 0; }
void VisitOnce() const { visited_time_++; }
bool visited() const {
return inlinks_.empty() || visited_time_ == inlinks_.size();
}
const char* type_info() const override { return __type_info__; }
GraphNode() = default;
static const char* __type_info__;
protected:
//! The input links of the node.
//! \note We record the raw pointer rather than the shared pointer to avoid
//! cycle reference.
std::set<common::Shared<GraphEdge>, GraphEdgeCompare> inlinks_;
//! The output links of the node.
//! \note We record the raw pointer rather than the shared pointer to avoid
//! cycle reference.
std::set<common::Shared<GraphEdge>, GraphEdgeCompare> outlinks_;
mutable int visited_time_{};
//! used to mark the index of node's input/output tensors
int index_inlinks{0};
int index_outlinks{0};
int index{0};
};
/**
* @brief The base class of all the graph.
*/
class Graph {
public:
using node_order_t = std::vector<GraphNode*>;
using edge_order_t = std::vector<GraphEdge*>;
//! Add a node to the graph.
//! @{
GraphNode* RegisterNode(size_t key, GraphNode* node);
GraphNode* RegisterNode(const std::string& key, GraphNode* node);
//! @}
//! Retrive a node.
//! @{
GraphNode* RetrieveNode(size_t key) const;
GraphNode* RetrieveNode(const std::string& key) const;
//! @}
//! Get the start point of the graph (the nodes those has no inlinks).
std::vector<const GraphNode*> start_points() const;
std::vector<GraphNode*> start_points();
//! Return the graph's nodes and edges(visited) in topological order.
std::tuple<std::vector<GraphNode*>, std::vector<GraphEdge*>>
topological_order() const;
//! Return the graph's DFS order.
std::vector<GraphNode*> dfs_order();
//! Return the dependency nodes of a set of nodes.
std::set<GraphNode*> dependencies(const std::vector<GraphNode*>& nodes);
std::vector<const GraphNode*> nodes() const;
std::vector<GraphNode*> nodes();
//! Collect the nodes match the condition defined by \p teller in the graph.
std::set<GraphNode*> CollectNodes(
std::function<bool(const common::GraphNode*)>&& teller);
void DropNode(GraphNode* n) {
auto it = std::find_if(
nodes_.begin(), nodes_.end(), [&](auto& x) { return x.get() == n; });
if (it != nodes_.end()) {
nodes_.erase(it);
}
}
//! Get a string representation to visualize a graph.
std::string Visualize() const;
void ClearUnlinkedNodes(
absl::flat_hash_map<std::string, std::vector<int>>* shape_dict,
absl::flat_hash_map<std::string, common::Type>* type_dict,
absl::flat_hash_map<std::string, std::string>* layout_dict);
size_t num_nodes() const { return nodes_.size(); }
protected:
//! A lookup table that map from hash key to graph node, note that it doesn't
//! own the graph node.
std::map<size_t, GraphNode*> registry_;
//! A list owns the graph nodes.
std::vector<Shared<GraphNode>> nodes_;
};
} // namespace common
} // namespace cinn
namespace std {
template <>
struct hash<cinn::common::GraphNode> {
size_t operator()(const cinn::common::GraphNode& x) {
return reinterpret_cast<size_t>(hash<std::string>()(x.id()));
}
};
} // namespace std
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/cinn/common/common.h"
namespace cinn {
namespace common {
struct GraphNodeWithName : public GraphNode {
explicit GraphNodeWithName(std::string name) : name(name) {}
std::string id() const override { return name; }
std::string name;
};
// A simple graph.
std::unique_ptr<Graph> CreateGraph0() {
std::unique_ptr<Graph> graph(new Graph);
auto* A = make_shared<GraphNodeWithName>("A");
auto* B = make_shared<GraphNodeWithName>("B");
auto* C = make_shared<GraphNodeWithName>("C");
auto* D = make_shared<GraphNodeWithName>("D");
auto* E = make_shared<GraphNodeWithName>("E");
graph->RegisterNode("A", A);
graph->RegisterNode("B", B);
graph->RegisterNode("C", C);
graph->RegisterNode("D", D);
graph->RegisterNode("E", E);
A->LinkTo(B);
A->LinkTo(C);
B->LinkTo(D);
C->LinkTo(D);
C->LinkTo(E);
return graph;
}
std::unique_ptr<Graph> CreateGraph1() {
std::unique_ptr<Graph> graph(new Graph);
auto* A = make_shared<GraphNodeWithName>("A");
auto* B = make_shared<GraphNodeWithName>("B");
graph->RegisterNode("A", A);
graph->RegisterNode("B", B);
B->LinkTo(A);
return graph;
}
TEST(Graph, Visualize) {
auto graph = CreateGraph0();
LOG(INFO) << "graph:\n" << graph->Visualize();
}
TEST(Graph, simple) {
auto graph = CreateGraph1();
Graph::node_order_t node_order;
Graph::edge_order_t edge_order;
std::tie(node_order, edge_order) = graph->topological_order();
LOG(INFO) << "graph1 " << graph->Visualize();
std::vector<GraphNode*> node_order_target(
{graph->RetrieveNode("B"), graph->RetrieveNode("A")});
ASSERT_EQ(node_order.size(), node_order_target.size());
for (int i = 0; i < node_order.size(); i++) {
EXPECT_EQ(node_order[i]->id(), node_order_target[i]->id());
}
}
} // namespace common
} // namespace cinn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment