Unverified Commit f00c06d0 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #955 from InfiniTensor/issue/811

issue/811 support cuda graph capture
parents 148b475b 3a8c6860
......@@ -31,17 +31,21 @@ protected:
class Graph {
public:
Graph() = default;
~Graph() = default;
Graph();
~Graph();
void run() const;
protected:
void add_operator(std::shared_ptr<GraphOperator> op);
void instantiate();
std::vector<std::shared_ptr<GraphOperator>> op_list_;
friend class GraphManager;
private:
struct DeviceGraph;
std::unique_ptr<DeviceGraph> device_graph_;
};
} // namespace infinicore::graph
......
......@@ -6,6 +6,9 @@
typedef void *infinirtStream_t;
typedef void *infinirtEvent_t;
typedef void *infinirtGraph_t;
typedef void *infinirtGraphNode_t;
typedef void *infinirtGraphExec_t;
__C __export infiniStatus_t infinirtInit();
......@@ -63,4 +66,24 @@ __C __export infiniStatus_t infinirtMemcpyAsync(void *dst, const void *src, size
__C __export infiniStatus_t infinirtMallocAsync(void **p_ptr, size_t size, infinirtStream_t stream);
__C __export infiniStatus_t infinirtFreeAsync(void *ptr, infinirtStream_t stream);
// Graph
typedef enum {
INFINIRT_STREAM_CAPTURE_MODE_GLOBAL = 0,
INFINIRT_STREAM_CAPTURE_MODE_THREAD_LOCAL = 1,
INFINIRT_STREAM_CAPTURE_MODE_RELAXED = 2,
} infinirtStreamCaptureMode_t;
__C __export infiniStatus_t infinirtStreamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode);
__C __export infiniStatus_t infinirtStreamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr);
__C __export infiniStatus_t infinirtGraphDestroy(infinirtGraph_t graph);
__C __export infiniStatus_t infinirtGraphInstantiate(
infinirtGraphExec_t *graph_exec_ptr,
infinirtGraph_t graph,
infinirtGraphNode_t *node_ptr,
char *log_buffer,
size_t buffer_size);
__C __export infiniStatus_t infinirtGraphExecDestroy(infinirtGraphExec_t graph_exec);
__C __export infiniStatus_t infinirtGraphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream);
#endif // __INFINIRT_API_H__
#include "graph_manager.hpp"
#include "../utils.hpp"
#include "infinicore/context/context.hpp"
#include <infinirt.h>
namespace infinicore::graph {
......@@ -33,9 +35,40 @@ GraphOperator::~GraphOperator() {
* Graph
* ========================= */
struct Graph::DeviceGraph {
infinirtGraph_t graph;
infinirtGraphExec_t exec;
infinirtGraphNode_t node;
std::vector<char> log_buffer;
DeviceGraph() {
log_buffer.resize(4 * 1024);
}
~DeviceGraph() {
if (exec) {
infinirtGraphExecDestroy(exec);
}
if (graph) {
infinirtGraphDestroy(graph);
}
}
void launch() {
INFINICORE_CHECK_ERROR(infinirtGraphLuanch(exec, context::getStream()));
}
};
Graph::Graph() {
}
void Graph::run() const {
for (auto &op : op_list_) {
op->run();
if (device_graph_ != nullptr && device_graph_.get()->exec != nullptr) {
device_graph_.get()->launch();
} else {
for (auto &op : op_list_) {
op->run();
}
}
}
......@@ -43,6 +76,50 @@ void Graph::add_operator(std::shared_ptr<GraphOperator> op) {
op_list_.push_back(op);
}
void Graph::instantiate() {
// Reset device graph
device_graph_ = std::make_unique<DeviceGraph>();
// warmup
for (size_t iter = 0; iter < 5; ++iter) {
this->run();
}
infinicore::context::syncStream();
if (infinirtStreamBeginCapture(
context::getStream(),
INFINIRT_STREAM_CAPTURE_MODE_GLOBAL)
!= INFINI_STATUS_SUCCESS) {
return;
}
// Run and record
this->run();
if (infinirtStreamEndCapture(
context::getStream(),
&device_graph_.get()->graph)
!= INFINI_STATUS_SUCCESS) {
return;
}
if (infinirtGraphInstantiate(
&device_graph_.get()->exec,
device_graph_.get()->graph,
&device_graph_.get()->node,
device_graph_.get()->log_buffer.data(),
device_graph_.get()->log_buffer.size())
!= INFINI_STATUS_SUCCESS) {
static bool warned_once = false;
if (!warned_once) {
warned_once = true;
spdlog::warn("Fail to instantiate device graph: {}", std::string(device_graph_.get()->log_buffer.data()));
}
}
}
Graph::~Graph() = default;
/* =========================
* GraphManager
* ========================= */
......@@ -52,19 +129,26 @@ bool GraphManager::is_recording() const {
}
void GraphManager::start_recording() {
if (is_recording()) {
spdlog::warn("Graph is already recording. Previous recording will be dropped.");
}
recording_ = true;
graph_ = std::make_shared<Graph>();
}
void GraphManager::add_operator(std::shared_ptr<GraphOperator> op) {
INFINICORE_ASSERT(recording_);
INFINICORE_ASSERT(is_recording());
graph_->add_operator(op);
}
std::shared_ptr<Graph> GraphManager::stop_recording() {
if (!is_recording()) {
spdlog::warn("Graph is not recording. Please start recording first.");
return nullptr;
}
recording_ = false;
graph_->instantiate();
return std::exchange(graph_, nullptr);
}
......
......@@ -23,6 +23,10 @@ Handle::Internal::Internal(int device_id) {
_grid_size[0] = prop.maxGridSize[0];
_grid_size[1] = prop.maxGridSize[1];
_grid_size[2] = prop.maxGridSize[2];
this->useCublas(nullptr, [](cublasHandle_t handle) { return INFINI_STATUS_SUCCESS; });
#ifdef ENABLE_CUDNN_API
this->useCudnn(nullptr, [](cudnnHandle_t handle) { return INFINI_STATUS_SUCCESS; });
#endif
}
infiniStatus_t Handle::Internal::useCublas(cudaStream_t stream, const Fn<cublasHandle_t> &f) const {
......
......@@ -150,5 +150,35 @@ infiniStatus_t mallocAsync(void **p_ptr, size_t size, infinirtStream_t stream) {
infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) {
return freeDevice(ptr);
}
infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphDestroy(infinirtGraph_t graph) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphInstantiate(
infinirtGraphExec_t *graph_exec_ptr,
infinirtGraph_t graph,
infinirtGraphNode_t *node_ptr,
char *log_buffer,
size_t buffer_size) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
} // namespace infinirt::ascend
#undef CHECK_ACLRT
......@@ -142,4 +142,34 @@ infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) {
CHECK_BANGRT(cnrtFree(ptr));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphDestroy(infinirtGraph_t graph) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphInstantiate(
infinirtGraphExec_t *graph_exec_ptr,
infinirtGraph_t graph,
infinirtGraphNode_t *node_ptr,
char *log_buffer,
size_t buffer_size) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
} // namespace infinirt::bang
......@@ -116,4 +116,33 @@ infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) {
return freeDevice(ptr);
}
infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphDestroy(infinirtGraph_t graph) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphInstantiate(
infinirtGraphExec_t *graph_exec_ptr,
infinirtGraph_t graph,
infinirtGraphNode_t *node_ptr,
char *log_buffer,
size_t buffer_size) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
} // namespace infinirt::cpu
......@@ -176,4 +176,53 @@ infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) {
RUN_CUDART(cudaFreeAsync(ptr, (cudaStream_t)stream));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) {
cudaStreamCaptureMode graph_mode;
if (mode == INFINIRT_STREAM_CAPTURE_MODE_GLOBAL) {
graph_mode = cudaStreamCaptureModeGlobal;
} else if (mode == INFINIRT_STREAM_CAPTURE_MODE_THREAD_LOCAL) {
graph_mode = cudaStreamCaptureModeThreadLocal;
} else if (mode == INFINIRT_STREAM_CAPTURE_MODE_RELAXED) {
graph_mode = cudaStreamCaptureModeRelaxed;
} else {
return INFINI_STATUS_BAD_PARAM;
}
CHECK_CUDART(cudaStreamBeginCapture((cudaStream_t)stream, graph_mode));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) {
cudaGraph_t graph;
CHECK_CUDART(cudaStreamEndCapture((cudaStream_t)stream, &graph));
*graph_ptr = graph;
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t graphDestroy(infinirtGraph_t graph) {
RUN_CUDART(cudaGraphDestroy((cudaGraph_t)graph));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t graphInstantiate(
infinirtGraphExec_t *graph_exec_ptr,
infinirtGraph_t graph,
infinirtGraphNode_t *node_ptr,
char *log_buffer,
size_t buffer_size) {
CHECK_CUDART(cudaGraphInstantiate((cudaGraphExec_t *)graph_exec_ptr, (cudaGraph_t)graph, (cudaGraphNode_t *)node_ptr, log_buffer, buffer_size));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) {
RUN_CUDART(cudaGraphExecDestroy((cudaGraphExec_t)graph_exec));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) {
CHECK_CUDART(cudaGraphLaunch((cudaGraphExec_t)graph_exec, (cudaStream_t)stream));
return INFINI_STATUS_SUCCESS;
}
}
......@@ -192,3 +192,32 @@ __C infiniStatus_t infinirtMallocAsync(void **p_ptr, size_t size, infinirtStream
__C infiniStatus_t infinirtFreeAsync(void *ptr, infinirtStream_t stream) {
INFINIRT_CALL_DEVICE_API(freeAsync, (ptr, stream));
}
__C infiniStatus_t infinirtStreamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) {
INFINIRT_CALL_DEVICE_API(streamBeginCapture, (stream, mode));
}
__C infiniStatus_t infinirtStreamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) {
INFINIRT_CALL_DEVICE_API(streamEndCapture, (stream, graph_ptr));
}
__C infiniStatus_t infinirtGraphDestroy(infinirtGraph_t graph) {
INFINIRT_CALL_DEVICE_API(graphDestroy, (graph));
}
__C infiniStatus_t infinirtGraphInstantiate(
infinirtGraphExec_t *graph_exec_ptr,
infinirtGraph_t graph,
infinirtGraphNode_t *node_ptr,
char *log_buffer,
size_t buffer_size) {
INFINIRT_CALL_DEVICE_API(graphInstantiate, (graph_exec_ptr, graph, node_ptr, log_buffer, buffer_size));
}
__C infiniStatus_t infinirtGraphExecDestroy(infinirtGraphExec_t graph_exec) {
INFINIRT_CALL_DEVICE_API(graphExecDestroy, (graph_exec));
}
__C infiniStatus_t infinirtGraphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) {
INFINIRT_CALL_DEVICE_API(graphLuanch, (graph_exec, stream));
}
......@@ -30,7 +30,19 @@
INLINE infiniStatus_t memcpyAsync(void *dst, const void *src, size_t size, infinirtMemcpyKind_t kind, infinirtStream_t stream) IMPL; \
\
INLINE infiniStatus_t mallocAsync(void **p_ptr, size_t size, infinirtStream_t stream) IMPL; \
INLINE infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) IMPL;
INLINE infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) IMPL; \
\
INLINE infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) IMPL; \
INLINE infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) IMPL; \
INLINE infiniStatus_t graphDestroy(infinirtGraph_t graph) IMPL; \
INLINE infiniStatus_t graphInstantiate( \
infinirtGraphExec_t *graph_exec_ptr, \
infinirtGraph_t graph, \
infinirtGraphNode_t *node_ptr, \
char *log_buffer, \
size_t buffer_size) IMPL; \
INLINE infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) IMPL; \
INLINE infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) IMPL;
#define INFINIRT_DEVICE_API_IMPL INFINIRT_DEVICE_API(, , )
#define INFINIRT_DEVICE_API_NOOP INFINIRT_DEVICE_API( \
......
......@@ -153,4 +153,33 @@ infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) {
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphDestroy(infinirtGraph_t graph) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphInstantiate(
infinirtGraphExec_t *graph_exec_ptr,
infinirtGraph_t graph,
infinirtGraphNode_t *node_ptr,
char *log_buffer,
size_t buffer_size) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
} // namespace infinirt::kunlun
......@@ -152,4 +152,34 @@ infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) {
CHECK_MACART(hcFreeAsync(ptr, (hcStream_t)stream));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphDestroy(infinirtGraph_t graph) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphInstantiate(
infinirtGraphExec_t *graph_exec_ptr,
infinirtGraph_t graph,
infinirtGraphNode_t *node_ptr,
char *log_buffer,
size_t buffer_size) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
} // namespace infinirt::metax
......@@ -138,4 +138,34 @@ infiniStatus_t mallocAsync(void **p_ptr, size_t size, infinirtStream_t stream) {
infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) {
return freeDevice(ptr);
}
infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphDestroy(infinirtGraph_t graph) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphInstantiate(
infinirtGraphExec_t *graph_exec_ptr,
infinirtGraph_t graph,
infinirtGraphNode_t *node_ptr,
char *log_buffer,
size_t buffer_size) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
} // namespace infinirt::musa
......@@ -49,6 +49,7 @@ target("infiniop-nvidia")
add_cuflags("--extended-lambda")
add_culdflags("-Xcompiler=-fPIC")
add_cxxflags("-fPIC")
add_cflags("-fPIC")
add_cuflags("--expt-relaxed-constexpr")
if CUDNN_ROOT ~= nil then
add_linkdirs(CUDNN_ROOT .. "/lib")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment