Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
8c16b808
Unverified
Commit
8c16b808
authored
Dec 03, 2025
by
PanZezhong1725
Committed by
GitHub
Dec 03, 2025
Browse files
Merge pull request #701 from InfiniTensor/issue/700
issue/700 算子执行时根据张量 set device, CPU时无操作
parents
57291db6
fa149d69
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
84 additions
and
33 deletions
+84
-33
src/infinicore/ops/README.md
src/infinicore/ops/README.md
+32
-23
src/infinicore/ops/add/add.cc
src/infinicore/ops/add/add.cc
+4
-1
src/infinicore/ops/attention/attention.cc
src/infinicore/ops/attention/attention.cc
+4
-1
src/infinicore/ops/causal_softmax/causal_softmax.cc
src/infinicore/ops/causal_softmax/causal_softmax.cc
+6
-1
src/infinicore/ops/gemm/gemm.cc
src/infinicore/ops/gemm/gemm.cc
+5
-1
src/infinicore/ops/mul/mul.cc
src/infinicore/ops/mul/mul.cc
+5
-1
src/infinicore/ops/random_sample/random_sample.cc
src/infinicore/ops/random_sample/random_sample.cc
+5
-1
src/infinicore/ops/rms_norm/rms_norm.cc
src/infinicore/ops/rms_norm/rms_norm.cc
+5
-1
src/infinicore/ops/rope/rope.cc
src/infinicore/ops/rope/rope.cc
+6
-1
src/infinicore/ops/silu/silu.cc
src/infinicore/ops/silu/silu.cc
+6
-1
src/infinicore/ops/swiglu/swiglu.cc
src/infinicore/ops/swiglu/swiglu.cc
+6
-1
No files found.
src/infinicore/ops/README.md
View file @
8c16b808
...
@@ -14,7 +14,7 @@ infinicore::ops 模块包含了 InfiniCore 所有 C++ 算子的接口和实现
...
@@ -14,7 +14,7 @@ infinicore::ops 模块包含了 InfiniCore 所有 C++ 算子的接口和实现
-
execute 函数,算子的计算逻辑。
-
execute 函数,算子的计算逻辑。
-
dispatcher 分发器,用于注册算子在不同设备上的 kernel 实现。一个进程中,一种算子只有一个全局分发器,每种设备上只能同时注册一个 kernel 实现,可以多次注册对之前的实现进行覆盖。详细信息请参考
`include/infinicore/ops/common/dispatcher.hpp`
。
-
dispatcher 分发器,用于注册算子在不同设备上的 kernel 实现。一个进程中,一种算子只有一个全局分发器,每种设备上只能同时注册一个 kernel 实现,可以多次注册对之前的实现进行覆盖。详细信息请参考
`include/infinicore/ops/common/dispatcher.hpp`
。
示例
`
Matmul
`
算子的头文件如下:
示例
`
Gemm
`
算子的头文件如下:
```
c++
```
c++
#pragma once
#pragma once
...
@@ -23,15 +23,17 @@ infinicore::ops 模块包含了 InfiniCore 所有 C++ 算子的接口和实现
...
@@ -23,15 +23,17 @@ infinicore::ops 模块包含了 InfiniCore 所有 C++ 算子的接口和实现
#include "common/op.hpp"
#include "common/op.hpp"
namespace
infinicore
::
op
{
namespace
infinicore
::
op
{
class
Matmul
{
class
Gemm
{
public:
public:
using
schema
=
void
(
*
)(
Tensor
,
Tensor
,
Tensor
);
using
schema
=
void
(
*
)(
Tensor
,
Tensor
,
Tensor
,
float
,
float
);
static
void
execute
(
Tensor
c
,
Tensor
a
,
Tensor
b
);
static
void
execute
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
);
static
common
::
OpDispatcher
<
schema
>
&
dispatcher
();
static
common
::
OpDispatcher
<
schema
>
&
dispatcher
();
};
};
Tensor
matmul
(
Tensor
a
,
Tensor
b
);
Tensor
gemm
(
Tensor
a
,
Tensor
b
,
float
alpha
=
1.0
f
,
float
beta
=
0.0
f
);
void
matmul_
(
Tensor
c
,
Tensor
a
,
Tensor
b
);
void
gemm_
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
);
}
}
```
```
...
@@ -39,38 +41,46 @@ void matmul_(Tensor c, Tensor a, Tensor b);
...
@@ -39,38 +41,46 @@ void matmul_(Tensor c, Tensor a, Tensor b);
在
`src/infinicore/ops/*OPNAME*/*OPNAME*.cpp`
文件中实现算子的计算逻辑。
在
`src/infinicore/ops/*OPNAME*/*OPNAME*.cpp`
文件中实现算子的计算逻辑。
-
execute 函数,使用算子的分发器,调用对应硬件上的核函数。
-
execute 函数,使用算子的分发器,调用对应硬件上的核函数。
可以通过
`context::setDevice`
来改变当前运行时的设备种类。
-
计算接口,使用 execute 函数实现算子接口的计算逻辑,包括 in-place 和 out-of-place 两种模式,其中 in-place 模式的接口函数名以
`_`
结尾,将输出接口写入给定的参数中;out-of-place 模式的接口会为输出创建新的 Tensor。
-
计算接口,使用 execute 函数实现算子接口的计算逻辑,包括 in-place 和 out-of-place 两种模式,其中 in-place 模式的接口函数名以
`_`
结尾,将输出接口写入给定的参数中;out-of-place 模式的接口会为输出创建新的 Tensor。
示例
`
Matmul
`
算子的实现如下:
示例
`
Gemm
`
算子的实现如下:
```
c++
```
c++
#include "infinicore/ops/matmul.hpp"
#include "infinicore/ops/gemm.hpp"
#include "../../utils.hpp"
namespace
infinicore
::
op
{
namespace
infinicore
::
op
{
common
::
OpDispatcher
<
Matmul
::
schema
>
&
Matmul
::
dispatcher
()
{
common
::
OpDispatcher
<
Gemm
::
schema
>
&
Gemm
::
dispatcher
()
{
static
common
::
OpDispatcher
<
Matmul
::
schema
>
dispatcher_
;
static
common
::
OpDispatcher
<
Gemm
::
schema
>
dispatcher_
;
return
dispatcher_
;
return
dispatcher_
;
};
};
void
Matmul
::
execute
(
Tensor
c
,
Tensor
a
,
Tensor
b
)
{
void
Gemm
::
execute
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
)
{
dispatcher
().
lookup
(
context
::
getDevice
().
getType
())(
c
,
a
,
b
);
// 检查张量设备是否一致
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
c
,
a
,
b
);
// 将运行时设备设置为与张量一致。若设备为CPU时,该接口不会进行任何操作
infinicore
::
context
::
setDevice
(
c
->
device
());
// 根据张量的设备种类选择 kernel,执行计算
dispatcher
().
lookup
(
c
->
device
().
getType
())(
c
,
a
,
b
,
alpha
,
beta
);
}
}
Tensor
matmul
(
Tensor
a
,
Tensor
b
)
{
Tensor
gemm
(
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
)
{
Shape
shape
=
a
->
shape
();
Shape
shape
=
a
->
shape
();
Size
size
=
a
->
ndim
();
Size
size
=
a
->
ndim
();
shape
[
size
-
1
]
=
b
->
size
(
size
-
1
);
shape
[
size
-
1
]
=
b
->
size
(
size
-
1
);
auto
c
=
Tensor
::
empty
(
shape
,
a
->
dtype
(),
a
->
device
());
auto
c
=
Tensor
::
empty
(
shape
,
a
->
dtype
(),
a
->
device
());
matmul
_
(
c
,
a
,
b
);
gemm
_
(
c
,
a
,
b
,
alpha
,
beta
);
return
c
;
return
c
;
}
}
void
matmul_
(
Tensor
c
,
Tensor
a
,
Tensor
b
)
{
void
gemm_
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
)
{
Matmul
::
execute
(
c
,
a
,
b
);
Gemm
::
execute
(
c
,
a
,
b
,
alpha
,
beta
);
}
}
}
}
```
```
### 3. Kernel 注册
### 3. Kernel 注册
...
@@ -91,7 +101,7 @@ void registerAll(Fn fn, bool override_existing = true);
...
@@ -91,7 +101,7 @@ void registerAll(Fn fn, bool override_existing = true);
Fn
lookup
(
Device
::
Type
device_type
)
const
;
Fn
lookup
(
Device
::
Type
device_type
)
const
;
```
```
如果你为多个(或全部)设备注册了同一个 kernel 实现,那么你需要自行实现不同设备的分发机制。比如本框架中的 InfiniOP 算子库,其算子接口在不同平台都保持了一致,并根据当前设备类型自动分发,因此在注册时会为所有平台注册同一个计算函数。以
Matmul
算子为例:
如果你为多个(或全部)设备注册了同一个 kernel 实现,那么你需要自行实现不同设备的分发机制。比如本框架中的 InfiniOP 算子库,其算子接口在不同平台都保持了一致,并根据当前设备类型自动分发,因此在注册时会为所有平台注册同一个计算函数。以
Gemm
算子为例:
```
c++
```
c++
namespace
infinicore
::
op
::
matmul_impl
::
infiniop
{
namespace
infinicore
::
op
::
matmul_impl
::
infiniop
{
...
@@ -107,19 +117,18 @@ thread_local common::OpCache<size_t, infiniopGemmDescriptor_t> caches(
...
@@ -107,19 +117,18 @@ thread_local common::OpCache<size_t, infiniopGemmDescriptor_t> caches(
});
});
// 计算函数
// 计算函数
void
calculate
(
Tensor
c
,
Tensor
a
,
Tensor
b
)
{
void
calculate
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
)
{
// ...
// ...
INFINICORE_CHECK_ERROR
(
infiniopGemm
(
INFINICORE_CHECK_ERROR
(
infiniopGemm
(
desc
,
workspace
->
data
(),
workspace_size
,
desc
,
workspace
->
data
(),
workspace_size
,
c
->
data
(),
a
->
data
(),
b
->
data
(),
1.
f
,
0.
f
,
context
::
getStream
()));
c
->
data
(),
a
->
data
(),
b
->
data
(),
alpha
,
beta
,
context
::
getStream
()));
}
}
// 在加载 InfiniCore 时为全平台注册 InfiniOP实现
// 在加载 InfiniCore 时为全平台注册 InfiniOP实现
static
bool
registered
=
[]()
{
static
bool
registered
=
[]()
{
Matmul
::
dispatcher
().
registerAll
(
&
calculate
,
false
);
Gemm
::
dispatcher
().
registerAll
(
&
calculate
,
false
);
return
true
;
return
true
;
}();
}();
}
}
```
```
...
...
src/infinicore/ops/add/add.cc
View file @
8c16b808
#include "infinicore/ops/add.hpp"
#include "infinicore/ops/add.hpp"
#include "../../utils.hpp"
namespace
infinicore
::
op
{
namespace
infinicore
::
op
{
...
@@ -8,7 +9,9 @@ common::OpDispatcher<Add::schema> &Add::dispatcher() {
...
@@ -8,7 +9,9 @@ common::OpDispatcher<Add::schema> &Add::dispatcher() {
};
};
void
Add
::
execute
(
Tensor
c
,
Tensor
a
,
Tensor
b
)
{
void
Add
::
execute
(
Tensor
c
,
Tensor
a
,
Tensor
b
)
{
dispatcher
().
lookup
(
context
::
getDevice
().
getType
())(
c
,
a
,
b
);
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
c
,
a
,
b
);
infinicore
::
context
::
setDevice
(
c
->
device
());
dispatcher
().
lookup
(
c
->
device
().
getType
())(
c
,
a
,
b
);
}
}
Tensor
add
(
Tensor
a
,
Tensor
b
)
{
Tensor
add
(
Tensor
a
,
Tensor
b
)
{
...
...
src/infinicore/ops/attention/attention.cc
View file @
8c16b808
#include "infinicore/ops/attention.hpp"
#include "infinicore/ops/attention.hpp"
#include "../../utils.hpp"
namespace
infinicore
::
op
{
namespace
infinicore
::
op
{
...
@@ -8,7 +9,9 @@ common::OpDispatcher<Attention::schema> &Attention::dispatcher() {
...
@@ -8,7 +9,9 @@ common::OpDispatcher<Attention::schema> &Attention::dispatcher() {
};
};
void
Attention
::
execute
(
Tensor
out
,
Tensor
q
,
Tensor
k
,
Tensor
v
,
Tensor
k_cache
,
Tensor
v_cache
,
size_t
pos
)
{
void
Attention
::
execute
(
Tensor
out
,
Tensor
q
,
Tensor
k
,
Tensor
v
,
Tensor
k_cache
,
Tensor
v_cache
,
size_t
pos
)
{
dispatcher
().
lookup
(
context
::
getDevice
().
getType
())(
out
,
q
,
k
,
v
,
k_cache
,
v_cache
,
pos
);
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
out
,
q
,
k
,
v
,
k_cache
,
v_cache
);
infinicore
::
context
::
setDevice
(
out
->
device
());
dispatcher
().
lookup
(
out
->
device
().
getType
())(
out
,
q
,
k
,
v
,
k_cache
,
v_cache
,
pos
);
}
}
Tensor
attention
(
Tensor
q
,
Tensor
k
,
Tensor
v
,
Tensor
k_cache
,
Tensor
v_cache
,
size_t
pos
)
{
Tensor
attention
(
Tensor
q
,
Tensor
k
,
Tensor
v
,
Tensor
k_cache
,
Tensor
v_cache
,
size_t
pos
)
{
...
...
src/infinicore/ops/causal_softmax/causal_softmax.cc
View file @
8c16b808
#include "infinicore/ops/causal_softmax.hpp"
#include "infinicore/ops/causal_softmax.hpp"
#include "../../utils.hpp"
#include <stdexcept>
#include <stdexcept>
namespace
infinicore
::
op
{
namespace
infinicore
::
op
{
...
@@ -9,7 +12,9 @@ common::OpDispatcher<CausalSoftmax::schema> &CausalSoftmax::dispatcher() {
...
@@ -9,7 +12,9 @@ common::OpDispatcher<CausalSoftmax::schema> &CausalSoftmax::dispatcher() {
};
};
void
CausalSoftmax
::
execute
(
Tensor
output
,
Tensor
input
)
{
void
CausalSoftmax
::
execute
(
Tensor
output
,
Tensor
input
)
{
auto
device_type
=
context
::
getDevice
().
getType
();
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
output
,
input
);
infinicore
::
context
::
setDevice
(
output
->
device
());
auto
device_type
=
output
->
device
().
getType
();
auto
func
=
dispatcher
().
lookup
(
device_type
);
auto
func
=
dispatcher
().
lookup
(
device_type
);
if
(
func
==
nullptr
)
{
if
(
func
==
nullptr
)
{
...
...
src/infinicore/ops/gemm/gemm.cc
View file @
8c16b808
#include "infinicore/ops/gemm.hpp"
#include "infinicore/ops/gemm.hpp"
#include "../../utils.hpp"
namespace
infinicore
::
op
{
namespace
infinicore
::
op
{
common
::
OpDispatcher
<
Gemm
::
schema
>
&
Gemm
::
dispatcher
()
{
common
::
OpDispatcher
<
Gemm
::
schema
>
&
Gemm
::
dispatcher
()
{
...
@@ -8,7 +10,9 @@ common::OpDispatcher<Gemm::schema> &Gemm::dispatcher() {
...
@@ -8,7 +10,9 @@ common::OpDispatcher<Gemm::schema> &Gemm::dispatcher() {
};
};
void
Gemm
::
execute
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
)
{
void
Gemm
::
execute
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
)
{
dispatcher
().
lookup
(
context
::
getDevice
().
getType
())(
c
,
a
,
b
,
alpha
,
beta
);
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
c
,
a
,
b
);
infinicore
::
context
::
setDevice
(
c
->
device
());
dispatcher
().
lookup
(
c
->
device
().
getType
())(
c
,
a
,
b
,
alpha
,
beta
);
}
}
Tensor
gemm
(
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
)
{
Tensor
gemm
(
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
)
{
...
...
src/infinicore/ops/mul/mul.cc
View file @
8c16b808
#include "infinicore/ops/mul.hpp"
#include "infinicore/ops/mul.hpp"
#include "../../utils.hpp"
namespace
infinicore
::
op
{
namespace
infinicore
::
op
{
common
::
OpDispatcher
<
Mul
::
schema
>
&
Mul
::
dispatcher
()
{
common
::
OpDispatcher
<
Mul
::
schema
>
&
Mul
::
dispatcher
()
{
...
@@ -8,7 +10,9 @@ common::OpDispatcher<Mul::schema> &Mul::dispatcher() {
...
@@ -8,7 +10,9 @@ common::OpDispatcher<Mul::schema> &Mul::dispatcher() {
};
};
void
Mul
::
execute
(
Tensor
c
,
Tensor
a
,
Tensor
b
)
{
void
Mul
::
execute
(
Tensor
c
,
Tensor
a
,
Tensor
b
)
{
dispatcher
().
lookup
(
context
::
getDevice
().
getType
())(
c
,
a
,
b
);
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
c
,
a
,
b
);
infinicore
::
context
::
setDevice
(
c
->
device
());
dispatcher
().
lookup
(
c
->
device
().
getType
())(
c
,
a
,
b
);
}
}
Tensor
mul
(
Tensor
a
,
Tensor
b
)
{
Tensor
mul
(
Tensor
a
,
Tensor
b
)
{
...
...
src/infinicore/ops/random_sample/random_sample.cc
View file @
8c16b808
#include "infinicore/ops/random_sample.hpp"
#include "infinicore/ops/random_sample.hpp"
#include "../../utils.hpp"
namespace
infinicore
::
op
{
namespace
infinicore
::
op
{
common
::
OpDispatcher
<
RandomSample
::
schema
>
&
RandomSample
::
dispatcher
()
{
common
::
OpDispatcher
<
RandomSample
::
schema
>
&
RandomSample
::
dispatcher
()
{
...
@@ -10,7 +12,9 @@ common::OpDispatcher<RandomSample::schema> &RandomSample::dispatcher() {
...
@@ -10,7 +12,9 @@ common::OpDispatcher<RandomSample::schema> &RandomSample::dispatcher() {
void
RandomSample
::
execute
(
void
RandomSample
::
execute
(
Tensor
indices
,
Tensor
logits
,
Tensor
indices
,
Tensor
logits
,
float
random_val
,
float
topp
,
int
topk
,
float
temperature
)
{
float
random_val
,
float
topp
,
int
topk
,
float
temperature
)
{
dispatcher
().
lookup
(
context
::
getDevice
().
getType
())(
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
indices
,
logits
);
infinicore
::
context
::
setDevice
(
logits
->
device
());
dispatcher
().
lookup
(
logits
->
device
().
getType
())(
indices
,
logits
,
random_val
,
topp
,
topk
,
temperature
);
indices
,
logits
,
random_val
,
topp
,
topk
,
temperature
);
}
}
...
...
src/infinicore/ops/rms_norm/rms_norm.cc
View file @
8c16b808
#include "infinicore/ops/rms_norm.hpp"
#include "infinicore/ops/rms_norm.hpp"
#include "../../utils.hpp"
namespace
infinicore
::
op
{
namespace
infinicore
::
op
{
common
::
OpDispatcher
<
RMSNorm
::
schema
>
&
RMSNorm
::
dispatcher
()
{
common
::
OpDispatcher
<
RMSNorm
::
schema
>
&
RMSNorm
::
dispatcher
()
{
...
@@ -8,7 +10,9 @@ common::OpDispatcher<RMSNorm::schema> &RMSNorm::dispatcher() {
...
@@ -8,7 +10,9 @@ common::OpDispatcher<RMSNorm::schema> &RMSNorm::dispatcher() {
};
};
void
RMSNorm
::
execute
(
Tensor
y
,
Tensor
x
,
Tensor
weight
,
float
epsilon
)
{
void
RMSNorm
::
execute
(
Tensor
y
,
Tensor
x
,
Tensor
weight
,
float
epsilon
)
{
dispatcher
().
lookup
(
context
::
getDevice
().
getType
())(
y
,
x
,
weight
,
epsilon
);
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
y
,
x
,
weight
);
infinicore
::
context
::
setDevice
(
y
->
device
());
dispatcher
().
lookup
(
y
->
device
().
getType
())(
y
,
x
,
weight
,
epsilon
);
}
}
Tensor
rms_norm
(
Tensor
x
,
Tensor
weight
,
float
epsilon
)
{
Tensor
rms_norm
(
Tensor
x
,
Tensor
weight
,
float
epsilon
)
{
...
...
src/infinicore/ops/rope/rope.cc
View file @
8c16b808
#include "infinicore/ops/rope.hpp"
#include "infinicore/ops/rope.hpp"
#include "../../utils.hpp"
#include "infinicore/context/context.hpp"
#include "infinicore/context/context.hpp"
#include <stdexcept>
#include <stdexcept>
namespace
infinicore
::
op
{
namespace
infinicore
::
op
{
...
@@ -10,7 +13,9 @@ common::OpDispatcher<RoPE::schema> &RoPE::dispatcher() {
...
@@ -10,7 +13,9 @@ common::OpDispatcher<RoPE::schema> &RoPE::dispatcher() {
};
};
void
RoPE
::
execute
(
Tensor
x_out
,
const
Tensor
&
x
,
const
Tensor
&
pos
,
const
Tensor
&
sin_table
,
const
Tensor
&
cos_table
,
infinicore
::
nn
::
RoPE
::
Algo
algo
)
{
void
RoPE
::
execute
(
Tensor
x_out
,
const
Tensor
&
x
,
const
Tensor
&
pos
,
const
Tensor
&
sin_table
,
const
Tensor
&
cos_table
,
infinicore
::
nn
::
RoPE
::
Algo
algo
)
{
auto
device_type
=
context
::
getDevice
().
getType
();
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
x_out
,
x
,
pos
,
sin_table
,
cos_table
);
infinicore
::
context
::
setDevice
(
x_out
->
device
());
auto
device_type
=
x_out
->
device
().
getType
();
auto
func
=
dispatcher
().
lookup
(
device_type
);
auto
func
=
dispatcher
().
lookup
(
device_type
);
if
(
func
==
nullptr
)
{
if
(
func
==
nullptr
)
{
...
...
src/infinicore/ops/silu/silu.cc
View file @
8c16b808
#include "infinicore/ops/silu.hpp"
#include "infinicore/ops/silu.hpp"
#include "../../utils.hpp"
#include <stdexcept>
#include <stdexcept>
namespace
infinicore
::
op
{
namespace
infinicore
::
op
{
...
@@ -9,7 +12,9 @@ common::OpDispatcher<Silu::schema> &Silu::dispatcher() {
...
@@ -9,7 +12,9 @@ common::OpDispatcher<Silu::schema> &Silu::dispatcher() {
};
};
void
Silu
::
execute
(
Tensor
output
,
Tensor
input
)
{
void
Silu
::
execute
(
Tensor
output
,
Tensor
input
)
{
auto
device_type
=
context
::
getDevice
().
getType
();
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
output
,
input
);
infinicore
::
context
::
setDevice
(
output
->
device
());
auto
device_type
=
output
->
device
().
getType
();
auto
func
=
dispatcher
().
lookup
(
device_type
);
auto
func
=
dispatcher
().
lookup
(
device_type
);
if
(
func
==
nullptr
)
{
if
(
func
==
nullptr
)
{
...
...
src/infinicore/ops/swiglu/swiglu.cc
View file @
8c16b808
#include "infinicore/ops/swiglu.hpp"
#include "infinicore/ops/swiglu.hpp"
#include "../../utils.hpp"
#include <stdexcept>
#include <stdexcept>
namespace
infinicore
::
op
{
namespace
infinicore
::
op
{
...
@@ -9,7 +12,9 @@ common::OpDispatcher<SwiGLU::schema> &SwiGLU::dispatcher() {
...
@@ -9,7 +12,9 @@ common::OpDispatcher<SwiGLU::schema> &SwiGLU::dispatcher() {
};
};
void
SwiGLU
::
execute
(
Tensor
c
,
Tensor
a
,
Tensor
b
)
{
void
SwiGLU
::
execute
(
Tensor
c
,
Tensor
a
,
Tensor
b
)
{
auto
device_type
=
context
::
getDevice
().
getType
();
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
c
,
a
,
b
);
infinicore
::
context
::
setDevice
(
c
->
device
());
auto
device_type
=
c
->
device
().
getType
();
auto
func
=
dispatcher
().
lookup
(
device_type
);
auto
func
=
dispatcher
().
lookup
(
device_type
);
if
(
func
==
nullptr
)
{
if
(
func
==
nullptr
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment