Unverified Commit 8c16b808 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #701 from InfiniTensor/issue/700

issue/700  算子执行时根据张量 set device, CPU时无操作
parents 57291db6 fa149d69
...@@ -14,7 +14,7 @@ infinicore::ops 模块包含了 InfiniCore 所有 C++ 算子的接口和实现 ...@@ -14,7 +14,7 @@ infinicore::ops 模块包含了 InfiniCore 所有 C++ 算子的接口和实现
- execute 函数,算子的计算逻辑。 - execute 函数,算子的计算逻辑。
- dispatcher 分发器,用于注册算子在不同设备上的 kernel 实现。一个进程中,一种算子只有一个全局分发器,每种设备上只能同时注册一个 kernel 实现,可以多次注册对之前的实现进行覆盖。详细信息请参考 `include/infinicore/ops/common/dispatcher.hpp` - dispatcher 分发器,用于注册算子在不同设备上的 kernel 实现。一个进程中,一种算子只有一个全局分发器,每种设备上只能同时注册一个 kernel 实现,可以多次注册对之前的实现进行覆盖。详细信息请参考 `include/infinicore/ops/common/dispatcher.hpp`
示例 `Matmul` 算子的头文件如下: 示例 `Gemm` 算子的头文件如下:
```c++ ```c++
#pragma once #pragma once
...@@ -23,15 +23,17 @@ infinicore::ops 模块包含了 InfiniCore 所有 C++ 算子的接口和实现 ...@@ -23,15 +23,17 @@ infinicore::ops 模块包含了 InfiniCore 所有 C++ 算子的接口和实现
#include "common/op.hpp" #include "common/op.hpp"
namespace infinicore::op { namespace infinicore::op {
class Matmul {
class Gemm {
public: public:
using schema = void (*)(Tensor, Tensor, Tensor); using schema = void (*)(Tensor, Tensor, Tensor, float, float);
static void execute(Tensor c, Tensor a, Tensor b); static void execute(Tensor c, Tensor a, Tensor b, float alpha, float beta);
static common::OpDispatcher<schema> &dispatcher(); static common::OpDispatcher<schema> &dispatcher();
}; };
Tensor matmul(Tensor a, Tensor b); Tensor gemm(Tensor a, Tensor b, float alpha = 1.0f, float beta = 0.0f);
void matmul_(Tensor c, Tensor a, Tensor b); void gemm_(Tensor c, Tensor a, Tensor b, float alpha, float beta);
} }
``` ```
...@@ -39,37 +41,45 @@ void matmul_(Tensor c, Tensor a, Tensor b); ...@@ -39,37 +41,45 @@ void matmul_(Tensor c, Tensor a, Tensor b);
`src/infinicore/ops/*OPNAME*/*OPNAME*.cpp` 文件中实现算子的计算逻辑。 `src/infinicore/ops/*OPNAME*/*OPNAME*.cpp` 文件中实现算子的计算逻辑。
- execute 函数,使用算子的分发器,调用对应硬件上的核函数。 - execute 函数,使用算子的分发器,调用对应硬件上的核函数。可以通过 `context::setDevice` 来改变当前运行时的设备种类。
- 计算接口,使用 execute 函数实现算子接口的计算逻辑,包括 in-place 和 out-of-place 两种模式,其中 in-place 模式的接口函数名以 `_` 结尾,将输出接口写入给定的参数中;out-of-place 模式的接口会为输出创建新的 Tensor。 - 计算接口,使用 execute 函数实现算子接口的计算逻辑,包括 in-place 和 out-of-place 两种模式,其中 in-place 模式的接口函数名以 `_` 结尾,将输出接口写入给定的参数中;out-of-place 模式的接口会为输出创建新的 Tensor。
示例 `Matmul` 算子的实现如下: 示例 `Gemm` 算子的实现如下:
```c++ ```c++
#include "infinicore/ops/matmul.hpp" #include "infinicore/ops/gemm.hpp"
#include "../../utils.hpp"
namespace infinicore::op { namespace infinicore::op {
common::OpDispatcher<Matmul::schema> &Matmul::dispatcher() { common::OpDispatcher<Gemm::schema> &Gemm::dispatcher() {
static common::OpDispatcher<Matmul::schema> dispatcher_; static common::OpDispatcher<Gemm::schema> dispatcher_;
return dispatcher_; return dispatcher_;
}; };
void Matmul::execute(Tensor c, Tensor a, Tensor b) { void Gemm::execute(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
dispatcher().lookup(context::getDevice().getType())(c, a, b); // 检查张量设备是否一致
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b);
// 将运行时设备设置为与张量一致。若设备为CPU时,该接口不会进行任何操作
infinicore::context::setDevice(c->device());
// 根据张量的设备种类选择 kernel,执行计算
dispatcher().lookup(c->device().getType())(c, a, b, alpha, beta);
} }
Tensor matmul(Tensor a, Tensor b) { Tensor gemm(Tensor a, Tensor b, float alpha, float beta) {
Shape shape = a->shape(); Shape shape = a->shape();
Size size = a->ndim(); Size size = a->ndim();
shape[size - 1] = b->size(size - 1); shape[size - 1] = b->size(size - 1);
auto c = Tensor::empty(shape, a->dtype(), a->device()); auto c = Tensor::empty(shape, a->dtype(), a->device());
matmul_(c, a, b); gemm_(c, a, b, alpha, beta);
return c; return c;
} }
void matmul_(Tensor c, Tensor a, Tensor b) { void gemm_(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
Matmul::execute(c, a, b); Gemm::execute(c, a, b, alpha, beta);
} }
} }
``` ```
...@@ -91,7 +101,7 @@ void registerAll(Fn fn, bool override_existing = true); ...@@ -91,7 +101,7 @@ void registerAll(Fn fn, bool override_existing = true);
Fn lookup(Device::Type device_type) const; Fn lookup(Device::Type device_type) const;
``` ```
如果你为多个(或全部)设备注册了同一个 kernel 实现,那么你需要自行实现不同设备的分发机制。比如本框架中的 InfiniOP 算子库,其算子接口在不同平台都保持了一致,并根据当前设备类型自动分发,因此在注册时会为所有平台注册同一个计算函数。以 Matmul 算子为例: 如果你为多个(或全部)设备注册了同一个 kernel 实现,那么你需要自行实现不同设备的分发机制。比如本框架中的 InfiniOP 算子库,其算子接口在不同平台都保持了一致,并根据当前设备类型自动分发,因此在注册时会为所有平台注册同一个计算函数。以 Gemm 算子为例:
```c++ ```c++
namespace infinicore::op::matmul_impl::infiniop { namespace infinicore::op::matmul_impl::infiniop {
...@@ -107,19 +117,18 @@ thread_local common::OpCache<size_t, infiniopGemmDescriptor_t> caches( ...@@ -107,19 +117,18 @@ thread_local common::OpCache<size_t, infiniopGemmDescriptor_t> caches(
}); });
// 计算函数 // 计算函数
void calculate(Tensor c, Tensor a, Tensor b){ void calculate(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
// ... // ...
INFINICORE_CHECK_ERROR(infiniopGemm( INFINICORE_CHECK_ERROR(infiniopGemm(
desc, workspace->data(), workspace_size, desc, workspace->data(), workspace_size,
c->data(), a->data(), b->data(), 1.f, 0.f, context::getStream())); c->data(), a->data(), b->data(), alpha, beta, context::getStream()));
} }
// 在加载 InfiniCore 时为全平台注册 InfiniOP实现 // 在加载 InfiniCore 时为全平台注册 InfiniOP实现
static bool registered = []() { static bool registered = []() {
Matmul::dispatcher().registerAll(&calculate, false); Gemm::dispatcher().registerAll(&calculate, false);
return true; return true;
}(); }();
} }
``` ```
......
#include "infinicore/ops/add.hpp" #include "infinicore/ops/add.hpp"
#include "../../utils.hpp"
namespace infinicore::op { namespace infinicore::op {
...@@ -8,7 +9,9 @@ common::OpDispatcher<Add::schema> &Add::dispatcher() { ...@@ -8,7 +9,9 @@ common::OpDispatcher<Add::schema> &Add::dispatcher() {
}; };
void Add::execute(Tensor c, Tensor a, Tensor b) { void Add::execute(Tensor c, Tensor a, Tensor b) {
dispatcher().lookup(context::getDevice().getType())(c, a, b); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b);
infinicore::context::setDevice(c->device());
dispatcher().lookup(c->device().getType())(c, a, b);
} }
Tensor add(Tensor a, Tensor b) { Tensor add(Tensor a, Tensor b) {
......
#include "infinicore/ops/attention.hpp" #include "infinicore/ops/attention.hpp"
#include "../../utils.hpp"
namespace infinicore::op { namespace infinicore::op {
...@@ -8,7 +9,9 @@ common::OpDispatcher<Attention::schema> &Attention::dispatcher() { ...@@ -8,7 +9,9 @@ common::OpDispatcher<Attention::schema> &Attention::dispatcher() {
}; };
void Attention::execute(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos) { void Attention::execute(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos) {
dispatcher().lookup(context::getDevice().getType())(out, q, k, v, k_cache, v_cache, pos); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k, v, k_cache, v_cache);
infinicore::context::setDevice(out->device());
dispatcher().lookup(out->device().getType())(out, q, k, v, k_cache, v_cache, pos);
} }
Tensor attention(Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos) { Tensor attention(Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos) {
......
#include "infinicore/ops/causal_softmax.hpp" #include "infinicore/ops/causal_softmax.hpp"
#include "../../utils.hpp"
#include <stdexcept> #include <stdexcept>
namespace infinicore::op { namespace infinicore::op {
...@@ -9,7 +12,9 @@ common::OpDispatcher<CausalSoftmax::schema> &CausalSoftmax::dispatcher() { ...@@ -9,7 +12,9 @@ common::OpDispatcher<CausalSoftmax::schema> &CausalSoftmax::dispatcher() {
}; };
void CausalSoftmax::execute(Tensor output, Tensor input) { void CausalSoftmax::execute(Tensor output, Tensor input) {
auto device_type = context::getDevice().getType(); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(output, input);
infinicore::context::setDevice(output->device());
auto device_type = output->device().getType();
auto func = dispatcher().lookup(device_type); auto func = dispatcher().lookup(device_type);
if (func == nullptr) { if (func == nullptr) {
......
#include "infinicore/ops/gemm.hpp" #include "infinicore/ops/gemm.hpp"
#include "../../utils.hpp"
namespace infinicore::op { namespace infinicore::op {
common::OpDispatcher<Gemm::schema> &Gemm::dispatcher() { common::OpDispatcher<Gemm::schema> &Gemm::dispatcher() {
...@@ -8,7 +10,9 @@ common::OpDispatcher<Gemm::schema> &Gemm::dispatcher() { ...@@ -8,7 +10,9 @@ common::OpDispatcher<Gemm::schema> &Gemm::dispatcher() {
}; };
void Gemm::execute(Tensor c, Tensor a, Tensor b, float alpha, float beta) { void Gemm::execute(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
dispatcher().lookup(context::getDevice().getType())(c, a, b, alpha, beta); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b);
infinicore::context::setDevice(c->device());
dispatcher().lookup(c->device().getType())(c, a, b, alpha, beta);
} }
Tensor gemm(Tensor a, Tensor b, float alpha, float beta) { Tensor gemm(Tensor a, Tensor b, float alpha, float beta) {
......
#include "infinicore/ops/mul.hpp" #include "infinicore/ops/mul.hpp"
#include "../../utils.hpp"
namespace infinicore::op { namespace infinicore::op {
common::OpDispatcher<Mul::schema> &Mul::dispatcher() { common::OpDispatcher<Mul::schema> &Mul::dispatcher() {
...@@ -8,7 +10,9 @@ common::OpDispatcher<Mul::schema> &Mul::dispatcher() { ...@@ -8,7 +10,9 @@ common::OpDispatcher<Mul::schema> &Mul::dispatcher() {
}; };
void Mul::execute(Tensor c, Tensor a, Tensor b) { void Mul::execute(Tensor c, Tensor a, Tensor b) {
dispatcher().lookup(context::getDevice().getType())(c, a, b); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b);
infinicore::context::setDevice(c->device());
dispatcher().lookup(c->device().getType())(c, a, b);
} }
Tensor mul(Tensor a, Tensor b) { Tensor mul(Tensor a, Tensor b) {
......
#include "infinicore/ops/random_sample.hpp" #include "infinicore/ops/random_sample.hpp"
#include "../../utils.hpp"
namespace infinicore::op { namespace infinicore::op {
common::OpDispatcher<RandomSample::schema> &RandomSample::dispatcher() { common::OpDispatcher<RandomSample::schema> &RandomSample::dispatcher() {
...@@ -10,7 +12,9 @@ common::OpDispatcher<RandomSample::schema> &RandomSample::dispatcher() { ...@@ -10,7 +12,9 @@ common::OpDispatcher<RandomSample::schema> &RandomSample::dispatcher() {
void RandomSample::execute( void RandomSample::execute(
Tensor indices, Tensor logits, Tensor indices, Tensor logits,
float random_val, float topp, int topk, float temperature) { float random_val, float topp, int topk, float temperature) {
dispatcher().lookup(context::getDevice().getType())( INFINICORE_ASSERT_TENSORS_SAME_DEVICE(indices, logits);
infinicore::context::setDevice(logits->device());
dispatcher().lookup(logits->device().getType())(
indices, logits, random_val, topp, topk, temperature); indices, logits, random_val, topp, topk, temperature);
} }
......
#include "infinicore/ops/rms_norm.hpp" #include "infinicore/ops/rms_norm.hpp"
#include "../../utils.hpp"
namespace infinicore::op { namespace infinicore::op {
common::OpDispatcher<RMSNorm::schema> &RMSNorm::dispatcher() { common::OpDispatcher<RMSNorm::schema> &RMSNorm::dispatcher() {
...@@ -8,7 +10,9 @@ common::OpDispatcher<RMSNorm::schema> &RMSNorm::dispatcher() { ...@@ -8,7 +10,9 @@ common::OpDispatcher<RMSNorm::schema> &RMSNorm::dispatcher() {
}; };
void RMSNorm::execute(Tensor y, Tensor x, Tensor weight, float epsilon) { void RMSNorm::execute(Tensor y, Tensor x, Tensor weight, float epsilon) {
dispatcher().lookup(context::getDevice().getType())(y, x, weight, epsilon); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, x, weight);
infinicore::context::setDevice(y->device());
dispatcher().lookup(y->device().getType())(y, x, weight, epsilon);
} }
Tensor rms_norm(Tensor x, Tensor weight, float epsilon) { Tensor rms_norm(Tensor x, Tensor weight, float epsilon) {
......
#include "infinicore/ops/rope.hpp" #include "infinicore/ops/rope.hpp"
#include "../../utils.hpp"
#include "infinicore/context/context.hpp" #include "infinicore/context/context.hpp"
#include <stdexcept> #include <stdexcept>
namespace infinicore::op { namespace infinicore::op {
...@@ -10,7 +13,9 @@ common::OpDispatcher<RoPE::schema> &RoPE::dispatcher() { ...@@ -10,7 +13,9 @@ common::OpDispatcher<RoPE::schema> &RoPE::dispatcher() {
}; };
void RoPE::execute(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo) { void RoPE::execute(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo) {
auto device_type = context::getDevice().getType(); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(x_out, x, pos, sin_table, cos_table);
infinicore::context::setDevice(x_out->device());
auto device_type = x_out->device().getType();
auto func = dispatcher().lookup(device_type); auto func = dispatcher().lookup(device_type);
if (func == nullptr) { if (func == nullptr) {
......
#include "infinicore/ops/silu.hpp" #include "infinicore/ops/silu.hpp"
#include "../../utils.hpp"
#include <stdexcept> #include <stdexcept>
namespace infinicore::op { namespace infinicore::op {
...@@ -9,7 +12,9 @@ common::OpDispatcher<Silu::schema> &Silu::dispatcher() { ...@@ -9,7 +12,9 @@ common::OpDispatcher<Silu::schema> &Silu::dispatcher() {
}; };
void Silu::execute(Tensor output, Tensor input) { void Silu::execute(Tensor output, Tensor input) {
auto device_type = context::getDevice().getType(); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(output, input);
infinicore::context::setDevice(output->device());
auto device_type = output->device().getType();
auto func = dispatcher().lookup(device_type); auto func = dispatcher().lookup(device_type);
if (func == nullptr) { if (func == nullptr) {
......
#include "infinicore/ops/swiglu.hpp" #include "infinicore/ops/swiglu.hpp"
#include "../../utils.hpp"
#include <stdexcept> #include <stdexcept>
namespace infinicore::op { namespace infinicore::op {
...@@ -9,7 +12,9 @@ common::OpDispatcher<SwiGLU::schema> &SwiGLU::dispatcher() { ...@@ -9,7 +12,9 @@ common::OpDispatcher<SwiGLU::schema> &SwiGLU::dispatcher() {
}; };
void SwiGLU::execute(Tensor c, Tensor a, Tensor b) { void SwiGLU::execute(Tensor c, Tensor a, Tensor b) {
auto device_type = context::getDevice().getType(); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b);
infinicore::context::setDevice(c->device());
auto device_type = c->device().getType();
auto func = dispatcher().lookup(device_type); auto func = dispatcher().lookup(device_type);
if (func == nullptr) { if (func == nullptr) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment