Unverified Commit 01a4a0c8 authored by Haojie Wang's avatar Haojie Wang Committed by GitHub
Browse files

Merge pull request #882 from InfiniTensor/issue/810

issue/810 static compute graph infra
parents 3883f32f 39f9c349
......@@ -3,6 +3,8 @@
#include "../device.hpp"
#include "../memory.hpp"
#include "../graph/graph.hpp"
#include <infiniop.h>
#include <infinirt.h>
......@@ -40,6 +42,12 @@ void destroyEvent(infinirtEvent_t event);
float elapsedTime(infinirtEvent_t start, infinirtEvent_t end);
void streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event);
// Graph recording APIs
bool isGraphRecording();
void startGraphRecording();
void addGraphOperator(std::shared_ptr<graph::GraphOperator> op);
std::shared_ptr<graph::Graph> stopGraphRecording();
} // namespace context
} // namespace infinicore
#pragma once
#include <memory>
#include <vector>
#include "../tensor.hpp"
namespace infinicore::graph {
// Forward declarations
class GraphManager;
class GraphTensor : public Tensor {
public:
GraphTensor(const Tensor &);
};
class GraphOperator {
public:
void run() const;
~GraphOperator();
protected:
using run_schema = void (*)(void *);
using cleanup_schema = void (*)(void **);
void *planned_meta_;
run_schema runner_;
cleanup_schema deleter_;
};
class Graph {
public:
Graph() = default;
~Graph() = default;
void run() const;
protected:
void add_operator(std::shared_ptr<GraphOperator> op);
std::vector<std::shared_ptr<GraphOperator>> op_list_;
friend class GraphManager;
};
} // namespace infinicore::graph
#define INFINICORE_GRAPH_OP_CLASS(__OP_NAME__, ...) \
class __OP_NAME__ : public graph::GraphOperator { \
public: \
using schema = void (*)(__VA_ARGS__); \
using plan_schema = void *(*)(__VA_ARGS__); \
static common::OpDispatcher<plan_schema> &plan_dispatcher(); \
static common::OpDispatcher<run_schema> &run_dispatcher(); \
static common::OpDispatcher<cleanup_schema> &cleanup_dispatcher(); \
__OP_NAME__(__VA_ARGS__); \
static void execute(__VA_ARGS__); \
};
#define INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(__OP_NAME__) \
common::OpDispatcher<__OP_NAME__::plan_schema> &__OP_NAME__::plan_dispatcher() { \
static common::OpDispatcher<__OP_NAME__::plan_schema> dispatcher_; \
return dispatcher_; \
} \
common::OpDispatcher<__OP_NAME__::run_schema> &__OP_NAME__::run_dispatcher() { \
static common::OpDispatcher<__OP_NAME__::run_schema> dispatcher_; \
return dispatcher_; \
} \
common::OpDispatcher<__OP_NAME__::cleanup_schema> &__OP_NAME__::cleanup_dispatcher() { \
static common::OpDispatcher<__OP_NAME__::cleanup_schema> dispatcher_; \
return dispatcher_; \
}
#define INFINICORE_GRAPH_OP_DISPATCH(__DEVICE_TYPE__, ...) \
planned_meta_ = plan_dispatcher().lookup(__DEVICE_TYPE__)(__VA_ARGS__); \
runner_ = run_dispatcher().lookup(__DEVICE_TYPE__); \
deleter_ = cleanup_dispatcher().lookup(__DEVICE_TYPE__);
#define INFINICORE_GRAPH_OP_RECORD_OR_RUN(__OP_NAME__, ...) \
auto op = std::make_shared<__OP_NAME__>(__VA_ARGS__); \
if (context::isGraphRecording()) { \
context::addGraphOperator(op); \
} else { \
op->run(); \
}
#define INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(__OP_NAME__, __PLAN_F__, __RUN_F__, __CLEANUP_F__) \
static bool registered = []() { \
__OP_NAME__::plan_dispatcher().registerAll(__PLAN_F__, false); \
__OP_NAME__::run_dispatcher().registerAll(__RUN_F__, false); \
__OP_NAME__::cleanup_dispatcher().registerAll(__CLEANUP_F__, false); \
return true; \
}();
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
namespace infinicore::op {
class Gemm {
public:
using schema = void (*)(Tensor, Tensor, Tensor, float, float);
static void execute(Tensor c, Tensor a, Tensor b, float alpha, float beta);
static common::OpDispatcher<schema> &dispatcher();
};
INFINICORE_GRAPH_OP_CLASS(Gemm, Tensor, Tensor, Tensor, float, float);
Tensor gemm(Tensor a, Tensor b, float alpha = 1.0f, float beta = 0.0f);
void gemm_(Tensor c, Tensor a, Tensor b, float alpha, float beta);
......
......@@ -133,6 +133,8 @@ public:
void debug() const;
Tensor to_blob() const;
///
/// Data Transfer APIs
///
......@@ -294,7 +296,7 @@ protected:
friend class Tensor;
private:
protected:
TensorMetaData meta_;
TensorData data_;
};
......
......@@ -8,7 +8,10 @@ from infinicore.context import (
get_device,
get_device_count,
get_stream,
is_graph_recording,
set_device,
start_graph_recording,
stop_graph_recording,
sync_device,
sync_stream,
)
......@@ -81,6 +84,9 @@ __all__ = [
"set_device",
"sync_device",
"sync_stream",
"is_graph_recording",
"start_graph_recording",
"stop_graph_recording",
# Data Types.
"bfloat16",
"bool",
......
import infinicore.device
from infinicore.graph import Graph
from infinicore.lib import _infinicore
......@@ -49,3 +50,24 @@ def get_stream():
stream: The current stream object
"""
return _infinicore.get_stream()
def is_graph_recording():
"""Check if the current graph is recording.
Returns:
bool: True if the current graph is recording, False otherwise
"""
return _infinicore.is_graph_recording()
def start_graph_recording(device=None):
"""Start recording the current graph."""
if device is not None:
set_device(device)
_infinicore.start_graph_recording()
def stop_graph_recording():
"""Stop recording the current graph."""
return Graph(_infinicore.stop_graph_recording())
from infinicore.lib import _infinicore
class Graph:
"""
Python wrapper around a InfiniCore Graph instance.
"""
def __init__(self, graph: _infinicore.Graph):
if not isinstance(graph, _infinicore.Graph):
raise TypeError("Expected _infinicore.Graph")
self._graph = graph
def run(self):
return self._graph.run()
def __repr__(self):
return f"<Graph wrapper of {self._graph!r}>"
......@@ -12,12 +12,18 @@ DevicePinnedHostAllocator::~DevicePinnedHostAllocator() {
}
std::byte *DevicePinnedHostAllocator::allocate(size_t size) {
if (size == 0) {
return nullptr;
}
void *ptr;
INFINICORE_CHECK_ERROR(infinirtMallocHost(&ptr, size));
return (std::byte *)ptr;
}
void DevicePinnedHostAllocator::deallocate(std::byte *ptr) {
if (ptr == nullptr) {
return;
}
if (owner_ == context::getDevice()) {
INFINICORE_CHECK_ERROR(infinirtFreeHost(ptr));
gc();
......
......@@ -4,10 +4,16 @@
namespace infinicore {
std::byte *HostAllocator::allocate(size_t size) {
if (size == 0) {
return nullptr;
}
return (std::byte *)std::malloc(size);
}
void HostAllocator::deallocate(std::byte *ptr) {
if (ptr == nullptr) {
return;
}
std::free(ptr);
}
......
#include "pinnable_block_allocator.hpp"
#include "../context_impl.hpp"
#include "../../utils.hpp"
#include <algorithm>
......@@ -35,6 +37,9 @@ PinnableBlockAllocator::PinnableBlockAllocator(Device device)
// ------------------- allocate -------------------
std::byte *PinnableBlockAllocator::allocate(size_t size) {
if (size == 0) {
return nullptr;
}
std::lock_guard<std::mutex> lock(mutex_);
// Align size to 256 bytes for GPU
......@@ -92,7 +97,7 @@ std::byte *PinnableBlockAllocator::allocate(size_t size) {
// ------------------- deallocate -------------------
void PinnableBlockAllocator::deallocate(std::byte *ptr) {
if (!ptr) {
if (ptr == nullptr) {
return;
}
......
......@@ -2,8 +2,6 @@
#include "memory_allocator.hpp"
#include "../context_impl.hpp"
#include <mutex>
#include <unordered_map>
#include <vector>
......@@ -25,7 +23,7 @@ class PinnableBlockAllocator : public MemoryAllocator {
};
public:
explicit PinnableBlockAllocator(Device device);
PinnableBlockAllocator(Device device);
~PinnableBlockAllocator();
std::byte *allocate(size_t size) override;
......
......@@ -8,12 +8,18 @@ namespace infinicore {
StreamOrderedAllocator::StreamOrderedAllocator(Device device) : MemoryAllocator(), device_(device) {}
std::byte *StreamOrderedAllocator::allocate(size_t size) {
if (size == 0) {
return nullptr;
}
void *ptr = nullptr;
INFINICORE_CHECK_ERROR(infinirtMallocAsync(&ptr, size, context::getStream()));
return (std::byte *)ptr;
}
void StreamOrderedAllocator::deallocate(std::byte *ptr) {
if (ptr == nullptr) {
return;
}
INFINICORE_CHECK_ERROR(infinirtFreeAsync(ptr, context::getStream()));
}
} // namespace infinicore
......@@ -39,6 +39,10 @@ void ContextImpl::setDevice(Device device) {
return;
}
if (getCurrentRuntime()->isGraphRecording()) {
spdlog::warn("Switching device runtime during graph recording may break the graph!");
}
if (runtime_table_[int(device.getType())][device.getIndex()] == nullptr) {
// Lazy initialization of runtime if never set before.
runtime_table_[int(device.getType())][device.getIndex()] = std::unique_ptr<Runtime>(new Runtime(device));
......@@ -178,6 +182,21 @@ void streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) {
ContextImpl::singleton().getCurrentRuntime()->streamWaitEvent(stream, event);
}
bool isGraphRecording() {
return ContextImpl::singleton().getCurrentRuntime()->isGraphRecording();
}
void startGraphRecording() {
ContextImpl::singleton().getCurrentRuntime()->startGraphRecording();
}
void addGraphOperator(std::shared_ptr<graph::GraphOperator> op) {
ContextImpl::singleton().getCurrentRuntime()->addGraphOperator(op);
}
std::shared_ptr<graph::Graph> stopGraphRecording() {
return ContextImpl::singleton().getCurrentRuntime()->stopGraphRecording();
}
} // namespace context
} // namespace infinicore
......@@ -8,12 +8,12 @@
#include "../allocators/stream_ordered_allocator.hpp"
namespace infinicore {
Runtime::Runtime(Device device) : device_(device) {
Runtime::Runtime(Device device) : device_(device), graph_manager_(std::make_unique<graph::GraphManager>()) {
activate();
INFINICORE_CHECK_ERROR(infinirtStreamCreate(&stream_));
INFINICORE_CHECK_ERROR(infiniopCreateHandle(&infiniop_handle_));
if (device_.getType() == Device::Type::CPU) {
device_memory_allocator_ = std::make_unique<HostAllocator>();
device_memory_allocator_ = std::make_unique<PinnableBlockAllocator>(device);
} else {
device_memory_allocator_ = std::make_unique<PinnableBlockAllocator>(device);
pinned_host_memory_allocator_ = std::make_unique<DevicePinnedHostAllocator>(device);
......@@ -145,6 +145,25 @@ void Runtime::streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) {
INFINICORE_CHECK_ERROR(infinirtStreamWaitEvent(stream, event));
}
bool Runtime::isGraphRecording() const {
return graph_manager_->is_recording();
}
void Runtime::startGraphRecording() {
device_memory_allocator_->set_pin_mode(true);
return graph_manager_->start_recording();
}
void Runtime::addGraphOperator(std::shared_ptr<graph::GraphOperator> op) {
return graph_manager_->add_operator(op);
}
std::shared_ptr<graph::Graph> Runtime::stopGraphRecording() {
auto graph = graph_manager_->stop_recording();
device_memory_allocator_->set_pin_mode(false);
return graph;
}
std::string Runtime::toString() const {
return fmt::format("Runtime({})", device_.toString());
}
......
#pragma once
#include "../allocators/memory_allocator.hpp"
#include "../allocators/pinnable_block_allocator.hpp"
#include "infinicore/context/context.hpp"
#include "../../graph/graph_manager.hpp"
#include <infiniop.h>
#include <infinirt.h>
......@@ -13,8 +16,9 @@ private:
Device device_;
infinirtStream_t stream_;
infiniopHandle_t infiniop_handle_;
std::unique_ptr<MemoryAllocator> device_memory_allocator_;
std::unique_ptr<PinnableBlockAllocator> device_memory_allocator_;
std::unique_ptr<MemoryAllocator> pinned_host_memory_allocator_;
std::unique_ptr<graph::GraphManager> graph_manager_;
protected:
Runtime(Device device);
......@@ -48,6 +52,12 @@ public:
float elapsedTime(infinirtEvent_t start, infinirtEvent_t end);
void streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event);
// Graph
bool isGraphRecording() const;
void startGraphRecording();
void addGraphOperator(std::shared_ptr<graph::GraphOperator> op);
std::shared_ptr<graph::Graph> stopGraphRecording();
std::string toString() const;
friend class ContextImpl;
......
#include "graph_manager.hpp"
#include "../utils.hpp"
namespace infinicore::graph {
/* =========================
* GraphTensor
* ========================= */
GraphTensor::GraphTensor(const Tensor &tensor) : Tensor(tensor->to_blob()) {
}
/* =========================
* GraphOperator
* ========================= */
void GraphOperator::run() const {
runner_(planned_meta_);
}
GraphOperator::~GraphOperator() {
if (deleter_) {
deleter_(&planned_meta_);
}
}
/* =========================
* Graph
* ========================= */
void Graph::run() const {
for (auto &op : op_list_) {
op->run();
}
}
void Graph::add_operator(std::shared_ptr<GraphOperator> op) {
op_list_.push_back(op);
}
/* =========================
* GraphManager
* ========================= */
bool GraphManager::is_recording() const {
return recording_;
}
void GraphManager::start_recording() {
recording_ = true;
graph_ = std::make_shared<Graph>();
}
void GraphManager::add_operator(std::shared_ptr<GraphOperator> op) {
INFINICORE_ASSERT(recording_);
graph_->add_operator(op);
}
std::shared_ptr<Graph> GraphManager::stop_recording() {
recording_ = false;
return std::exchange(graph_, nullptr);
}
} // namespace infinicore::graph
#pragma once
#include "infinicore/graph/graph.hpp"
#include <memory>
#include <vector>
namespace infinicore::graph {
class GraphManager {
public:
GraphManager() = default;
~GraphManager() = default;
bool is_recording() const;
void start_recording();
void add_operator(std::shared_ptr<GraphOperator> op);
std::shared_ptr<Graph> stop_recording();
private:
std::shared_ptr<Graph> graph_;
bool recording_ = false;
};
} // namespace infinicore::graph
......@@ -3,16 +3,15 @@
#include "../../utils.hpp"
namespace infinicore::op {
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Gemm);
common::OpDispatcher<Gemm::schema> &Gemm::dispatcher() {
static common::OpDispatcher<Gemm::schema> dispatcher_;
return dispatcher_;
};
Gemm::Gemm(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b);
INFINICORE_GRAPH_OP_DISPATCH(c->device().getType(), c, a, b, alpha, beta);
}
void Gemm::execute(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b);
infinicore::context::setDevice(c->device());
dispatcher().lookup(c->device().getType())(c, a, b, alpha, beta);
INFINICORE_GRAPH_OP_RECORD_OR_RUN(Gemm, c, a, b, alpha, beta);
}
Tensor gemm(Tensor a, Tensor b, float alpha, float beta) {
......
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "../infiniop_impl.hpp"
#include "infinicore/ops/gemm.hpp"
#include <infiniop.h>
namespace infinicore::op::gemm_impl::infiniop {
thread_local common::OpCache<size_t, infiniopGemmDescriptor_t> caches(
100, // capacity
[](infiniopGemmDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyGemmDescriptor(desc));
desc = nullptr;
}
});
void calculate(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
size_t seed = hash_combine(c, b, a, alpha, beta);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
auto desc_opt = cache.get(seed);
infiniopGemmDescriptor_t desc = nullptr;
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateGemmDescriptor(
context::getInfiniopHandle(device), &desc,
c->desc(), a->desc(), b->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}
size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetGemmWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Gemm, 100);
struct PlannedMeta {
std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor workspace, c, a, b;
float alpha, beta;
};
void *plan(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
size_t seed = hash_combine(c, a, b);
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
Descriptor, descriptor, Gemm,
seed, c->desc(), a->desc(), b->desc());
INFINIOP_WORKSPACE_TENSOR(workspace, Gemm, descriptor);
auto planned = new PlannedMeta{
descriptor,
graph::GraphTensor(workspace),
graph::GraphTensor(c),
graph::GraphTensor(a),
graph::GraphTensor(b),
alpha, beta};
return planned;
}
void run(void *planned_meta) {
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
INFINICORE_CHECK_ERROR(infiniopGemm(
desc, workspace->data(), workspace_size,
c->data(), a->data(), b->data(), alpha, beta, context::getStream()));
planned->descriptor->desc, planned->workspace->data(), planned->workspace->numel(),
planned->c->data(), planned->a->data(), planned->b->data(), planned->alpha, planned->beta, context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}
static bool registered = []() {
Gemm::dispatcher().registerAll(&calculate, false);
return true;
}();
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Gemm, &plan, &run, &cleanup);
} // namespace infinicore::op::gemm_impl::infiniop
#pragma once
#include "../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include <infiniop.h>
#define INFINIOP_CACHABLE_DESCRIPTOR(__DESC_TYPE__, __OP_NAME__, __SIZE__) \
struct __DESC_TYPE__ { \
infiniop##__OP_NAME__##Descriptor_t desc; \
Descriptor(infiniop##__OP_NAME__##Descriptor_t desc) : desc(desc) {} \
~Descriptor() { \
if (desc != nullptr) { \
infiniopDestroy##__OP_NAME__##Descriptor(desc); \
desc = nullptr; \
} \
} \
}; \
\
thread_local common::OpCache<size_t, std::shared_ptr<__DESC_TYPE__>> \
caches( \
__SIZE__, \
[](std::shared_ptr<__DESC_TYPE__> &desc) { \
desc = nullptr; \
});
#define INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(__DESC_TYPE__, __DESC_NAME__, __INFINIOP_NAME__, __HASH_KEY__, ...) \
std::shared_ptr<__DESC_TYPE__> __DESC_NAME__; \
{ \
auto device__ = context::getDevice(); \
auto &cache__ = caches.getCache(device__); \
__DESC_NAME__ = cache__.get(__HASH_KEY__).value_or(nullptr); \
if (!__DESC_NAME__) { \
__DESC_NAME__ = std::make_shared<__DESC_TYPE__>(nullptr); \
INFINICORE_CHECK_ERROR(infiniopCreate##__INFINIOP_NAME__##Descriptor( \
context::getInfiniopHandle(device__), \
&__DESC_NAME__->desc, \
__VA_ARGS__)); \
cache__.put(__HASH_KEY__, __DESC_NAME__); \
} \
}
#define INFINIOP_WORKSPACE_TENSOR(__TENSOR_NAME__, __INFINIOP_NAME__, __DESC_NAME__) \
Tensor __TENSOR_NAME__; \
{ \
auto device__ = context::getDevice(); \
size_t workspace_size = 0; \
INFINICORE_CHECK_ERROR(infiniopGet##__INFINIOP_NAME__##WorkspaceSize(__DESC_NAME__->desc, &workspace_size)); \
__TENSOR_NAME__ = Tensor::empty({workspace_size}, DataType::U8, device__); \
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment