Commit 006d530c authored by PanZezhong's avatar PanZezhong
Browse files

issue/810 static compute graph infra

parent caa61e9e
......@@ -3,6 +3,8 @@
#include "../device.hpp"
#include "../memory.hpp"
#include "../graph/graph.hpp"
#include <infiniop.h>
#include <infinirt.h>
......@@ -40,6 +42,12 @@ void destroyEvent(infinirtEvent_t event);
float elapsedTime(infinirtEvent_t start, infinirtEvent_t end);
void streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event);
// Graph recording APIs
bool isGraphRecording();
void startGraphRecording();
void addGraphOperator(std::shared_ptr<graph::GraphOperator> op);
std::shared_ptr<graph::Graph> stopGraphRecording();
} // namespace context
} // namespace infinicore
#pragma once
#include <memory>
#include <vector>
#include "../tensor.hpp"
namespace infinicore::graph {
// Forward declarations
class GraphManager;
class GraphTensor : public Tensor {
public:
GraphTensor(const Tensor &);
};
class GraphOperator {
public:
void run() const;
~GraphOperator();
protected:
using run_schema = void (*)(void *);
using cleanup_schema = void (*)(void **);
void *planned_meta_;
run_schema runner_;
cleanup_schema deleter_;
};
class Graph {
public:
Graph() = default;
~Graph() = default;
void run() const;
protected:
void add_operator(std::shared_ptr<GraphOperator> op);
std::vector<std::shared_ptr<GraphOperator>> op_list_;
friend class GraphManager;
};
} // namespace infinicore::graph
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
namespace infinicore::op {
class Gemm {
class Gemm : public graph::GraphOperator {
public:
using schema = void (*)(Tensor, Tensor, Tensor, float, float);
using plan_schema = void *(*)(Tensor, Tensor, Tensor, float, float);
Gemm(Tensor c, Tensor a, Tensor b, float alpha, float beta);
static void execute(Tensor c, Tensor a, Tensor b, float alpha, float beta);
static common::OpDispatcher<schema> &dispatcher();
static common::OpDispatcher<plan_schema> &plan_dispatcher();
static common::OpDispatcher<run_schema> &run_dispatcher();
static common::OpDispatcher<cleanup_schema> &cleanup_dispatcher();
};
Tensor gemm(Tensor a, Tensor b, float alpha = 1.0f, float beta = 0.0f);
......
......@@ -133,6 +133,8 @@ public:
void debug() const;
Tensor to_blob() const;
///
/// Data Transfer APIs
///
......@@ -294,7 +296,7 @@ protected:
friend class Tensor;
private:
protected:
TensorMetaData meta_;
TensorData data_;
};
......
......@@ -8,7 +8,10 @@ from infinicore.context import (
get_device,
get_device_count,
get_stream,
is_graph_recording,
set_device,
start_graph_recording,
stop_graph_recording,
sync_device,
sync_stream,
)
......@@ -80,6 +83,9 @@ __all__ = [
"set_device",
"sync_device",
"sync_stream",
"is_graph_recording",
"start_graph_recording",
"stop_graph_recording",
# Data Types.
"bfloat16",
"bool",
......
import infinicore.device
from infinicore.graph import Graph
from infinicore.lib import _infinicore
......@@ -49,3 +50,24 @@ def get_stream():
stream: The current stream object
"""
return _infinicore.get_stream()
def is_graph_recording():
"""Check if the current graph is recording.
Returns:
bool: True if the current graph is recording, False otherwise
"""
return _infinicore.is_graph_recording()
def start_graph_recording(device=None):
"""Start recording the current graph."""
if device is not None:
set_device(device)
_infinicore.start_graph_recording()
def stop_graph_recording():
"""Stop recording the current graph."""
return Graph(_infinicore.stop_graph_recording())
from infinicore.lib import _infinicore
class Graph:
"""
Python wrapper around a InfiniCore Graph instance.
"""
def __init__(self, graph: _infinicore.Graph):
if not isinstance(graph, _infinicore.Graph):
raise TypeError("Expected _infinicore.Graph")
self._graph = graph
def run(self):
return self._graph.run()
def __repr__(self):
return f"<Graph wrapper of {self._graph!r}>"
#include "pinnable_block_allocator.hpp"
#include "../context_impl.hpp"
#include "../../utils.hpp"
#include <algorithm>
......
......@@ -2,8 +2,6 @@
#include "memory_allocator.hpp"
#include "../context_impl.hpp"
#include <mutex>
#include <unordered_map>
#include <vector>
......@@ -25,7 +23,7 @@ class PinnableBlockAllocator : public MemoryAllocator {
};
public:
explicit PinnableBlockAllocator(Device device);
PinnableBlockAllocator(Device device);
~PinnableBlockAllocator();
std::byte *allocate(size_t size) override;
......
......@@ -39,6 +39,10 @@ void ContextImpl::setDevice(Device device) {
return;
}
if (getCurrentRuntime()->isGraphRecording()) {
spdlog::warn("Switching device runtime during graph recording may break the graph!");
}
if (runtime_table_[int(device.getType())][device.getIndex()] == nullptr) {
// Lazy initialization of runtime if never set before.
runtime_table_[int(device.getType())][device.getIndex()] = std::unique_ptr<Runtime>(new Runtime(device));
......@@ -178,6 +182,21 @@ void streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) {
ContextImpl::singleton().getCurrentRuntime()->streamWaitEvent(stream, event);
}
bool isGraphRecording() {
return ContextImpl::singleton().getCurrentRuntime()->isGraphRecording();
}
void startGraphRecording() {
ContextImpl::singleton().getCurrentRuntime()->startGraphRecording();
}
void addGraphOperator(std::shared_ptr<graph::GraphOperator> op) {
ContextImpl::singleton().getCurrentRuntime()->addGraphOperator(op);
}
std::shared_ptr<graph::Graph> stopGraphRecording() {
return ContextImpl::singleton().getCurrentRuntime()->stopGraphRecording();
}
} // namespace context
} // namespace infinicore
......@@ -8,12 +8,12 @@
#include "../allocators/stream_ordered_allocator.hpp"
namespace infinicore {
Runtime::Runtime(Device device) : device_(device) {
Runtime::Runtime(Device device) : device_(device), graph_manager_(std::make_unique<graph::GraphManager>()) {
activate();
INFINICORE_CHECK_ERROR(infinirtStreamCreate(&stream_));
INFINICORE_CHECK_ERROR(infiniopCreateHandle(&infiniop_handle_));
if (device_.getType() == Device::Type::CPU) {
device_memory_allocator_ = std::make_unique<HostAllocator>();
device_memory_allocator_ = std::make_unique<PinnableBlockAllocator>(device);
} else {
device_memory_allocator_ = std::make_unique<PinnableBlockAllocator>(device);
pinned_host_memory_allocator_ = std::make_unique<DevicePinnedHostAllocator>(device);
......@@ -145,6 +145,25 @@ void Runtime::streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) {
INFINICORE_CHECK_ERROR(infinirtStreamWaitEvent(stream, event));
}
bool Runtime::isGraphRecording() const {
return graph_manager_->is_recording();
}
void Runtime::startGraphRecording() {
device_memory_allocator_->set_pin_mode(true);
return graph_manager_->start_recording();
}
void Runtime::addGraphOperator(std::shared_ptr<graph::GraphOperator> op) {
return graph_manager_->add_operator(op);
}
std::shared_ptr<graph::Graph> Runtime::stopGraphRecording() {
auto graph = graph_manager_->stop_recording();
device_memory_allocator_->set_pin_mode(false);
return graph;
}
std::string Runtime::toString() const {
return fmt::format("Runtime({})", device_.toString());
}
......
#pragma once
#include "../allocators/memory_allocator.hpp"
#include "../allocators/pinnable_block_allocator.hpp"
#include "infinicore/context/context.hpp"
#include "../../graph/graph_manager.hpp"
#include <infiniop.h>
#include <infinirt.h>
......@@ -13,8 +16,9 @@ private:
Device device_;
infinirtStream_t stream_;
infiniopHandle_t infiniop_handle_;
std::unique_ptr<MemoryAllocator> device_memory_allocator_;
std::unique_ptr<PinnableBlockAllocator> device_memory_allocator_;
std::unique_ptr<MemoryAllocator> pinned_host_memory_allocator_;
std::unique_ptr<graph::GraphManager> graph_manager_;
protected:
Runtime(Device device);
......@@ -48,6 +52,12 @@ public:
float elapsedTime(infinirtEvent_t start, infinirtEvent_t end);
void streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event);
// Graph
bool isGraphRecording() const;
void startGraphRecording();
void addGraphOperator(std::shared_ptr<graph::GraphOperator> op);
std::shared_ptr<graph::Graph> stopGraphRecording();
std::string toString() const;
friend class ContextImpl;
......
#include "graph_manager.hpp"
#include "../utils.hpp"
namespace infinicore::graph {
/* =========================
* GraphTensor
* ========================= */
GraphTensor::GraphTensor(const Tensor &tensor) : Tensor(tensor->to_blob()) {
}
/* =========================
* GraphOperator
* ========================= */
void GraphOperator::run() const {
runner_(planned_meta_);
}
GraphOperator::~GraphOperator() {
if (deleter_) {
deleter_(&planned_meta_);
}
}
/* =========================
* Graph
* ========================= */
void Graph::run() const {
for (auto &op : op_list_) {
op->run();
}
}
void Graph::add_operator(std::shared_ptr<GraphOperator> op) {
op_list_.push_back(op);
}
/* =========================
* GraphManager
* ========================= */
bool GraphManager::is_recording() const {
return recording_;
}
void GraphManager::start_recording() {
recording_ = true;
graph_ = std::make_shared<Graph>();
}
void GraphManager::add_operator(std::shared_ptr<GraphOperator> op) {
INFINICORE_ASSERT(recording_);
graph_->add_operator(op);
}
std::shared_ptr<Graph> GraphManager::stop_recording() {
recording_ = false;
return std::exchange(graph_, nullptr);
}
} // namespace infinicore::graph
#pragma once
#include "infinicore/graph/graph.hpp"
#include <memory>
#include <vector>
namespace infinicore::graph {
class GraphManager {
public:
GraphManager() = default;
~GraphManager() = default;
bool is_recording() const;
void start_recording();
void add_operator(std::shared_ptr<GraphOperator> op);
std::shared_ptr<Graph> stop_recording();
private:
std::shared_ptr<Graph> graph_;
bool recording_ = false;
};
} // namespace infinicore::graph
......@@ -9,10 +9,34 @@ common::OpDispatcher<Gemm::schema> &Gemm::dispatcher() {
return dispatcher_;
};
void Gemm::execute(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
common::OpDispatcher<Gemm::plan_schema> &Gemm::plan_dispatcher() {
static common::OpDispatcher<Gemm::plan_schema> dispatcher_;
return dispatcher_;
}
common::OpDispatcher<Gemm::run_schema> &Gemm::run_dispatcher() {
static common::OpDispatcher<Gemm::run_schema> dispatcher_;
return dispatcher_;
}
common::OpDispatcher<Gemm::cleanup_schema> &Gemm::cleanup_dispatcher() {
static common::OpDispatcher<Gemm::cleanup_schema> dispatcher_;
return dispatcher_;
}
Gemm::Gemm(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b);
infinicore::context::setDevice(c->device());
dispatcher().lookup(c->device().getType())(c, a, b, alpha, beta);
planned_meta_ = plan_dispatcher().lookup(c->device().getType())(c, a, b, alpha, beta);
runner_ = run_dispatcher().lookup(c->device().getType());
deleter_ = cleanup_dispatcher().lookup(c->device().getType());
}
void Gemm::execute(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
auto op = std::make_shared<Gemm>(c, a, b, alpha, beta);
if (context::isGraphRecording()) {
context::addGraphOperator(op);
} else {
op->run();
}
}
Tensor gemm(Tensor a, Tensor b, float alpha, float beta) {
......
......@@ -5,45 +5,89 @@
#include <infiniop.h>
namespace infinicore::op::gemm_impl::infiniop {
thread_local common::OpCache<size_t, infiniopGemmDescriptor_t> caches(
100, // capacity
[](infiniopGemmDescriptor_t &desc) {
// A desc holder to make it a shared pointer that can auto clean-up
struct Descriptor {
infiniopGemmDescriptor_t desc;
Descriptor(infiniopGemmDescriptor_t desc) : desc(desc) {}
~Descriptor() {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyGemmDescriptor(desc));
infiniopDestroyGemmDescriptor(desc);
desc = nullptr;
}
});
}
};
void calculate(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
thread_local common::OpCache<size_t, std::shared_ptr<Descriptor>>
caches(
// capacity
100,
// on evict
[](std::shared_ptr<Descriptor> &desc) {
desc = nullptr;
});
struct PlannedMeta {
std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor workspace, c, a, b;
float alpha, beta;
};
void *plan(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
size_t seed = hash_combine(c, b, a, alpha, beta);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
auto desc_opt = cache.get(seed);
infiniopGemmDescriptor_t desc = nullptr;
auto descriptor = cache.get(seed).value_or(nullptr);
if (!desc_opt) {
if (!descriptor) {
descriptor = std::make_shared<Descriptor>(nullptr);
INFINICORE_CHECK_ERROR(infiniopCreateGemmDescriptor(
context::getInfiniopHandle(device), &desc,
context::getInfiniopHandle(device),
&descriptor->desc,
c->desc(), a->desc(), b->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
cache.put(seed, descriptor);
}
size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetGemmWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
INFINICORE_CHECK_ERROR(infiniopGetGemmWorkspaceSize(descriptor->desc, &workspace_size));
Tensor workspace = Tensor::empty({workspace_size}, DataType::U8, device);
auto planned = new PlannedMeta{
descriptor,
graph::GraphTensor(workspace),
graph::GraphTensor(c),
graph::GraphTensor(a),
graph::GraphTensor(b),
alpha, beta};
return planned;
}
void run(void *planned_meta) {
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
INFINICORE_CHECK_ERROR(infiniopGemm(
desc, workspace->data(), workspace_size,
c->data(), a->data(), b->data(), alpha, beta, context::getStream()));
planned->descriptor->desc, planned->workspace->data(), planned->workspace->numel(),
planned->c->data(), planned->a->data(), planned->b->data(), planned->alpha, planned->beta, context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}
void calculate(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
auto planned = plan(c, a, b, alpha, beta);
run(planned);
cleanup(&planned);
}
static bool registered = []() {
Gemm::dispatcher().registerAll(&calculate, false);
Gemm::plan_dispatcher().registerAll(&plan, false);
Gemm::run_dispatcher().registerAll(&run, false);
Gemm::cleanup_dispatcher().registerAll(&cleanup, false);
return true;
}();
......
......@@ -24,6 +24,11 @@ inline void bind(py::module &m) {
// Synchronization
m.def("sync_stream", &syncStream, "Synchronize the current stream");
m.def("sync_device", &syncDevice, "Synchronize the current device");
// Graph
m.def("is_graph_recording", &isGraphRecording, "Check if graph recording is turned on");
m.def("start_graph_recording", &startGraphRecording, "Start graph recording");
m.def("stop_graph_recording", &stopGraphRecording, "Stop graph recording and return the graph");
}
} // namespace infinicore::context
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "infinicore.hpp"
namespace py = pybind11;
namespace infinicore::graph {
inline void bind(py::module_ &m) {
py::class_<infinicore::graph::Graph,
std::shared_ptr<infinicore::graph::Graph>>(m, "Graph")
.def(py::init<>()) // allow construction
.def("run", &infinicore::graph::Graph::run);
}
} // namespace infinicore::graph
......@@ -6,6 +6,7 @@
#include "device.hpp"
#include "device_event.hpp"
#include "dtype.hpp"
#include "graph.hpp"
#include "ops.hpp"
#include "tensor.hpp"
......@@ -18,6 +19,7 @@ PYBIND11_MODULE(_infinicore, m) {
dtype::bind(m);
ops::bind(m);
tensor::bind(m);
graph::bind(m);
}
} // namespace infinicore
......@@ -275,4 +275,12 @@ std::shared_ptr<TensorImpl> TensorImpl::strided_from_blob(
return t;
}
Tensor TensorImpl::to_blob() const {
auto t = std::shared_ptr<TensorImpl>(new TensorImpl(shape(), strides(), dtype()));
t->data_.offset = this->data_.offset;
t->data_.memory = std::make_shared<Memory>(this->data_.memory->data(), this->data_.memory->size(), this->data_.memory->device(), nullptr);
return Tensor{t};
}
} // namespace infinicore
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment