Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
9e30b806
Unverified
Commit
9e30b806
authored
Dec 18, 2025
by
thatPepe
Committed by
GitHub
Dec 18, 2025
Browse files
Merge pull request #799 from InfiniTensor/issue/798
issue/798 - fix operator device handling
parents
fb5e36d2
3720127c
Changes
19
Show whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
59 additions
and
68 deletions
+59
-68
include/infinicore/context/context.hpp
include/infinicore/context/context.hpp
+1
-1
include/infinicore/ops/common/cache.hpp
include/infinicore/ops/common/cache.hpp
+4
-0
python/infinicore/context.py
python/infinicore/context.py
+2
-2
src/infinicore/context/context_impl.cc
src/infinicore/context/context_impl.cc
+3
-7
src/infinicore/context/context_impl.hpp
src/infinicore/context/context_impl.hpp
+1
-1
src/infinicore/ops/add/add_infiniop.cc
src/infinicore/ops/add/add_infiniop.cc
+3
-5
src/infinicore/ops/attention/attention_infiniop.cc
src/infinicore/ops/attention/attention_infiniop.cc
+3
-5
src/infinicore/ops/causal_softmax/causal_softmax_infiniop.cc
src/infinicore/ops/causal_softmax/causal_softmax_infiniop.cc
+3
-5
src/infinicore/ops/gemm/gemm_infiniop.cc
src/infinicore/ops/gemm/gemm_infiniop.cc
+3
-5
src/infinicore/ops/mul/mul_infiniop.cc
src/infinicore/ops/mul/mul_infiniop.cc
+3
-5
src/infinicore/ops/random_sample/random_sample_infiniop.cc
src/infinicore/ops/random_sample/random_sample_infiniop.cc
+3
-5
src/infinicore/ops/rearrange/rearrange_infiniop.cc
src/infinicore/ops/rearrange/rearrange_infiniop.cc
+3
-5
src/infinicore/ops/rms_norm/rms_norm_infiniop.cc
src/infinicore/ops/rms_norm/rms_norm_infiniop.cc
+3
-5
src/infinicore/ops/rope/rope_infiniop.cc
src/infinicore/ops/rope/rope_infiniop.cc
+3
-4
src/infinicore/ops/silu/silu_infiniop.cc
src/infinicore/ops/silu/silu_infiniop.cc
+3
-5
src/infinicore/ops/swiglu/swiglu_infiniop.cc
src/infinicore/ops/swiglu/swiglu_infiniop.cc
+3
-5
src/infinicore/pybind11/context.hpp
src/infinicore/pybind11/context.hpp
+1
-2
src/infinicore/tensor/copy.cc
src/infinicore/tensor/copy.cc
+2
-1
src/infinirt/infinirt.cc
src/infinirt/infinirt.cc
+12
-0
No files found.
include/infinicore/context/context.hpp
View file @
9e30b806
...
...
@@ -11,7 +11,7 @@
namespace
infinicore
{
namespace
context
{
void
setDevice
(
Device
device
,
bool
force_cpu
=
false
);
void
setDevice
(
Device
device
);
Device
getDevice
();
size_t
getDeviceCount
(
Device
::
Type
type
);
...
...
include/infinicore/ops/common/cache.hpp
View file @
9e30b806
...
...
@@ -36,6 +36,10 @@ public:
return
cache_vector
[
device_index
];
}
BaseCache
&
getCache
(
Device
device
)
{
return
getCache
(
device
.
getType
(),
device
.
getIndex
());
}
void
setCapacity
(
size_t
capacity
)
{
capacity_
=
capacity
;
for
(
auto
&
vec
:
caches_
)
{
...
...
python/infinicore/context.py
View file @
9e30b806
...
...
@@ -23,13 +23,13 @@ def get_device_count(device_type):
return
_infinicore
.
get_device_count
(
infinicore
.
device
(
device_type
).
_underlying
.
type
)
def
set_device
(
device
,
force_cpu
=
False
):
def
set_device
(
device
):
"""Set the current active device.
Args:
device: The device to set as active
"""
_infinicore
.
set_device
(
device
.
_underlying
,
force_cpu
)
_infinicore
.
set_device
(
device
.
_underlying
)
def
sync_stream
():
...
...
src/infinicore/context/context_impl.cc
View file @
9e30b806
...
...
@@ -33,15 +33,11 @@ Runtime *ContextImpl::getCpuRuntime() {
return
runtime_table_
[
int
(
Device
::
Type
::
CPU
)][
0
].
get
();
}
void
ContextImpl
::
setDevice
(
Device
device
,
bool
force_cpu
)
{
void
ContextImpl
::
setDevice
(
Device
device
)
{
if
(
device
==
getCurrentRuntime
()
->
device
())
{
// Do nothing if the device is already set.
return
;
}
if
(
device
==
Device
(
Device
::
Type
::
CPU
,
0
)
&&
!
force_cpu
)
{
// if not forced, no need to switch to CPU device runtime
return
;
}
if
(
runtime_table_
[
int
(
device
.
getType
())][
device
.
getIndex
()]
==
nullptr
)
{
// Lazy initialization of runtime if never set before.
...
...
@@ -87,8 +83,8 @@ ContextImpl::ContextImpl() {
namespace
context
{
void
setDevice
(
Device
device
,
bool
force_cpu
)
{
ContextImpl
::
singleton
().
setDevice
(
device
,
force_cpu
);
void
setDevice
(
Device
device
)
{
ContextImpl
::
singleton
().
setDevice
(
device
);
}
Device
getDevice
()
{
...
...
src/infinicore/context/context_impl.hpp
View file @
9e30b806
...
...
@@ -21,7 +21,7 @@ public:
Runtime
*
getCpuRuntime
();
void
setDevice
(
Device
,
bool
force_cpu
=
false
);
void
setDevice
(
Device
);
size_t
getDeviceCount
(
Device
::
Type
type
);
...
...
src/infinicore/ops/add/add_infiniop.cc
View file @
9e30b806
...
...
@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopAddDescriptor_t> caches(
void
calculate
(
Tensor
c
,
Tensor
a
,
Tensor
b
)
{
size_t
seed
=
hash_combine
(
c
,
b
,
a
);
auto
device_type
=
context
::
getDevice
().
getType
();
auto
device_index
=
context
::
getDevice
().
getIndex
();
auto
&
cache
=
caches
.
getCache
(
device_type
,
device_index
);
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
auto
desc_opt
=
cache
.
get
(
seed
);
infiniopAddDescriptor_t
desc
=
nullptr
;
if
(
!
desc_opt
)
{
INFINICORE_CHECK_ERROR
(
infiniopCreateAddDescriptor
(
context
::
getInfiniopHandle
(
c
->
device
()
),
&
desc
,
context
::
getInfiniopHandle
(
device
),
&
desc
,
c
->
desc
(),
a
->
desc
(),
b
->
desc
()));
cache
.
put
(
seed
,
desc
);
}
else
{
...
...
src/infinicore/ops/attention/attention_infiniop.cc
View file @
9e30b806
...
...
@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopAttentionDescriptor_t> caches(
void
calculate
(
Tensor
out
,
Tensor
q
,
Tensor
k
,
Tensor
v
,
Tensor
k_cache
,
Tensor
v_cache
,
size_t
pos
)
{
size_t
seed
=
hash_combine
(
out
,
q
,
k
,
v
,
k_cache
,
v_cache
,
pos
);
auto
device_type
=
context
::
getDevice
().
getType
();
auto
device_index
=
context
::
getDevice
().
getIndex
();
auto
&
cache
=
caches
.
getCache
(
device_type
,
device_index
);
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
auto
desc_opt
=
cache
.
get
(
seed
);
infiniopAttentionDescriptor_t
desc
=
nullptr
;
if
(
!
desc_opt
)
{
INFINICORE_CHECK_ERROR
(
infiniopCreateAttentionDescriptor
(
context
::
getInfiniopHandle
(
out
->
device
()
),
&
desc
,
context
::
getInfiniopHandle
(
device
),
&
desc
,
out
->
desc
(),
q
->
desc
(),
k
->
desc
(),
v
->
desc
(),
k_cache
->
desc
(),
v_cache
->
desc
(),
pos
));
cache
.
put
(
seed
,
desc
);
...
...
src/infinicore/ops/causal_softmax/causal_softmax_infiniop.cc
View file @
9e30b806
...
...
@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopCausalSoftmaxDescriptor_t> caches(
void
calculate
(
Tensor
output
,
Tensor
input
)
{
size_t
seed
=
hash_combine
(
output
,
input
);
auto
device_type
=
context
::
getDevice
().
getType
();
auto
device_index
=
context
::
getDevice
().
getIndex
();
auto
&
cache
=
caches
.
getCache
(
device_type
,
device_index
);
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
auto
desc_opt
=
cache
.
get
(
seed
);
infiniopCausalSoftmaxDescriptor_t
desc
=
nullptr
;
if
(
!
desc_opt
)
{
INFINICORE_CHECK_ERROR
(
infiniopCreateCausalSoftmaxDescriptor
(
context
::
getInfiniopHandle
(
output
->
device
()
),
&
desc
,
context
::
getInfiniopHandle
(
device
),
&
desc
,
output
->
desc
(),
input
->
desc
()));
cache
.
put
(
seed
,
desc
);
}
else
{
...
...
src/infinicore/ops/gemm/gemm_infiniop.cc
View file @
9e30b806
...
...
@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopGemmDescriptor_t> caches(
void
calculate
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
)
{
size_t
seed
=
hash_combine
(
c
,
b
,
a
,
alpha
,
beta
);
auto
device_type
=
context
::
getDevice
().
getType
();
auto
device_index
=
context
::
getDevice
().
getIndex
();
auto
&
cache
=
caches
.
getCache
(
device_type
,
device_index
);
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
auto
desc_opt
=
cache
.
get
(
seed
);
infiniopGemmDescriptor_t
desc
=
nullptr
;
if
(
!
desc_opt
)
{
INFINICORE_CHECK_ERROR
(
infiniopCreateGemmDescriptor
(
context
::
getInfiniopHandle
(
c
->
device
()
),
&
desc
,
context
::
getInfiniopHandle
(
device
),
&
desc
,
c
->
desc
(),
a
->
desc
(),
b
->
desc
()));
cache
.
put
(
seed
,
desc
);
}
else
{
...
...
src/infinicore/ops/mul/mul_infiniop.cc
View file @
9e30b806
...
...
@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopMulDescriptor_t> caches(
void
calculate
(
Tensor
c
,
Tensor
a
,
Tensor
b
)
{
size_t
seed
=
hash_combine
(
c
,
b
,
a
);
auto
device_type
=
context
::
getDevice
().
getType
();
auto
device_index
=
context
::
getDevice
().
getIndex
();
auto
&
cache
=
caches
.
getCache
(
device_type
,
device_index
);
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
auto
desc_opt
=
cache
.
get
(
seed
);
infiniopMulDescriptor_t
desc
=
nullptr
;
if
(
!
desc_opt
)
{
INFINICORE_CHECK_ERROR
(
infiniopCreateMulDescriptor
(
context
::
getInfiniopHandle
(
c
->
device
()
),
&
desc
,
context
::
getInfiniopHandle
(
device
),
&
desc
,
c
->
desc
(),
a
->
desc
(),
b
->
desc
()));
cache
.
put
(
seed
,
desc
);
}
else
{
...
...
src/infinicore/ops/random_sample/random_sample_infiniop.cc
View file @
9e30b806
...
...
@@ -25,17 +25,15 @@ static void calculate(
// cache per (result desc + logits desc) on device
size_t
seed
=
hash_combine
(
indices
,
logits
);
auto
device_type
=
context
::
getDevice
().
getType
();
auto
device_index
=
context
::
getDevice
().
getIndex
();
auto
&
cache
=
caches
.
getCache
(
device_type
,
device_index
);
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
auto
desc_opt
=
cache
.
get
(
seed
);
infiniopRandomSampleDescriptor_t
desc
=
nullptr
;
if
(
!
desc_opt
)
{
INFINICORE_CHECK_ERROR
(
infiniopCreateRandomSampleDescriptor
(
context
::
getInfiniopHandle
(
indices
->
device
()
),
&
desc
,
context
::
getInfiniopHandle
(
device
),
&
desc
,
indices
->
desc
(),
logits
->
desc
()));
cache
.
put
(
seed
,
desc
);
}
else
{
...
...
src/infinicore/ops/rearrange/rearrange_infiniop.cc
View file @
9e30b806
...
...
@@ -18,16 +18,14 @@ thread_local common::OpCache<size_t, infiniopRearrangeDescriptor_t> caches(
void
calculate
(
Tensor
y
,
Tensor
x
)
{
size_t
seed
=
hash_combine
(
y
,
x
);
auto
device_type
=
y
->
device
().
getType
();
auto
device_index
=
y
->
device
().
getIndex
();
auto
&
cache
=
caches
.
getCache
(
device_type
,
device_index
);
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
auto
desc_opt
=
cache
.
get
(
seed
);
infiniopRearrangeDescriptor_t
desc
=
nullptr
;
if
(
!
desc_opt
)
{
INFINICORE_CHECK_ERROR
(
infiniopCreateRearrangeDescriptor
(
context
::
getInfiniopHandle
(
y
->
device
()
),
&
desc
,
y
->
desc
(),
x
->
desc
()));
INFINICORE_CHECK_ERROR
(
infiniopCreateRearrangeDescriptor
(
context
::
getInfiniopHandle
(
device
),
&
desc
,
y
->
desc
(),
x
->
desc
()));
cache
.
put
(
seed
,
desc
);
}
else
{
desc
=
*
desc_opt
;
...
...
src/infinicore/ops/rms_norm/rms_norm_infiniop.cc
View file @
9e30b806
...
...
@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopRMSNormDescriptor_t> caches(
void
calculate
(
Tensor
y
,
Tensor
x
,
Tensor
weight
,
float
epsilon
)
{
size_t
seed
=
hash_combine
(
y
,
x
,
weight
,
epsilon
);
auto
device_type
=
context
::
getDevice
().
getType
();
auto
device_index
=
context
::
getDevice
().
getIndex
();
auto
&
cache
=
caches
.
getCache
(
device_type
,
device_index
);
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
auto
desc_opt
=
cache
.
get
(
seed
);
infiniopRMSNormDescriptor_t
desc
=
nullptr
;
if
(
!
desc_opt
)
{
INFINICORE_CHECK_ERROR
(
infiniopCreateRMSNormDescriptor
(
context
::
getInfiniopHandle
(
y
->
device
()
),
&
desc
,
context
::
getInfiniopHandle
(
device
),
&
desc
,
y
->
desc
(),
x
->
desc
(),
weight
->
desc
(),
epsilon
));
cache
.
put
(
seed
,
desc
);
}
else
{
...
...
src/infinicore/ops/rope/rope_infiniop.cc
View file @
9e30b806
...
...
@@ -33,16 +33,15 @@ void calculate(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &s
size_t
key
=
hash_combine
(
x_out
,
x
,
pos
,
sin_cache
,
cos_cache
);
hash_combine
(
key
,
std
::
hash
<
int
>
()(
static_cast
<
int
>
(
infiniop_algo
)));
auto
device_type
=
context
::
getDevice
().
getType
();
auto
device_index
=
context
::
getDevice
().
getIndex
();
auto
&
cache
=
caches
.
getCache
(
device_type
,
device_index
);
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
auto
desc_opt
=
cache
.
get
(
key
);
infiniopRoPEDescriptor_t
desc
=
nullptr
;
if
(
!
desc_opt
)
{
INFINICORE_CHECK_ERROR
(
infiniopCreateRoPEDescriptor
(
context
::
getInfiniopHandle
(
x_out
->
device
()
),
&
desc
,
context
::
getInfiniopHandle
(
device
),
&
desc
,
x_out
->
desc
(),
x
->
desc
(),
pos
->
desc
(),
sin_cache
->
desc
(),
cos_cache
->
desc
(),
infiniop_algo
));
...
...
src/infinicore/ops/silu/silu_infiniop.cc
View file @
9e30b806
...
...
@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopSiluDescriptor_t> caches(
void
calculate
(
Tensor
output
,
Tensor
input
)
{
size_t
seed
=
hash_combine
(
output
,
input
);
auto
device_type
=
context
::
getDevice
().
getType
();
auto
device_index
=
context
::
getDevice
().
getIndex
();
auto
&
cache
=
caches
.
getCache
(
device_type
,
device_index
);
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
auto
desc_opt
=
cache
.
get
(
seed
);
infiniopSiluDescriptor_t
desc
=
nullptr
;
if
(
!
desc_opt
)
{
INFINICORE_CHECK_ERROR
(
infiniopCreateSiluDescriptor
(
context
::
getInfiniopHandle
(
output
->
device
()
),
&
desc
,
context
::
getInfiniopHandle
(
device
),
&
desc
,
output
->
desc
(),
input
->
desc
()));
cache
.
put
(
seed
,
desc
);
}
else
{
...
...
src/infinicore/ops/swiglu/swiglu_infiniop.cc
View file @
9e30b806
...
...
@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopSwiGLUDescriptor_t> caches(
void
calculate
(
Tensor
c
,
Tensor
a
,
Tensor
b
)
{
size_t
seed
=
hash_combine
(
c
,
b
,
a
);
auto
device_type
=
context
::
getDevice
().
getType
();
auto
device_index
=
context
::
getDevice
().
getIndex
();
auto
&
cache
=
caches
.
getCache
(
device_type
,
device_index
);
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
auto
desc_opt
=
cache
.
get
(
seed
);
infiniopSwiGLUDescriptor_t
desc
=
nullptr
;
if
(
!
desc_opt
)
{
INFINICORE_CHECK_ERROR
(
infiniopCreateSwiGLUDescriptor
(
context
::
getInfiniopHandle
(
c
->
device
()
),
&
desc
,
context
::
getInfiniopHandle
(
device
),
&
desc
,
c
->
desc
(),
a
->
desc
(),
b
->
desc
()));
cache
.
put
(
seed
,
desc
);
}
else
{
...
...
src/infinicore/pybind11/context.hpp
View file @
9e30b806
...
...
@@ -16,8 +16,7 @@ inline void bind(py::module &m) {
py
::
arg
(
"device_type"
));
m
.
def
(
"set_device"
,
&
setDevice
,
"Set the current active device"
,
py
::
arg
(
"device"
),
py
::
arg
(
"force_cpu"
));
py
::
arg
(
"device"
));
// Stream and handle management
m
.
def
(
"get_stream"
,
&
getStream
,
"Get the current stream"
);
...
...
src/infinicore/tensor/copy.cc
View file @
9e30b806
...
...
@@ -31,6 +31,7 @@ void TensorImpl::copy_from(Tensor src) {
// Use nbytes() to get the actual tensor size, not the full memory size
size_t
copy_size
=
std
::
min
(
this
->
nbytes
(),
src
->
nbytes
());
if
(
this
->
device
().
getType
()
==
Device
::
Type
::
CPU
)
{
context
::
setDevice
(
src
->
device
());
if
(
this
->
is_contiguous
())
{
context
::
memcpyD2H
(
this
->
data
(),
src
->
data
(),
copy_size
);
}
else
{
...
...
@@ -39,7 +40,7 @@ void TensorImpl::copy_from(Tensor src) {
op
::
rearrange_
(
Tensor
(
const_cast
<
TensorImpl
*>
(
this
)
->
shared_from_this
()),
local_src
);
}
}
else
if
(
src
->
device
().
getType
()
==
Device
::
Type
::
CPU
)
{
context
::
setDevice
(
this
->
device
());
if
(
this
->
is_contiguous
())
{
context
::
memcpyH2D
(
this
->
data
(),
src
->
data
(),
copy_size
);
}
else
{
...
...
src/infinirt/infinirt.cc
View file @
9e30b806
...
...
@@ -10,6 +10,8 @@
thread_local
infiniDevice_t
CURRENT_DEVICE_TYPE
=
INFINI_DEVICE_CPU
;
thread_local
int
CURRENT_DEVICE_ID
=
0
;
thread_local
infiniDevice_t
PREVIOUS_NON_CPU_DEVICE_TYPE
=
INFINI_DEVICE_TYPE_COUNT
;
thread_local
int
PREVIOUS_NON_CPU_DEVICE_ID
=
0
;
__C
infiniStatus_t
infinirtInit
()
{
#ifdef ENABLE_ASCEND_API
...
...
@@ -96,6 +98,16 @@ __C infiniStatus_t infinirtGetDeviceCount(infiniDevice_t device, int *count) {
}
__C
infiniStatus_t
infinirtSetDevice
(
infiniDevice_t
device
,
int
device_id
)
{
bool
skip_set
=
CURRENT_DEVICE_TYPE
==
INFINI_DEVICE_CPU
&&
device
==
PREVIOUS_NON_CPU_DEVICE_TYPE
&&
device_id
==
PREVIOUS_NON_CPU_DEVICE_ID
;
if
(
CURRENT_DEVICE_TYPE
!=
INFINI_DEVICE_CPU
)
{
PREVIOUS_NON_CPU_DEVICE_TYPE
=
CURRENT_DEVICE_TYPE
;
PREVIOUS_NON_CPU_DEVICE_ID
=
CURRENT_DEVICE_ID
;
}
if
(
skip_set
)
{
CURRENT_DEVICE_TYPE
=
device
;
CURRENT_DEVICE_ID
=
device_id
;
return
INFINI_STATUS_SUCCESS
;
}
INFINIRT_CALL_DEVICE_API_AND
(
device
,
setDevice
,
(
device_id
),
{
CURRENT_DEVICE_TYPE
=
device
;
CURRENT_DEVICE_ID
=
device_id
;
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment