Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
006d530c
Commit
006d530c
authored
Jan 06, 2026
by
PanZezhong
Browse files
issue/810 static compute graph infra
parent
caa61e9e
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
380 additions
and
30 deletions
+380
-30
include/infinicore/context/context.hpp
include/infinicore/context/context.hpp
+8
-0
include/infinicore/graph/graph.hpp
include/infinicore/graph/graph.hpp
+45
-0
include/infinicore/ops/gemm.hpp
include/infinicore/ops/gemm.hpp
+10
-1
include/infinicore/tensor.hpp
include/infinicore/tensor.hpp
+3
-1
python/infinicore/__init__.py
python/infinicore/__init__.py
+6
-0
python/infinicore/context.py
python/infinicore/context.py
+22
-0
python/infinicore/graph.py
python/infinicore/graph.py
+18
-0
src/infinicore/context/allocators/pinnable_block_allocator.cc
...infinicore/context/allocators/pinnable_block_allocator.cc
+2
-0
src/infinicore/context/allocators/pinnable_block_allocator.hpp
...nfinicore/context/allocators/pinnable_block_allocator.hpp
+1
-3
src/infinicore/context/context_impl.cc
src/infinicore/context/context_impl.cc
+19
-0
src/infinicore/context/runtime/runtime.cc
src/infinicore/context/runtime/runtime.cc
+21
-2
src/infinicore/context/runtime/runtime.hpp
src/infinicore/context/runtime/runtime.hpp
+12
-2
src/infinicore/graph/graph.cc
src/infinicore/graph/graph.cc
+67
-0
src/infinicore/graph/graph_manager.hpp
src/infinicore/graph/graph_manager.hpp
+25
-0
src/infinicore/ops/gemm/gemm.cc
src/infinicore/ops/gemm/gemm.cc
+27
-3
src/infinicore/ops/gemm/gemm_infiniop.cc
src/infinicore/ops/gemm/gemm_infiniop.cc
+62
-18
src/infinicore/pybind11/context.hpp
src/infinicore/pybind11/context.hpp
+5
-0
src/infinicore/pybind11/graph.hpp
src/infinicore/pybind11/graph.hpp
+17
-0
src/infinicore/pybind11/infinicore.cc
src/infinicore/pybind11/infinicore.cc
+2
-0
src/infinicore/tensor/tensor.cc
src/infinicore/tensor/tensor.cc
+8
-0
No files found.
include/infinicore/context/context.hpp
View file @
006d530c
...
...
@@ -3,6 +3,8 @@
#include "../device.hpp"
#include "../memory.hpp"
#include "../graph/graph.hpp"
#include <infiniop.h>
#include <infinirt.h>
...
...
@@ -40,6 +42,12 @@ void destroyEvent(infinirtEvent_t event);
float
elapsedTime
(
infinirtEvent_t
start
,
infinirtEvent_t
end
);
void
streamWaitEvent
(
infinirtStream_t
stream
,
infinirtEvent_t
event
);
// Graph recording APIs
bool
isGraphRecording
();
void
startGraphRecording
();
void
addGraphOperator
(
std
::
shared_ptr
<
graph
::
GraphOperator
>
op
);
std
::
shared_ptr
<
graph
::
Graph
>
stopGraphRecording
();
}
// namespace context
}
// namespace infinicore
include/infinicore/graph/graph.hpp
0 → 100644
View file @
006d530c
#pragma once
#include <memory>
#include <vector>
#include "../tensor.hpp"
namespace
infinicore
::
graph
{
// Forward declarations
class
GraphManager
;
class
GraphTensor
:
public
Tensor
{
public:
GraphTensor
(
const
Tensor
&
);
};
class
GraphOperator
{
public:
void
run
()
const
;
~
GraphOperator
();
protected:
using
run_schema
=
void
(
*
)(
void
*
);
using
cleanup_schema
=
void
(
*
)(
void
**
);
void
*
planned_meta_
;
run_schema
runner_
;
cleanup_schema
deleter_
;
};
class
Graph
{
public:
Graph
()
=
default
;
~
Graph
()
=
default
;
void
run
()
const
;
protected:
void
add_operator
(
std
::
shared_ptr
<
GraphOperator
>
op
);
std
::
vector
<
std
::
shared_ptr
<
GraphOperator
>>
op_list_
;
friend
class
GraphManager
;
};
}
// namespace infinicore::graph
include/infinicore/ops/gemm.hpp
View file @
006d530c
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
namespace
infinicore
::
op
{
class
Gemm
{
class
Gemm
:
public
graph
::
GraphOperator
{
public:
using
schema
=
void
(
*
)(
Tensor
,
Tensor
,
Tensor
,
float
,
float
);
using
plan_schema
=
void
*
(
*
)(
Tensor
,
Tensor
,
Tensor
,
float
,
float
);
Gemm
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
);
static
void
execute
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
);
static
common
::
OpDispatcher
<
schema
>
&
dispatcher
();
static
common
::
OpDispatcher
<
plan_schema
>
&
plan_dispatcher
();
static
common
::
OpDispatcher
<
run_schema
>
&
run_dispatcher
();
static
common
::
OpDispatcher
<
cleanup_schema
>
&
cleanup_dispatcher
();
};
Tensor
gemm
(
Tensor
a
,
Tensor
b
,
float
alpha
=
1.0
f
,
float
beta
=
0.0
f
);
...
...
include/infinicore/tensor.hpp
View file @
006d530c
...
...
@@ -133,6 +133,8 @@ public:
void
debug
()
const
;
Tensor
to_blob
()
const
;
///
/// Data Transfer APIs
///
...
...
@@ -294,7 +296,7 @@ protected:
friend
class
Tensor
;
pr
ivate
:
pr
otected
:
TensorMetaData
meta_
;
TensorData
data_
;
};
...
...
python/infinicore/__init__.py
View file @
006d530c
...
...
@@ -8,7 +8,10 @@ from infinicore.context import (
get_device
,
get_device_count
,
get_stream
,
is_graph_recording
,
set_device
,
start_graph_recording
,
stop_graph_recording
,
sync_device
,
sync_stream
,
)
...
...
@@ -80,6 +83,9 @@ __all__ = [
"set_device"
,
"sync_device"
,
"sync_stream"
,
"is_graph_recording"
,
"start_graph_recording"
,
"stop_graph_recording"
,
# Data Types.
"bfloat16"
,
"bool"
,
...
...
python/infinicore/context.py
View file @
006d530c
import
infinicore.device
from
infinicore.graph
import
Graph
from
infinicore.lib
import
_infinicore
...
...
@@ -49,3 +50,24 @@ def get_stream():
stream: The current stream object
"""
return
_infinicore
.
get_stream
()
def
is_graph_recording
():
"""Check if the current graph is recording.
Returns:
bool: True if the current graph is recording, False otherwise
"""
return
_infinicore
.
is_graph_recording
()
def
start_graph_recording
(
device
=
None
):
"""Start recording the current graph."""
if
device
is
not
None
:
set_device
(
device
)
_infinicore
.
start_graph_recording
()
def
stop_graph_recording
():
"""Stop recording the current graph."""
return
Graph
(
_infinicore
.
stop_graph_recording
())
python/infinicore/graph.py
0 → 100644
View file @
006d530c
from
infinicore.lib
import
_infinicore
class
Graph
:
"""
Python wrapper around a InfiniCore Graph instance.
"""
def
__init__
(
self
,
graph
:
_infinicore
.
Graph
):
if
not
isinstance
(
graph
,
_infinicore
.
Graph
):
raise
TypeError
(
"Expected _infinicore.Graph"
)
self
.
_graph
=
graph
def
run
(
self
):
return
self
.
_graph
.
run
()
def
__repr__
(
self
):
return
f
"<Graph wrapper of
{
self
.
_graph
!
r
}
>"
src/infinicore/context/allocators/pinnable_block_allocator.cc
View file @
006d530c
#include "pinnable_block_allocator.hpp"
#include "../context_impl.hpp"
#include "../../utils.hpp"
#include <algorithm>
...
...
src/infinicore/context/allocators/pinnable_block_allocator.hpp
View file @
006d530c
...
...
@@ -2,8 +2,6 @@
#include "memory_allocator.hpp"
#include "../context_impl.hpp"
#include <mutex>
#include <unordered_map>
#include <vector>
...
...
@@ -25,7 +23,7 @@ class PinnableBlockAllocator : public MemoryAllocator {
};
public:
explicit
PinnableBlockAllocator
(
Device
device
);
PinnableBlockAllocator
(
Device
device
);
~
PinnableBlockAllocator
();
std
::
byte
*
allocate
(
size_t
size
)
override
;
...
...
src/infinicore/context/context_impl.cc
View file @
006d530c
...
...
@@ -39,6 +39,10 @@ void ContextImpl::setDevice(Device device) {
return
;
}
if
(
getCurrentRuntime
()
->
isGraphRecording
())
{
spdlog
::
warn
(
"Switching device runtime during graph recording may break the graph!"
);
}
if
(
runtime_table_
[
int
(
device
.
getType
())][
device
.
getIndex
()]
==
nullptr
)
{
// Lazy initialization of runtime if never set before.
runtime_table_
[
int
(
device
.
getType
())][
device
.
getIndex
()]
=
std
::
unique_ptr
<
Runtime
>
(
new
Runtime
(
device
));
...
...
@@ -178,6 +182,21 @@ void streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) {
ContextImpl
::
singleton
().
getCurrentRuntime
()
->
streamWaitEvent
(
stream
,
event
);
}
bool
isGraphRecording
()
{
return
ContextImpl
::
singleton
().
getCurrentRuntime
()
->
isGraphRecording
();
}
void
startGraphRecording
()
{
ContextImpl
::
singleton
().
getCurrentRuntime
()
->
startGraphRecording
();
}
void
addGraphOperator
(
std
::
shared_ptr
<
graph
::
GraphOperator
>
op
)
{
ContextImpl
::
singleton
().
getCurrentRuntime
()
->
addGraphOperator
(
op
);
}
std
::
shared_ptr
<
graph
::
Graph
>
stopGraphRecording
()
{
return
ContextImpl
::
singleton
().
getCurrentRuntime
()
->
stopGraphRecording
();
}
}
// namespace context
}
// namespace infinicore
src/infinicore/context/runtime/runtime.cc
View file @
006d530c
...
...
@@ -8,12 +8,12 @@
#include "../allocators/stream_ordered_allocator.hpp"
namespace
infinicore
{
Runtime
::
Runtime
(
Device
device
)
:
device_
(
device
)
{
Runtime
::
Runtime
(
Device
device
)
:
device_
(
device
)
,
graph_manager_
(
std
::
make_unique
<
graph
::
GraphManager
>
())
{
activate
();
INFINICORE_CHECK_ERROR
(
infinirtStreamCreate
(
&
stream_
));
INFINICORE_CHECK_ERROR
(
infiniopCreateHandle
(
&
infiniop_handle_
));
if
(
device_
.
getType
()
==
Device
::
Type
::
CPU
)
{
device_memory_allocator_
=
std
::
make_unique
<
Host
Allocator
>
();
device_memory_allocator_
=
std
::
make_unique
<
PinnableBlock
Allocator
>
(
device
);
}
else
{
device_memory_allocator_
=
std
::
make_unique
<
PinnableBlockAllocator
>
(
device
);
pinned_host_memory_allocator_
=
std
::
make_unique
<
DevicePinnedHostAllocator
>
(
device
);
...
...
@@ -145,6 +145,25 @@ void Runtime::streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) {
INFINICORE_CHECK_ERROR
(
infinirtStreamWaitEvent
(
stream
,
event
));
}
bool
Runtime
::
isGraphRecording
()
const
{
return
graph_manager_
->
is_recording
();
}
void
Runtime
::
startGraphRecording
()
{
device_memory_allocator_
->
set_pin_mode
(
true
);
return
graph_manager_
->
start_recording
();
}
void
Runtime
::
addGraphOperator
(
std
::
shared_ptr
<
graph
::
GraphOperator
>
op
)
{
return
graph_manager_
->
add_operator
(
op
);
}
std
::
shared_ptr
<
graph
::
Graph
>
Runtime
::
stopGraphRecording
()
{
auto
graph
=
graph_manager_
->
stop_recording
();
device_memory_allocator_
->
set_pin_mode
(
false
);
return
graph
;
}
std
::
string
Runtime
::
toString
()
const
{
return
fmt
::
format
(
"Runtime({})"
,
device_
.
toString
());
}
...
...
src/infinicore/context/runtime/runtime.hpp
View file @
006d530c
#pragma once
#include "../allocators/memory_allocator.hpp"
#include "../allocators/pinnable_block_allocator.hpp"
#include "infinicore/context/context.hpp"
#include "../../graph/graph_manager.hpp"
#include <infiniop.h>
#include <infinirt.h>
...
...
@@ -13,8 +16,9 @@ private:
Device
device_
;
infinirtStream_t
stream_
;
infiniopHandle_t
infiniop_handle_
;
std
::
unique_ptr
<
Memory
Allocator
>
device_memory_allocator_
;
std
::
unique_ptr
<
PinnableBlock
Allocator
>
device_memory_allocator_
;
std
::
unique_ptr
<
MemoryAllocator
>
pinned_host_memory_allocator_
;
std
::
unique_ptr
<
graph
::
GraphManager
>
graph_manager_
;
protected:
Runtime
(
Device
device
);
...
...
@@ -48,6 +52,12 @@ public:
float
elapsedTime
(
infinirtEvent_t
start
,
infinirtEvent_t
end
);
void
streamWaitEvent
(
infinirtStream_t
stream
,
infinirtEvent_t
event
);
// Graph
bool
isGraphRecording
()
const
;
void
startGraphRecording
();
void
addGraphOperator
(
std
::
shared_ptr
<
graph
::
GraphOperator
>
op
);
std
::
shared_ptr
<
graph
::
Graph
>
stopGraphRecording
();
std
::
string
toString
()
const
;
friend
class
ContextImpl
;
...
...
src/infinicore/graph/graph.cc
0 → 100644
View file @
006d530c
#include "graph_manager.hpp"
#include "../utils.hpp"
namespace
infinicore
::
graph
{
/* =========================
* GraphTensor
* ========================= */
GraphTensor
::
GraphTensor
(
const
Tensor
&
tensor
)
:
Tensor
(
tensor
->
to_blob
())
{
}
/* =========================
* GraphOperator
* ========================= */
void
GraphOperator
::
run
()
const
{
runner_
(
planned_meta_
);
}
GraphOperator
::~
GraphOperator
()
{
if
(
deleter_
)
{
deleter_
(
&
planned_meta_
);
}
}
/* =========================
* Graph
* ========================= */
void
Graph
::
run
()
const
{
for
(
auto
&
op
:
op_list_
)
{
op
->
run
();
}
}
void
Graph
::
add_operator
(
std
::
shared_ptr
<
GraphOperator
>
op
)
{
op_list_
.
push_back
(
op
);
}
/* =========================
* GraphManager
* ========================= */
bool
GraphManager
::
is_recording
()
const
{
return
recording_
;
}
void
GraphManager
::
start_recording
()
{
recording_
=
true
;
graph_
=
std
::
make_shared
<
Graph
>
();
}
void
GraphManager
::
add_operator
(
std
::
shared_ptr
<
GraphOperator
>
op
)
{
INFINICORE_ASSERT
(
recording_
);
graph_
->
add_operator
(
op
);
}
std
::
shared_ptr
<
Graph
>
GraphManager
::
stop_recording
()
{
recording_
=
false
;
return
std
::
exchange
(
graph_
,
nullptr
);
}
}
// namespace infinicore::graph
src/infinicore/graph/graph_manager.hpp
0 → 100644
View file @
006d530c
#pragma once
#include "infinicore/graph/graph.hpp"
#include <memory>
#include <vector>
namespace
infinicore
::
graph
{
class
GraphManager
{
public:
GraphManager
()
=
default
;
~
GraphManager
()
=
default
;
bool
is_recording
()
const
;
void
start_recording
();
void
add_operator
(
std
::
shared_ptr
<
GraphOperator
>
op
);
std
::
shared_ptr
<
Graph
>
stop_recording
();
private:
std
::
shared_ptr
<
Graph
>
graph_
;
bool
recording_
=
false
;
};
}
// namespace infinicore::graph
src/infinicore/ops/gemm/gemm.cc
View file @
006d530c
...
...
@@ -9,10 +9,34 @@ common::OpDispatcher<Gemm::schema> &Gemm::dispatcher() {
return
dispatcher_
;
};
void
Gemm
::
execute
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
)
{
common
::
OpDispatcher
<
Gemm
::
plan_schema
>
&
Gemm
::
plan_dispatcher
()
{
static
common
::
OpDispatcher
<
Gemm
::
plan_schema
>
dispatcher_
;
return
dispatcher_
;
}
common
::
OpDispatcher
<
Gemm
::
run_schema
>
&
Gemm
::
run_dispatcher
()
{
static
common
::
OpDispatcher
<
Gemm
::
run_schema
>
dispatcher_
;
return
dispatcher_
;
}
common
::
OpDispatcher
<
Gemm
::
cleanup_schema
>
&
Gemm
::
cleanup_dispatcher
()
{
static
common
::
OpDispatcher
<
Gemm
::
cleanup_schema
>
dispatcher_
;
return
dispatcher_
;
}
Gemm
::
Gemm
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
)
{
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
c
,
a
,
b
);
infinicore
::
context
::
setDevice
(
c
->
device
());
dispatcher
().
lookup
(
c
->
device
().
getType
())(
c
,
a
,
b
,
alpha
,
beta
);
planned_meta_
=
plan_dispatcher
().
lookup
(
c
->
device
().
getType
())(
c
,
a
,
b
,
alpha
,
beta
);
runner_
=
run_dispatcher
().
lookup
(
c
->
device
().
getType
());
deleter_
=
cleanup_dispatcher
().
lookup
(
c
->
device
().
getType
());
}
void
Gemm
::
execute
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
)
{
auto
op
=
std
::
make_shared
<
Gemm
>
(
c
,
a
,
b
,
alpha
,
beta
);
if
(
context
::
isGraphRecording
())
{
context
::
addGraphOperator
(
op
);
}
else
{
op
->
run
();
}
}
Tensor
gemm
(
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
)
{
...
...
src/infinicore/ops/gemm/gemm_infiniop.cc
View file @
006d530c
...
...
@@ -5,45 +5,89 @@
#include <infiniop.h>
namespace
infinicore
::
op
::
gemm_impl
::
infiniop
{
thread_local
common
::
OpCache
<
size_t
,
infiniopGemmDescriptor_t
>
caches
(
100
,
// capacity
[](
infiniopGemmDescriptor_t
&
desc
)
{
// A desc holder to make it a shared pointer that can auto clean-up
struct
Descriptor
{
infiniopGemmDescriptor_t
desc
;
Descriptor
(
infiniopGemmDescriptor_t
desc
)
:
desc
(
desc
)
{}
~
Descriptor
()
{
if
(
desc
!=
nullptr
)
{
INFINICORE_CHECK_ERROR
(
infiniopDestroyGemmDescriptor
(
desc
)
)
;
infiniopDestroyGemmDescriptor
(
desc
);
desc
=
nullptr
;
}
});
}
};
void
calculate
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
)
{
thread_local
common
::
OpCache
<
size_t
,
std
::
shared_ptr
<
Descriptor
>>
caches
(
// capacity
100
,
// on evict
[](
std
::
shared_ptr
<
Descriptor
>
&
desc
)
{
desc
=
nullptr
;
});
struct
PlannedMeta
{
std
::
shared_ptr
<
Descriptor
>
descriptor
;
graph
::
GraphTensor
workspace
,
c
,
a
,
b
;
float
alpha
,
beta
;
};
void
*
plan
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
)
{
size_t
seed
=
hash_combine
(
c
,
b
,
a
,
alpha
,
beta
);
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
auto
desc_opt
=
cache
.
get
(
seed
);
infiniopGemmDescriptor_t
desc
=
nullptr
;
auto
descriptor
=
cache
.
get
(
seed
).
value_or
(
nullptr
);
if
(
!
desc_opt
)
{
if
(
!
descriptor
)
{
descriptor
=
std
::
make_shared
<
Descriptor
>
(
nullptr
);
INFINICORE_CHECK_ERROR
(
infiniopCreateGemmDescriptor
(
context
::
getInfiniopHandle
(
device
),
&
desc
,
context
::
getInfiniopHandle
(
device
),
&
descriptor
->
desc
,
c
->
desc
(),
a
->
desc
(),
b
->
desc
()));
cache
.
put
(
seed
,
desc
);
}
else
{
desc
=
*
desc_opt
;
cache
.
put
(
seed
,
descriptor
);
}
size_t
workspace_size
=
0
;
INFINICORE_CHECK_ERROR
(
infiniopGetGemmWorkspaceSize
(
desc
,
&
workspace_size
));
std
::
shared_ptr
<
Memory
>
workspace
=
context
::
allocateMemory
(
workspace_size
);
INFINICORE_CHECK_ERROR
(
infiniopGetGemmWorkspaceSize
(
descriptor
->
desc
,
&
workspace_size
));
Tensor
workspace
=
Tensor
::
empty
({
workspace_size
},
DataType
::
U8
,
device
);
auto
planned
=
new
PlannedMeta
{
descriptor
,
graph
::
GraphTensor
(
workspace
),
graph
::
GraphTensor
(
c
),
graph
::
GraphTensor
(
a
),
graph
::
GraphTensor
(
b
),
alpha
,
beta
};
return
planned
;
}
void
run
(
void
*
planned_meta
)
{
auto
planned
=
reinterpret_cast
<
PlannedMeta
*>
(
planned_meta
);
INFINICORE_CHECK_ERROR
(
infiniopGemm
(
desc
,
workspace
->
data
(),
workspace_size
,
c
->
data
(),
a
->
data
(),
b
->
data
(),
alpha
,
beta
,
context
::
getStream
()));
planned
->
descriptor
->
desc
,
planned
->
workspace
->
data
(),
planned
->
workspace
->
numel
(),
planned
->
c
->
data
(),
planned
->
a
->
data
(),
planned
->
b
->
data
(),
planned
->
alpha
,
planned
->
beta
,
context
::
getStream
()));
}
void
cleanup
(
void
**
planned_meta_ptr
)
{
delete
*
reinterpret_cast
<
PlannedMeta
**>
(
planned_meta_ptr
);
*
planned_meta_ptr
=
nullptr
;
}
void
calculate
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
)
{
auto
planned
=
plan
(
c
,
a
,
b
,
alpha
,
beta
);
run
(
planned
);
cleanup
(
&
planned
);
}
static
bool
registered
=
[]()
{
Gemm
::
dispatcher
().
registerAll
(
&
calculate
,
false
);
Gemm
::
plan_dispatcher
().
registerAll
(
&
plan
,
false
);
Gemm
::
run_dispatcher
().
registerAll
(
&
run
,
false
);
Gemm
::
cleanup_dispatcher
().
registerAll
(
&
cleanup
,
false
);
return
true
;
}();
...
...
src/infinicore/pybind11/context.hpp
View file @
006d530c
...
...
@@ -24,6 +24,11 @@ inline void bind(py::module &m) {
// Synchronization
m
.
def
(
"sync_stream"
,
&
syncStream
,
"Synchronize the current stream"
);
m
.
def
(
"sync_device"
,
&
syncDevice
,
"Synchronize the current device"
);
// Graph
m
.
def
(
"is_graph_recording"
,
&
isGraphRecording
,
"Check if graph recording is turned on"
);
m
.
def
(
"start_graph_recording"
,
&
startGraphRecording
,
"Start graph recording"
);
m
.
def
(
"stop_graph_recording"
,
&
stopGraphRecording
,
"Stop graph recording and return the graph"
);
}
}
// namespace infinicore::context
src/infinicore/pybind11/graph.hpp
0 → 100644
View file @
006d530c
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "infinicore.hpp"
namespace
py
=
pybind11
;
namespace
infinicore
::
graph
{
inline
void
bind
(
py
::
module_
&
m
)
{
py
::
class_
<
infinicore
::
graph
::
Graph
,
std
::
shared_ptr
<
infinicore
::
graph
::
Graph
>>
(
m
,
"Graph"
)
.
def
(
py
::
init
<>
())
// allow construction
.
def
(
"run"
,
&
infinicore
::
graph
::
Graph
::
run
);
}
}
// namespace infinicore::graph
src/infinicore/pybind11/infinicore.cc
View file @
006d530c
...
...
@@ -6,6 +6,7 @@
#include "device.hpp"
#include "device_event.hpp"
#include "dtype.hpp"
#include "graph.hpp"
#include "ops.hpp"
#include "tensor.hpp"
...
...
@@ -18,6 +19,7 @@ PYBIND11_MODULE(_infinicore, m) {
dtype
::
bind
(
m
);
ops
::
bind
(
m
);
tensor
::
bind
(
m
);
graph
::
bind
(
m
);
}
}
// namespace infinicore
src/infinicore/tensor/tensor.cc
View file @
006d530c
...
...
@@ -275,4 +275,12 @@ std::shared_ptr<TensorImpl> TensorImpl::strided_from_blob(
return
t
;
}
Tensor
TensorImpl
::
to_blob
()
const
{
auto
t
=
std
::
shared_ptr
<
TensorImpl
>
(
new
TensorImpl
(
shape
(),
strides
(),
dtype
()));
t
->
data_
.
offset
=
this
->
data_
.
offset
;
t
->
data_
.
memory
=
std
::
make_shared
<
Memory
>
(
this
->
data_
.
memory
->
data
(),
this
->
data_
.
memory
->
size
(),
this
->
data_
.
memory
->
device
(),
nullptr
);
return
Tensor
{
t
};
}
}
// namespace infinicore
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment