Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
01a4a0c8
Unverified
Commit
01a4a0c8
authored
Jan 10, 2026
by
Haojie Wang
Committed by
GitHub
Jan 10, 2026
Browse files
Merge pull request #882 from InfiniTensor/issue/810
issue/810 static compute graph infra
parents
3883f32f
39f9c349
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
415 additions
and
62 deletions
+415
-62
include/infinicore/context/context.hpp
include/infinicore/context/context.hpp
+8
-0
include/infinicore/graph/graph.hpp
include/infinicore/graph/graph.hpp
+92
-0
include/infinicore/ops/gemm.hpp
include/infinicore/ops/gemm.hpp
+2
-6
include/infinicore/tensor.hpp
include/infinicore/tensor.hpp
+3
-1
python/infinicore/__init__.py
python/infinicore/__init__.py
+6
-0
python/infinicore/context.py
python/infinicore/context.py
+22
-0
python/infinicore/graph.py
python/infinicore/graph.py
+18
-0
src/infinicore/context/allocators/device_pinned_allocator.cc
src/infinicore/context/allocators/device_pinned_allocator.cc
+6
-0
src/infinicore/context/allocators/host_allocator.cc
src/infinicore/context/allocators/host_allocator.cc
+6
-0
src/infinicore/context/allocators/pinnable_block_allocator.cc
...infinicore/context/allocators/pinnable_block_allocator.cc
+6
-1
src/infinicore/context/allocators/pinnable_block_allocator.hpp
...nfinicore/context/allocators/pinnable_block_allocator.hpp
+1
-3
src/infinicore/context/allocators/stream_ordered_allocator.cc
...infinicore/context/allocators/stream_ordered_allocator.cc
+6
-0
src/infinicore/context/context_impl.cc
src/infinicore/context/context_impl.cc
+19
-0
src/infinicore/context/runtime/runtime.cc
src/infinicore/context/runtime/runtime.cc
+21
-2
src/infinicore/context/runtime/runtime.hpp
src/infinicore/context/runtime/runtime.hpp
+12
-2
src/infinicore/graph/graph.cc
src/infinicore/graph/graph.cc
+67
-0
src/infinicore/graph/graph_manager.hpp
src/infinicore/graph/graph_manager.hpp
+25
-0
src/infinicore/ops/gemm/gemm.cc
src/infinicore/ops/gemm/gemm.cc
+6
-7
src/infinicore/ops/gemm/gemm_infiniop.cc
src/infinicore/ops/gemm/gemm_infiniop.cc
+39
-40
src/infinicore/ops/infiniop_impl.hpp
src/infinicore/ops/infiniop_impl.hpp
+50
-0
No files found.
include/infinicore/context/context.hpp
View file @
01a4a0c8
...
@@ -3,6 +3,8 @@
...
@@ -3,6 +3,8 @@
#include "../device.hpp"
#include "../device.hpp"
#include "../memory.hpp"
#include "../memory.hpp"
#include "../graph/graph.hpp"
#include <infiniop.h>
#include <infiniop.h>
#include <infinirt.h>
#include <infinirt.h>
...
@@ -40,6 +42,12 @@ void destroyEvent(infinirtEvent_t event);
...
@@ -40,6 +42,12 @@ void destroyEvent(infinirtEvent_t event);
float
elapsedTime
(
infinirtEvent_t
start
,
infinirtEvent_t
end
);
float
elapsedTime
(
infinirtEvent_t
start
,
infinirtEvent_t
end
);
void
streamWaitEvent
(
infinirtStream_t
stream
,
infinirtEvent_t
event
);
void
streamWaitEvent
(
infinirtStream_t
stream
,
infinirtEvent_t
event
);
// Graph recording APIs
bool
isGraphRecording
();
void
startGraphRecording
();
void
addGraphOperator
(
std
::
shared_ptr
<
graph
::
GraphOperator
>
op
);
std
::
shared_ptr
<
graph
::
Graph
>
stopGraphRecording
();
}
// namespace context
}
// namespace context
}
// namespace infinicore
}
// namespace infinicore
include/infinicore/graph/graph.hpp
0 → 100644
View file @
01a4a0c8
#pragma once
#include <memory>
#include <vector>
#include "../tensor.hpp"
namespace
infinicore
::
graph
{
// Forward declarations
class
GraphManager
;
class
GraphTensor
:
public
Tensor
{
public:
GraphTensor
(
const
Tensor
&
);
};
class
GraphOperator
{
public:
void
run
()
const
;
~
GraphOperator
();
protected:
using
run_schema
=
void
(
*
)(
void
*
);
using
cleanup_schema
=
void
(
*
)(
void
**
);
void
*
planned_meta_
;
run_schema
runner_
;
cleanup_schema
deleter_
;
};
class
Graph
{
public:
Graph
()
=
default
;
~
Graph
()
=
default
;
void
run
()
const
;
protected:
void
add_operator
(
std
::
shared_ptr
<
GraphOperator
>
op
);
std
::
vector
<
std
::
shared_ptr
<
GraphOperator
>>
op_list_
;
friend
class
GraphManager
;
};
}
// namespace infinicore::graph
#define INFINICORE_GRAPH_OP_CLASS(__OP_NAME__, ...) \
class __OP_NAME__ : public graph::GraphOperator { \
public: \
using schema = void (*)(__VA_ARGS__); \
using plan_schema = void *(*)(__VA_ARGS__); \
static common::OpDispatcher<plan_schema> &plan_dispatcher(); \
static common::OpDispatcher<run_schema> &run_dispatcher(); \
static common::OpDispatcher<cleanup_schema> &cleanup_dispatcher(); \
__OP_NAME__(__VA_ARGS__); \
static void execute(__VA_ARGS__); \
};
#define INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(__OP_NAME__) \
common::OpDispatcher<__OP_NAME__::plan_schema> &__OP_NAME__::plan_dispatcher() { \
static common::OpDispatcher<__OP_NAME__::plan_schema> dispatcher_; \
return dispatcher_; \
} \
common::OpDispatcher<__OP_NAME__::run_schema> &__OP_NAME__::run_dispatcher() { \
static common::OpDispatcher<__OP_NAME__::run_schema> dispatcher_; \
return dispatcher_; \
} \
common::OpDispatcher<__OP_NAME__::cleanup_schema> &__OP_NAME__::cleanup_dispatcher() { \
static common::OpDispatcher<__OP_NAME__::cleanup_schema> dispatcher_; \
return dispatcher_; \
}
#define INFINICORE_GRAPH_OP_DISPATCH(__DEVICE_TYPE__, ...) \
planned_meta_ = plan_dispatcher().lookup(__DEVICE_TYPE__)(__VA_ARGS__); \
runner_ = run_dispatcher().lookup(__DEVICE_TYPE__); \
deleter_ = cleanup_dispatcher().lookup(__DEVICE_TYPE__);
#define INFINICORE_GRAPH_OP_RECORD_OR_RUN(__OP_NAME__, ...) \
auto op = std::make_shared<__OP_NAME__>(__VA_ARGS__); \
if (context::isGraphRecording()) { \
context::addGraphOperator(op); \
} else { \
op->run(); \
}
#define INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(__OP_NAME__, __PLAN_F__, __RUN_F__, __CLEANUP_F__) \
static bool registered = []() { \
__OP_NAME__::plan_dispatcher().registerAll(__PLAN_F__, false); \
__OP_NAME__::run_dispatcher().registerAll(__RUN_F__, false); \
__OP_NAME__::cleanup_dispatcher().registerAll(__CLEANUP_F__, false); \
return true; \
}();
include/infinicore/ops/gemm.hpp
View file @
01a4a0c8
#pragma once
#pragma once
#include "../device.hpp"
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
#include "common/op.hpp"
namespace
infinicore
::
op
{
namespace
infinicore
::
op
{
class
Gemm
{
INFINICORE_GRAPH_OP_CLASS
(
Gemm
,
Tensor
,
Tensor
,
Tensor
,
float
,
float
);
public:
using
schema
=
void
(
*
)(
Tensor
,
Tensor
,
Tensor
,
float
,
float
);
static
void
execute
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
);
static
common
::
OpDispatcher
<
schema
>
&
dispatcher
();
};
Tensor
gemm
(
Tensor
a
,
Tensor
b
,
float
alpha
=
1.0
f
,
float
beta
=
0.0
f
);
Tensor
gemm
(
Tensor
a
,
Tensor
b
,
float
alpha
=
1.0
f
,
float
beta
=
0.0
f
);
void
gemm_
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
);
void
gemm_
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
);
...
...
include/infinicore/tensor.hpp
View file @
01a4a0c8
...
@@ -133,6 +133,8 @@ public:
...
@@ -133,6 +133,8 @@ public:
void
debug
()
const
;
void
debug
()
const
;
Tensor
to_blob
()
const
;
///
///
/// Data Transfer APIs
/// Data Transfer APIs
///
///
...
@@ -294,7 +296,7 @@ protected:
...
@@ -294,7 +296,7 @@ protected:
friend
class
Tensor
;
friend
class
Tensor
;
pr
ivate
:
pr
otected
:
TensorMetaData
meta_
;
TensorMetaData
meta_
;
TensorData
data_
;
TensorData
data_
;
};
};
...
...
python/infinicore/__init__.py
View file @
01a4a0c8
...
@@ -8,7 +8,10 @@ from infinicore.context import (
...
@@ -8,7 +8,10 @@ from infinicore.context import (
get_device
,
get_device
,
get_device_count
,
get_device_count
,
get_stream
,
get_stream
,
is_graph_recording
,
set_device
,
set_device
,
start_graph_recording
,
stop_graph_recording
,
sync_device
,
sync_device
,
sync_stream
,
sync_stream
,
)
)
...
@@ -81,6 +84,9 @@ __all__ = [
...
@@ -81,6 +84,9 @@ __all__ = [
"set_device"
,
"set_device"
,
"sync_device"
,
"sync_device"
,
"sync_stream"
,
"sync_stream"
,
"is_graph_recording"
,
"start_graph_recording"
,
"stop_graph_recording"
,
# Data Types.
# Data Types.
"bfloat16"
,
"bfloat16"
,
"bool"
,
"bool"
,
...
...
python/infinicore/context.py
View file @
01a4a0c8
import
infinicore.device
import
infinicore.device
from
infinicore.graph
import
Graph
from
infinicore.lib
import
_infinicore
from
infinicore.lib
import
_infinicore
...
@@ -49,3 +50,24 @@ def get_stream():
...
@@ -49,3 +50,24 @@ def get_stream():
stream: The current stream object
stream: The current stream object
"""
"""
return
_infinicore
.
get_stream
()
return
_infinicore
.
get_stream
()
def
is_graph_recording
():
"""Check if the current graph is recording.
Returns:
bool: True if the current graph is recording, False otherwise
"""
return
_infinicore
.
is_graph_recording
()
def
start_graph_recording
(
device
=
None
):
"""Start recording the current graph."""
if
device
is
not
None
:
set_device
(
device
)
_infinicore
.
start_graph_recording
()
def
stop_graph_recording
():
"""Stop recording the current graph."""
return
Graph
(
_infinicore
.
stop_graph_recording
())
python/infinicore/graph.py
0 → 100644
View file @
01a4a0c8
from
infinicore.lib
import
_infinicore
class
Graph
:
"""
Python wrapper around a InfiniCore Graph instance.
"""
def
__init__
(
self
,
graph
:
_infinicore
.
Graph
):
if
not
isinstance
(
graph
,
_infinicore
.
Graph
):
raise
TypeError
(
"Expected _infinicore.Graph"
)
self
.
_graph
=
graph
def
run
(
self
):
return
self
.
_graph
.
run
()
def
__repr__
(
self
):
return
f
"<Graph wrapper of
{
self
.
_graph
!
r
}
>"
src/infinicore/context/allocators/device_pinned_allocator.cc
View file @
01a4a0c8
...
@@ -12,12 +12,18 @@ DevicePinnedHostAllocator::~DevicePinnedHostAllocator() {
...
@@ -12,12 +12,18 @@ DevicePinnedHostAllocator::~DevicePinnedHostAllocator() {
}
}
std
::
byte
*
DevicePinnedHostAllocator
::
allocate
(
size_t
size
)
{
std
::
byte
*
DevicePinnedHostAllocator
::
allocate
(
size_t
size
)
{
if
(
size
==
0
)
{
return
nullptr
;
}
void
*
ptr
;
void
*
ptr
;
INFINICORE_CHECK_ERROR
(
infinirtMallocHost
(
&
ptr
,
size
));
INFINICORE_CHECK_ERROR
(
infinirtMallocHost
(
&
ptr
,
size
));
return
(
std
::
byte
*
)
ptr
;
return
(
std
::
byte
*
)
ptr
;
}
}
void
DevicePinnedHostAllocator
::
deallocate
(
std
::
byte
*
ptr
)
{
void
DevicePinnedHostAllocator
::
deallocate
(
std
::
byte
*
ptr
)
{
if
(
ptr
==
nullptr
)
{
return
;
}
if
(
owner_
==
context
::
getDevice
())
{
if
(
owner_
==
context
::
getDevice
())
{
INFINICORE_CHECK_ERROR
(
infinirtFreeHost
(
ptr
));
INFINICORE_CHECK_ERROR
(
infinirtFreeHost
(
ptr
));
gc
();
gc
();
...
...
src/infinicore/context/allocators/host_allocator.cc
View file @
01a4a0c8
...
@@ -4,10 +4,16 @@
...
@@ -4,10 +4,16 @@
namespace
infinicore
{
namespace
infinicore
{
std
::
byte
*
HostAllocator
::
allocate
(
size_t
size
)
{
std
::
byte
*
HostAllocator
::
allocate
(
size_t
size
)
{
if
(
size
==
0
)
{
return
nullptr
;
}
return
(
std
::
byte
*
)
std
::
malloc
(
size
);
return
(
std
::
byte
*
)
std
::
malloc
(
size
);
}
}
void
HostAllocator
::
deallocate
(
std
::
byte
*
ptr
)
{
void
HostAllocator
::
deallocate
(
std
::
byte
*
ptr
)
{
if
(
ptr
==
nullptr
)
{
return
;
}
std
::
free
(
ptr
);
std
::
free
(
ptr
);
}
}
...
...
src/infinicore/context/allocators/pinnable_block_allocator.cc
View file @
01a4a0c8
#include "pinnable_block_allocator.hpp"
#include "pinnable_block_allocator.hpp"
#include "../context_impl.hpp"
#include "../../utils.hpp"
#include "../../utils.hpp"
#include <algorithm>
#include <algorithm>
...
@@ -35,6 +37,9 @@ PinnableBlockAllocator::PinnableBlockAllocator(Device device)
...
@@ -35,6 +37,9 @@ PinnableBlockAllocator::PinnableBlockAllocator(Device device)
// ------------------- allocate -------------------
// ------------------- allocate -------------------
std
::
byte
*
PinnableBlockAllocator
::
allocate
(
size_t
size
)
{
std
::
byte
*
PinnableBlockAllocator
::
allocate
(
size_t
size
)
{
if
(
size
==
0
)
{
return
nullptr
;
}
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
// Align size to 256 bytes for GPU
// Align size to 256 bytes for GPU
...
@@ -92,7 +97,7 @@ std::byte *PinnableBlockAllocator::allocate(size_t size) {
...
@@ -92,7 +97,7 @@ std::byte *PinnableBlockAllocator::allocate(size_t size) {
// ------------------- deallocate -------------------
// ------------------- deallocate -------------------
void
PinnableBlockAllocator
::
deallocate
(
std
::
byte
*
ptr
)
{
void
PinnableBlockAllocator
::
deallocate
(
std
::
byte
*
ptr
)
{
if
(
!
ptr
)
{
if
(
ptr
==
null
ptr
)
{
return
;
return
;
}
}
...
...
src/infinicore/context/allocators/pinnable_block_allocator.hpp
View file @
01a4a0c8
...
@@ -2,8 +2,6 @@
...
@@ -2,8 +2,6 @@
#include "memory_allocator.hpp"
#include "memory_allocator.hpp"
#include "../context_impl.hpp"
#include <mutex>
#include <mutex>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
...
@@ -25,7 +23,7 @@ class PinnableBlockAllocator : public MemoryAllocator {
...
@@ -25,7 +23,7 @@ class PinnableBlockAllocator : public MemoryAllocator {
};
};
public:
public:
explicit
PinnableBlockAllocator
(
Device
device
);
PinnableBlockAllocator
(
Device
device
);
~
PinnableBlockAllocator
();
~
PinnableBlockAllocator
();
std
::
byte
*
allocate
(
size_t
size
)
override
;
std
::
byte
*
allocate
(
size_t
size
)
override
;
...
...
src/infinicore/context/allocators/stream_ordered_allocator.cc
View file @
01a4a0c8
...
@@ -8,12 +8,18 @@ namespace infinicore {
...
@@ -8,12 +8,18 @@ namespace infinicore {
StreamOrderedAllocator
::
StreamOrderedAllocator
(
Device
device
)
:
MemoryAllocator
(),
device_
(
device
)
{}
StreamOrderedAllocator
::
StreamOrderedAllocator
(
Device
device
)
:
MemoryAllocator
(),
device_
(
device
)
{}
std
::
byte
*
StreamOrderedAllocator
::
allocate
(
size_t
size
)
{
std
::
byte
*
StreamOrderedAllocator
::
allocate
(
size_t
size
)
{
if
(
size
==
0
)
{
return
nullptr
;
}
void
*
ptr
=
nullptr
;
void
*
ptr
=
nullptr
;
INFINICORE_CHECK_ERROR
(
infinirtMallocAsync
(
&
ptr
,
size
,
context
::
getStream
()));
INFINICORE_CHECK_ERROR
(
infinirtMallocAsync
(
&
ptr
,
size
,
context
::
getStream
()));
return
(
std
::
byte
*
)
ptr
;
return
(
std
::
byte
*
)
ptr
;
}
}
void
StreamOrderedAllocator
::
deallocate
(
std
::
byte
*
ptr
)
{
void
StreamOrderedAllocator
::
deallocate
(
std
::
byte
*
ptr
)
{
if
(
ptr
==
nullptr
)
{
return
;
}
INFINICORE_CHECK_ERROR
(
infinirtFreeAsync
(
ptr
,
context
::
getStream
()));
INFINICORE_CHECK_ERROR
(
infinirtFreeAsync
(
ptr
,
context
::
getStream
()));
}
}
}
// namespace infinicore
}
// namespace infinicore
src/infinicore/context/context_impl.cc
View file @
01a4a0c8
...
@@ -39,6 +39,10 @@ void ContextImpl::setDevice(Device device) {
...
@@ -39,6 +39,10 @@ void ContextImpl::setDevice(Device device) {
return
;
return
;
}
}
if
(
getCurrentRuntime
()
->
isGraphRecording
())
{
spdlog
::
warn
(
"Switching device runtime during graph recording may break the graph!"
);
}
if
(
runtime_table_
[
int
(
device
.
getType
())][
device
.
getIndex
()]
==
nullptr
)
{
if
(
runtime_table_
[
int
(
device
.
getType
())][
device
.
getIndex
()]
==
nullptr
)
{
// Lazy initialization of runtime if never set before.
// Lazy initialization of runtime if never set before.
runtime_table_
[
int
(
device
.
getType
())][
device
.
getIndex
()]
=
std
::
unique_ptr
<
Runtime
>
(
new
Runtime
(
device
));
runtime_table_
[
int
(
device
.
getType
())][
device
.
getIndex
()]
=
std
::
unique_ptr
<
Runtime
>
(
new
Runtime
(
device
));
...
@@ -178,6 +182,21 @@ void streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) {
...
@@ -178,6 +182,21 @@ void streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) {
ContextImpl
::
singleton
().
getCurrentRuntime
()
->
streamWaitEvent
(
stream
,
event
);
ContextImpl
::
singleton
().
getCurrentRuntime
()
->
streamWaitEvent
(
stream
,
event
);
}
}
bool
isGraphRecording
()
{
return
ContextImpl
::
singleton
().
getCurrentRuntime
()
->
isGraphRecording
();
}
void
startGraphRecording
()
{
ContextImpl
::
singleton
().
getCurrentRuntime
()
->
startGraphRecording
();
}
void
addGraphOperator
(
std
::
shared_ptr
<
graph
::
GraphOperator
>
op
)
{
ContextImpl
::
singleton
().
getCurrentRuntime
()
->
addGraphOperator
(
op
);
}
std
::
shared_ptr
<
graph
::
Graph
>
stopGraphRecording
()
{
return
ContextImpl
::
singleton
().
getCurrentRuntime
()
->
stopGraphRecording
();
}
}
// namespace context
}
// namespace context
}
// namespace infinicore
}
// namespace infinicore
src/infinicore/context/runtime/runtime.cc
View file @
01a4a0c8
...
@@ -8,12 +8,12 @@
...
@@ -8,12 +8,12 @@
#include "../allocators/stream_ordered_allocator.hpp"
#include "../allocators/stream_ordered_allocator.hpp"
namespace
infinicore
{
namespace
infinicore
{
Runtime
::
Runtime
(
Device
device
)
:
device_
(
device
)
{
Runtime
::
Runtime
(
Device
device
)
:
device_
(
device
)
,
graph_manager_
(
std
::
make_unique
<
graph
::
GraphManager
>
())
{
activate
();
activate
();
INFINICORE_CHECK_ERROR
(
infinirtStreamCreate
(
&
stream_
));
INFINICORE_CHECK_ERROR
(
infinirtStreamCreate
(
&
stream_
));
INFINICORE_CHECK_ERROR
(
infiniopCreateHandle
(
&
infiniop_handle_
));
INFINICORE_CHECK_ERROR
(
infiniopCreateHandle
(
&
infiniop_handle_
));
if
(
device_
.
getType
()
==
Device
::
Type
::
CPU
)
{
if
(
device_
.
getType
()
==
Device
::
Type
::
CPU
)
{
device_memory_allocator_
=
std
::
make_unique
<
Host
Allocator
>
();
device_memory_allocator_
=
std
::
make_unique
<
PinnableBlock
Allocator
>
(
device
);
}
else
{
}
else
{
device_memory_allocator_
=
std
::
make_unique
<
PinnableBlockAllocator
>
(
device
);
device_memory_allocator_
=
std
::
make_unique
<
PinnableBlockAllocator
>
(
device
);
pinned_host_memory_allocator_
=
std
::
make_unique
<
DevicePinnedHostAllocator
>
(
device
);
pinned_host_memory_allocator_
=
std
::
make_unique
<
DevicePinnedHostAllocator
>
(
device
);
...
@@ -145,6 +145,25 @@ void Runtime::streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) {
...
@@ -145,6 +145,25 @@ void Runtime::streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) {
INFINICORE_CHECK_ERROR
(
infinirtStreamWaitEvent
(
stream
,
event
));
INFINICORE_CHECK_ERROR
(
infinirtStreamWaitEvent
(
stream
,
event
));
}
}
bool
Runtime
::
isGraphRecording
()
const
{
return
graph_manager_
->
is_recording
();
}
void
Runtime
::
startGraphRecording
()
{
device_memory_allocator_
->
set_pin_mode
(
true
);
return
graph_manager_
->
start_recording
();
}
void
Runtime
::
addGraphOperator
(
std
::
shared_ptr
<
graph
::
GraphOperator
>
op
)
{
return
graph_manager_
->
add_operator
(
op
);
}
std
::
shared_ptr
<
graph
::
Graph
>
Runtime
::
stopGraphRecording
()
{
auto
graph
=
graph_manager_
->
stop_recording
();
device_memory_allocator_
->
set_pin_mode
(
false
);
return
graph
;
}
std
::
string
Runtime
::
toString
()
const
{
std
::
string
Runtime
::
toString
()
const
{
return
fmt
::
format
(
"Runtime({})"
,
device_
.
toString
());
return
fmt
::
format
(
"Runtime({})"
,
device_
.
toString
());
}
}
...
...
src/infinicore/context/runtime/runtime.hpp
View file @
01a4a0c8
#pragma once
#pragma once
#include "../allocators/memory_allocator.hpp"
#include "../allocators/pinnable_block_allocator.hpp"
#include "infinicore/context/context.hpp"
#include "infinicore/context/context.hpp"
#include "../../graph/graph_manager.hpp"
#include <infiniop.h>
#include <infiniop.h>
#include <infinirt.h>
#include <infinirt.h>
...
@@ -13,8 +16,9 @@ private:
...
@@ -13,8 +16,9 @@ private:
Device
device_
;
Device
device_
;
infinirtStream_t
stream_
;
infinirtStream_t
stream_
;
infiniopHandle_t
infiniop_handle_
;
infiniopHandle_t
infiniop_handle_
;
std
::
unique_ptr
<
Memory
Allocator
>
device_memory_allocator_
;
std
::
unique_ptr
<
PinnableBlock
Allocator
>
device_memory_allocator_
;
std
::
unique_ptr
<
MemoryAllocator
>
pinned_host_memory_allocator_
;
std
::
unique_ptr
<
MemoryAllocator
>
pinned_host_memory_allocator_
;
std
::
unique_ptr
<
graph
::
GraphManager
>
graph_manager_
;
protected:
protected:
Runtime
(
Device
device
);
Runtime
(
Device
device
);
...
@@ -48,6 +52,12 @@ public:
...
@@ -48,6 +52,12 @@ public:
float
elapsedTime
(
infinirtEvent_t
start
,
infinirtEvent_t
end
);
float
elapsedTime
(
infinirtEvent_t
start
,
infinirtEvent_t
end
);
void
streamWaitEvent
(
infinirtStream_t
stream
,
infinirtEvent_t
event
);
void
streamWaitEvent
(
infinirtStream_t
stream
,
infinirtEvent_t
event
);
// Graph
bool
isGraphRecording
()
const
;
void
startGraphRecording
();
void
addGraphOperator
(
std
::
shared_ptr
<
graph
::
GraphOperator
>
op
);
std
::
shared_ptr
<
graph
::
Graph
>
stopGraphRecording
();
std
::
string
toString
()
const
;
std
::
string
toString
()
const
;
friend
class
ContextImpl
;
friend
class
ContextImpl
;
...
...
src/infinicore/graph/graph.cc
0 → 100644
View file @
01a4a0c8
#include "graph_manager.hpp"
#include "../utils.hpp"
namespace
infinicore
::
graph
{
/* =========================
* GraphTensor
* ========================= */
GraphTensor
::
GraphTensor
(
const
Tensor
&
tensor
)
:
Tensor
(
tensor
->
to_blob
())
{
}
/* =========================
* GraphOperator
* ========================= */
void
GraphOperator
::
run
()
const
{
runner_
(
planned_meta_
);
}
GraphOperator
::~
GraphOperator
()
{
if
(
deleter_
)
{
deleter_
(
&
planned_meta_
);
}
}
/* =========================
* Graph
* ========================= */
void
Graph
::
run
()
const
{
for
(
auto
&
op
:
op_list_
)
{
op
->
run
();
}
}
void
Graph
::
add_operator
(
std
::
shared_ptr
<
GraphOperator
>
op
)
{
op_list_
.
push_back
(
op
);
}
/* =========================
* GraphManager
* ========================= */
bool
GraphManager
::
is_recording
()
const
{
return
recording_
;
}
void
GraphManager
::
start_recording
()
{
recording_
=
true
;
graph_
=
std
::
make_shared
<
Graph
>
();
}
void
GraphManager
::
add_operator
(
std
::
shared_ptr
<
GraphOperator
>
op
)
{
INFINICORE_ASSERT
(
recording_
);
graph_
->
add_operator
(
op
);
}
std
::
shared_ptr
<
Graph
>
GraphManager
::
stop_recording
()
{
recording_
=
false
;
return
std
::
exchange
(
graph_
,
nullptr
);
}
}
// namespace infinicore::graph
src/infinicore/graph/graph_manager.hpp
0 → 100644
View file @
01a4a0c8
#pragma once
#include "infinicore/graph/graph.hpp"
#include <memory>
#include <vector>
namespace
infinicore
::
graph
{
class
GraphManager
{
public:
GraphManager
()
=
default
;
~
GraphManager
()
=
default
;
bool
is_recording
()
const
;
void
start_recording
();
void
add_operator
(
std
::
shared_ptr
<
GraphOperator
>
op
);
std
::
shared_ptr
<
Graph
>
stop_recording
();
private:
std
::
shared_ptr
<
Graph
>
graph_
;
bool
recording_
=
false
;
};
}
// namespace infinicore::graph
src/infinicore/ops/gemm/gemm.cc
View file @
01a4a0c8
...
@@ -3,16 +3,15 @@
...
@@ -3,16 +3,15 @@
#include "../../utils.hpp"
#include "../../utils.hpp"
namespace
infinicore
::
op
{
namespace
infinicore
::
op
{
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL
(
Gemm
);
common
::
OpDispatcher
<
Gemm
::
schema
>
&
Gemm
::
dispatcher
(
)
{
Gemm
::
Gemm
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
)
{
static
common
::
OpDispatcher
<
Gemm
::
schema
>
dispatcher_
;
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
c
,
a
,
b
)
;
return
dispatcher_
;
INFINICORE_GRAPH_OP_DISPATCH
(
c
->
device
().
getType
(),
c
,
a
,
b
,
alpha
,
beta
)
;
}
;
}
void
Gemm
::
execute
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
)
{
void
Gemm
::
execute
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
)
{
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
c
,
a
,
b
);
INFINICORE_GRAPH_OP_RECORD_OR_RUN
(
Gemm
,
c
,
a
,
b
,
alpha
,
beta
);
infinicore
::
context
::
setDevice
(
c
->
device
());
dispatcher
().
lookup
(
c
->
device
().
getType
())(
c
,
a
,
b
,
alpha
,
beta
);
}
}
Tensor
gemm
(
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
)
{
Tensor
gemm
(
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
)
{
...
...
src/infinicore/ops/gemm/gemm_infiniop.cc
View file @
01a4a0c8
#include "../../utils.hpp"
#include "../infiniop_impl.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/gemm.hpp"
#include "infinicore/ops/gemm.hpp"
#include <infiniop.h>
namespace
infinicore
::
op
::
gemm_impl
::
infiniop
{
namespace
infinicore
::
op
::
gemm_impl
::
infiniop
{
thread_local
common
::
OpCache
<
size_t
,
infiniopGemmDescriptor_t
>
caches
(
INFINIOP_CACHABLE_DESCRIPTOR
(
Descriptor
,
Gemm
,
100
);
100
,
// capacity
[](
infiniopGemmDescriptor_t
&
desc
)
{
struct
PlannedMeta
{
if
(
desc
!=
nullptr
)
{
std
::
shared_ptr
<
Descriptor
>
descriptor
;
INFINICORE_CHECK_ERROR
(
infiniopDestroyGemmDescriptor
(
desc
))
;
graph
::
GraphTensor
workspace
,
c
,
a
,
b
;
desc
=
nullptr
;
float
alpha
,
beta
;
}
}
;
});
void
*
plan
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
)
{
void
calculate
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
)
{
size_t
seed
=
hash_combine
(
c
,
a
,
b
);
size_t
seed
=
hash_combine
(
c
,
b
,
a
,
alpha
,
beta
);
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE
(
auto
device
=
context
::
getDevice
();
Descriptor
,
descriptor
,
Gemm
,
auto
&
cache
=
caches
.
getCache
(
device
);
seed
,
c
->
desc
(),
a
->
desc
(),
b
->
desc
()
);
auto
desc_opt
=
cache
.
get
(
seed
);
INFINIOP_WORKSPACE_TENSOR
(
workspace
,
Gemm
,
descriptor
);
infiniopGemmDescriptor_t
desc
=
nullptr
;
auto
planned
=
new
PlannedMeta
{
if
(
!
desc_opt
)
{
descriptor
,
INFINICORE_CHECK_ERROR
(
infiniopCreateGemmDescriptor
(
graph
::
GraphTensor
(
workspace
),
context
::
getInfiniopHandle
(
device
),
&
desc
,
graph
::
GraphTensor
(
c
)
,
c
->
desc
(),
a
->
desc
(),
b
->
desc
()));
graph
::
GraphTensor
(
a
),
cache
.
put
(
seed
,
desc
);
graph
::
GraphTensor
(
b
),
}
else
{
alpha
,
beta
};
desc
=
*
desc_opt
;
}
return
planned
;
}
size_t
workspace_size
=
0
;
INFINICORE_CHECK_ERROR
(
infiniopGetGemmWorkspaceSize
(
desc
,
&
workspace_size
));
void
run
(
void
*
planned_meta
)
{
std
::
shared_ptr
<
Memory
>
workspace
=
context
::
allocateMemory
(
workspace_size
);
auto
planned
=
reinterpret_cast
<
PlannedMeta
*>
(
planned_meta
);
INFINICORE_CHECK_ERROR
(
infiniopGemm
(
INFINICORE_CHECK_ERROR
(
infiniopGemm
(
desc
,
workspace
->
data
(),
workspace_size
,
planned
->
descriptor
->
desc
,
planned
->
workspace
->
data
(),
planned
->
workspace
->
numel
(),
c
->
data
(),
a
->
data
(),
b
->
data
(),
alpha
,
beta
,
context
::
getStream
()));
planned
->
c
->
data
(),
planned
->
a
->
data
(),
planned
->
b
->
data
(),
planned
->
alpha
,
planned
->
beta
,
context
::
getStream
()));
}
void
cleanup
(
void
**
planned_meta_ptr
)
{
delete
*
reinterpret_cast
<
PlannedMeta
**>
(
planned_meta_ptr
);
*
planned_meta_ptr
=
nullptr
;
}
}
static
bool
registered
=
[]()
{
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE
(
Gemm
,
&
plan
,
&
run
,
&
cleanup
);
Gemm
::
dispatcher
().
registerAll
(
&
calculate
,
false
);
return
true
;
}();
}
// namespace infinicore::op::gemm_impl::infiniop
}
// namespace infinicore::op::gemm_impl::infiniop
src/infinicore/ops/infiniop_impl.hpp
0 → 100644
View file @
01a4a0c8
#pragma once
#include "../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include <infiniop.h>
#define INFINIOP_CACHABLE_DESCRIPTOR(__DESC_TYPE__, __OP_NAME__, __SIZE__) \
struct __DESC_TYPE__ { \
infiniop##__OP_NAME__##Descriptor_t desc; \
Descriptor(infiniop##__OP_NAME__##Descriptor_t desc) : desc(desc) {} \
~Descriptor() { \
if (desc != nullptr) { \
infiniopDestroy##__OP_NAME__##Descriptor(desc); \
desc = nullptr; \
} \
} \
}; \
\
thread_local common::OpCache<size_t, std::shared_ptr<__DESC_TYPE__>> \
caches( \
__SIZE__, \
[](std::shared_ptr<__DESC_TYPE__> &desc) { \
desc = nullptr; \
});
#define INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(__DESC_TYPE__, __DESC_NAME__, __INFINIOP_NAME__, __HASH_KEY__, ...) \
std::shared_ptr<__DESC_TYPE__> __DESC_NAME__; \
{ \
auto device__ = context::getDevice(); \
auto &cache__ = caches.getCache(device__); \
__DESC_NAME__ = cache__.get(__HASH_KEY__).value_or(nullptr); \
if (!__DESC_NAME__) { \
__DESC_NAME__ = std::make_shared<__DESC_TYPE__>(nullptr); \
INFINICORE_CHECK_ERROR(infiniopCreate##__INFINIOP_NAME__##Descriptor( \
context::getInfiniopHandle(device__), \
&__DESC_NAME__->desc, \
__VA_ARGS__)); \
cache__.put(__HASH_KEY__, __DESC_NAME__); \
} \
}
#define INFINIOP_WORKSPACE_TENSOR(__TENSOR_NAME__, __INFINIOP_NAME__, __DESC_NAME__) \
Tensor __TENSOR_NAME__; \
{ \
auto device__ = context::getDevice(); \
size_t workspace_size = 0; \
INFINICORE_CHECK_ERROR(infiniopGet##__INFINIOP_NAME__##WorkspaceSize(__DESC_NAME__->desc, &workspace_size)); \
__TENSOR_NAME__ = Tensor::empty({workspace_size}, DataType::U8, device__); \
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment