Commit 81e5fe94 authored by PanZezhong's avatar PanZezhong Committed by wooway777
Browse files

issue/810 support more ops as graph op

parent 0611cb1b
......@@ -15,10 +15,15 @@ public:
};
class GraphOperator {
public:
virtual void run() const = 0;
virtual ~GraphOperator() = default;
};
class DispatchableGraphOperator : public GraphOperator {
public:
void run() const;
~GraphOperator();
void run() const override;
~DispatchableGraphOperator() override;
protected:
using run_schema = void (*)(void *);
......@@ -49,7 +54,7 @@ private:
} // namespace infinicore::graph
#define INFINICORE_GRAPH_OP_CLASS(__OP_NAME__, ...) \
class __OP_NAME__ : public graph::GraphOperator { \
class __OP_NAME__ : public graph::DispatchableGraphOperator { \
public: \
using schema = void (*)(__VA_ARGS__); \
using plan_schema = void *(*)(__VA_ARGS__); \
......@@ -80,11 +85,11 @@ private:
deleter_ = cleanup_dispatcher().lookup(__DEVICE_TYPE__);
#define INFINICORE_GRAPH_OP_RECORD_OR_RUN(__OP_NAME__, ...) \
auto op = std::make_shared<__OP_NAME__>(__VA_ARGS__); \
auto ___op = std::make_shared<__OP_NAME__>(__VA_ARGS__); \
if (context::isGraphRecording()) { \
context::addGraphOperator(op); \
context::addGraphOperator(___op); \
} else { \
op->run(); \
___op->run(); \
}
#define INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(__OP_NAME__, __PLAN_F__, __RUN_F__, __CLEANUP_F__) \
......
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
namespace infinicore::op {
class Add {
public:
using schema = void (*)(Tensor, Tensor, Tensor);
static void execute(Tensor c, Tensor a, Tensor b);
static common::OpDispatcher<schema> &dispatcher();
};
Tensor add(Tensor a, Tensor b);
void add_(Tensor c, Tensor a, Tensor b);
Tensor operator+(Tensor a, Tensor b);
INFINICORE_GRAPH_OP_CLASS(Add, Tensor, const Tensor &, const Tensor &);
Tensor add(const Tensor &a, const Tensor &b);
void add_(Tensor c, const Tensor &a, const Tensor &b);
} // namespace infinicore::op
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
namespace infinicore::op {
class CausalSoftmax {
public:
using schema = void (*)(Tensor, Tensor);
static void execute(Tensor output, Tensor input);
static common::OpDispatcher<schema> &dispatcher();
};
Tensor causal_softmax(Tensor input);
void causal_softmax_(Tensor output, Tensor input);
INFINICORE_GRAPH_OP_CLASS(CausalSoftmax, Tensor, const Tensor &);
Tensor causal_softmax(const Tensor &input);
void causal_softmax_(Tensor output, const Tensor &input);
} // namespace infinicore::op
#pragma once
#include "../../device.hpp"
#include "../../graph/graph.hpp"
#include "../common/op.hpp"
#include <infiniccl.h>
namespace infinicore::op::distributed {
class AllReduce : public graph::GraphOperator {
public:
AllReduce(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator);
~AllReduce();
void run() const override;
static void execute(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator);
private:
void *planned_meta_;
};
Tensor allreduce(const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator);
void allreduce_(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator);
} // namespace infinicore::op::distributed
......@@ -6,9 +6,9 @@
namespace infinicore::op {
INFINICORE_GRAPH_OP_CLASS(Gemm, Tensor, Tensor, Tensor, float, float);
INFINICORE_GRAPH_OP_CLASS(Gemm, Tensor, const Tensor &, const Tensor &, float, float);
Tensor gemm(Tensor a, Tensor b, float alpha = 1.0f, float beta = 0.0f);
void gemm_(Tensor c, Tensor a, Tensor b, float alpha, float beta);
Tensor gemm(const Tensor &a, const Tensor &b, float alpha = 1.0f, float beta = 0.0f);
void gemm_(Tensor c, const Tensor &a, const Tensor &b, float alpha, float beta);
} // namespace infinicore::op
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
namespace infinicore::op {
class Mul {
public:
using schema = void (*)(Tensor, Tensor, Tensor);
static void execute(Tensor c, Tensor a, Tensor b);
static common::OpDispatcher<schema> &dispatcher();
};
Tensor mul(Tensor a, Tensor b);
void mul_(Tensor c, Tensor a, Tensor b);
INFINICORE_GRAPH_OP_CLASS(Mul, Tensor, const Tensor &, const Tensor &);
Tensor mul(const Tensor &a, const Tensor &b);
void mul_(Tensor c, const Tensor &a, const Tensor &b);
} // namespace infinicore::op
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
#include <optional>
namespace infinicore::op {
class PagedAttention {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, std::optional<Tensor>, float);
static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float);
static common::OpDispatcher<schema> &dispatcher();
};
INFINICORE_GRAPH_OP_CLASS(PagedAttention, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, std::optional<Tensor>, float);
Tensor paged_attention(const Tensor &q, const Tensor &k_cache, const Tensor &v_cache,
const Tensor &block_tables, const Tensor &kv_lens,
std::optional<Tensor> alibi_slopes, float scale);
void paged_attention_(Tensor out, const Tensor &q, const Tensor &k_cache, const Tensor &v_cache,
const Tensor &block_tables, const Tensor &kv_lens,
std::optional<Tensor> alibi_slopes, float scale);
Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale);
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale);
} // namespace infinicore::op
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
namespace infinicore::op {
class PagedCaching {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor);
static void execute(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping);
static common::OpDispatcher<schema> &dispatcher();
};
INFINICORE_GRAPH_OP_CLASS(PagedCaching, Tensor, Tensor, const Tensor &, const Tensor &, const Tensor &);
void paged_caching_(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping);
void paged_caching_(Tensor k_cache, Tensor v_cache, const Tensor &k, const Tensor &v, const Tensor &slot_mapping);
} // namespace infinicore::op
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
namespace infinicore::op {
class Rearrange {
public:
using schema = void (*)(Tensor, Tensor);
static void execute(Tensor y, Tensor x);
static common::OpDispatcher<schema> &dispatcher();
};
Tensor rearrange(Tensor x);
void rearrange_(Tensor y, Tensor x);
INFINICORE_GRAPH_OP_CLASS(Rearrange, Tensor, const Tensor &);
Tensor rearrange(const Tensor &x);
void rearrange_(Tensor y, const Tensor &x);
} // namespace infinicore::op
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
namespace infinicore::op {
class RMSNorm {
public:
using schema = void (*)(Tensor, Tensor, Tensor, float);
static void execute(Tensor y, Tensor x, Tensor weight, float epsilon = 1e-5f);
static common::OpDispatcher<schema> &dispatcher();
};
Tensor rms_norm(Tensor x, Tensor weight, float epsilon = 1e-5f);
void rms_norm_(Tensor y, Tensor x, Tensor weight, float epsilon = 1e-5f);
INFINICORE_GRAPH_OP_CLASS(RMSNorm, Tensor, const Tensor &, const Tensor &, float);
Tensor rms_norm(const Tensor &x, const Tensor &weight, float epsilon = 1e-5f);
void rms_norm_(Tensor y, const Tensor &x, const Tensor &weight, float epsilon = 1e-5f);
} // namespace infinicore::op
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "../nn/rope.hpp"
#include "../tensor.hpp"
#include "common/op.hpp"
namespace infinicore::op {
class RoPE {
public:
using schema = void (*)(Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, infinicore::nn::RoPE::Algo);
static void execute(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo);
static common::OpDispatcher<schema> &dispatcher();
};
// Internal function
void rope_(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo);
INFINICORE_GRAPH_OP_CLASS(RoPE, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, infinicore::nn::RoPE::Algo);
// Internal
void rope_(Tensor x_out,
const Tensor &x,
const Tensor &pos,
const Tensor &sin_table,
const Tensor &cos_table,
infinicore::nn::RoPE::Algo algo);
// Public API
Tensor rope(const Tensor &x,
const Tensor &pos,
const Tensor &sin_table,
const Tensor &cos_table,
infinicore::nn::RoPE::Algo algo);
// Public API that uses infinicore::nn::RoPE::Algo
Tensor rope(const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo);
} // namespace infinicore::op
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "../tensor.hpp"
#include "common/op.hpp"
namespace infinicore::op {
class SwiGLU {
public:
using schema = void (*)(Tensor, Tensor, Tensor);
static void execute(Tensor c, Tensor a, Tensor b);
static common::OpDispatcher<schema> &dispatcher();
};
Tensor swiglu(Tensor a, Tensor b);
void swiglu_(Tensor c, Tensor a, Tensor b);
INFINICORE_GRAPH_OP_CLASS(SwiGLU, Tensor, const Tensor &, const Tensor &);
Tensor swiglu(const Tensor &a, const Tensor &b);
void swiglu_(Tensor c, const Tensor &a, const Tensor &b);
} // namespace infinicore::op
......@@ -17,11 +17,11 @@ GraphTensor::GraphTensor(const Tensor &tensor) : Tensor(tensor->to_blob_()) {
* GraphOperator
* ========================= */
void GraphOperator::run() const {
void DispatchableGraphOperator::run() const {
runner_(planned_meta_);
}
GraphOperator::~GraphOperator() {
DispatchableGraphOperator::~DispatchableGraphOperator() {
if (deleter_) {
deleter_(&planned_meta_);
}
......
#include "infinicore/nn/linear.hpp"
#include "../utils.hpp"
#include "infinicore/ops.hpp"
#include "infinicore/ops/distributed/allreduce.hpp"
#include "infinicore/ops/linear.hpp"
#include <optional>
#include <spdlog/spdlog.h>
......@@ -102,9 +103,6 @@ ColumnParallelLinear::ColumnParallelLinear(size_t in_features, size_t out_featur
} else {
bias_ = Parameter(); // Default constructed empty parameter
}
// SPDLOG_DEBUG("Created ColumnParallelLinear module: in_features={}, out_features={}, bias={}, dtype={}",
// in_features, out_features, bias, static_cast<int>(dtype_));
}
Tensor ColumnParallelLinear::forward(Tensor &input) const {
......@@ -138,26 +136,13 @@ RowParallelLinear::RowParallelLinear(size_t in_features, size_t out_features, bo
} else {
bias_ = Parameter(); // Default constructed empty parameter
}
// SPDLOG_DEBUG("Created RowParallelLinear module: in_features={}, out_features={}, bias={}, dtype={}",
// in_features, out_features, bias, static_cast<int>(dtype_));
}
Tensor RowParallelLinear::forward(Tensor &input) const {
auto output = BaseLinear::forward(input);
if ((tp_size_ > 1) && (communicator_ != nullptr)) {
Size count = output->numel();
DataType type = output->dtype();
infinirtStream_t stream = infinicore::context::getStream();
INFINICORE_CHECK_ERROR(infinicclAllReduce(output->data(), output->data(), count, static_cast<infiniDtype_t>(static_cast<int>(type)),
INFINICCL_SUM, communicator_, stream));
INFINICORE_CHECK_ERROR(infinirtStreamSynchronize(stream));
// RUN_INFINI(infinirtStreamSynchronize(stream));
op::distributed::allreduce_(output, output, INFINICCL_SUM, communicator_);
}
return output;
}
......
......@@ -3,24 +3,24 @@
namespace infinicore::op {
common::OpDispatcher<Add::schema> &Add::dispatcher() {
static common::OpDispatcher<Add::schema> dispatcher_;
return dispatcher_;
};
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Add);
void Add::execute(Tensor c, Tensor a, Tensor b) {
Add::Add(Tensor c, const Tensor &a, const Tensor &b) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b);
infinicore::context::setDevice(c->device());
dispatcher().lookup(c->device().getType())(c, a, b);
INFINICORE_GRAPH_OP_DISPATCH(c->device().getType(), c, a, b);
}
Tensor add(Tensor a, Tensor b) {
void Add::execute(Tensor c, const Tensor &a, const Tensor &b) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(Add, c, a, b);
}
Tensor add(const Tensor &a, const Tensor &b) {
auto c = Tensor::empty(a->shape(), a->dtype(), a->device());
add_(c, a, b);
return c;
}
void add_(Tensor c, Tensor a, Tensor b) {
void add_(Tensor c, const Tensor &a, const Tensor &b) {
Add::execute(c, a, b);
}
......
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/add.hpp"
#include "infinicore/ops/common/cache.hpp"
#include <infiniop.h>
#include "../infiniop_impl.hpp"
namespace infinicore::op::add_impl::infiniop {
thread_local common::OpCache<size_t, infiniopAddDescriptor_t> caches(
100, // capacity
[](infiniopAddDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyAddDescriptor(desc));
desc = nullptr;
}
});
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Add, 100);
struct PlannedMeta {
std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor workspace, c, a, b;
};
void calculate(Tensor c, Tensor a, Tensor b) {
void *plan(Tensor c, const Tensor &a, const Tensor &b) {
size_t seed = hash_combine(c, b, a);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
Descriptor, descriptor, Add,
seed,
c->desc(), a->desc(), b->desc());
auto desc_opt = cache.get(seed);
infiniopAddDescriptor_t desc = nullptr;
INFINIOP_WORKSPACE_TENSOR(workspace, Add, descriptor);
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateAddDescriptor(
context::getInfiniopHandle(device), &desc,
c->desc(), a->desc(), b->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}
return new PlannedMeta{
descriptor,
graph::GraphTensor(workspace),
graph::GraphTensor(c),
graph::GraphTensor(a),
graph::GraphTensor(b)};
}
size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetAddWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
void run(void *planned_meta) {
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
INFINICORE_CHECK_ERROR(infiniopAdd(
desc, workspace->data(), workspace_size,
c->data(), a->data(), b->data(), context::getStream()));
planned->descriptor->desc,
planned->workspace->data(),
planned->workspace->numel(),
planned->c->data(),
planned->a->data(),
planned->b->data(),
context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}
static bool registered = []() {
Add::dispatcher().registerAll(&calculate, false);
return true;
}();
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Add, &plan, &run, &cleanup);
} // namespace infinicore::op::add_impl::infiniop
#include "infinicore/ops/causal_softmax.hpp"
#include "../../utils.hpp"
#include <stdexcept>
namespace infinicore::op {
common::OpDispatcher<CausalSoftmax::schema> &CausalSoftmax::dispatcher() {
static common::OpDispatcher<CausalSoftmax::schema> dispatcher_;
return dispatcher_;
};
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(CausalSoftmax);
void CausalSoftmax::execute(Tensor output, Tensor input) {
CausalSoftmax::CausalSoftmax(Tensor output, const Tensor &input) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(output, input);
infinicore::context::setDevice(output->device());
auto device_type = output->device().getType();
auto func = dispatcher().lookup(device_type);
if (func == nullptr) {
throw std::runtime_error("No CausalSoftmax implementation found for device type: " + std::to_string(static_cast<int>(device_type)));
}
INFINICORE_GRAPH_OP_DISPATCH(output->device().getType(), output, input);
}
func(output, input);
void CausalSoftmax::execute(Tensor output, const Tensor &input) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(CausalSoftmax, output, input);
}
Tensor causal_softmax(Tensor input) {
Shape shape = input->shape();
auto output = Tensor::empty(shape, input->dtype(), input->device());
Tensor causal_softmax(const Tensor &input) {
auto output = Tensor::empty(input->shape(), input->dtype(), input->device());
causal_softmax_(output, input);
return output;
}
void causal_softmax_(Tensor output, Tensor input) {
void causal_softmax_(Tensor output, const Tensor &input) {
CausalSoftmax::execute(output, input);
}
} // namespace infinicore::op
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/causal_softmax.hpp"
#include "infinicore/ops/common/cache.hpp"
#include <infiniop.h>
#include "../infiniop_impl.hpp"
namespace infinicore::op::causal_softmax_impl::infiniop {
thread_local common::OpCache<size_t, infiniopCausalSoftmaxDescriptor_t> caches(
100, // capacity
[](infiniopCausalSoftmaxDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyCausalSoftmaxDescriptor(desc));
desc = nullptr;
}
});
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, CausalSoftmax, 100);
struct PlannedMeta {
std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor workspace, output, input;
};
void calculate(Tensor output, Tensor input) {
void *plan(Tensor output, const Tensor &input) {
size_t seed = hash_combine(output, input);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
Descriptor, descriptor, CausalSoftmax,
seed, output->desc(), input->desc());
auto desc_opt = cache.get(seed);
infiniopCausalSoftmaxDescriptor_t desc = nullptr;
INFINIOP_WORKSPACE_TENSOR(workspace, CausalSoftmax, descriptor);
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateCausalSoftmaxDescriptor(
context::getInfiniopHandle(device), &desc,
output->desc(), input->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}
return new PlannedMeta{
descriptor,
graph::GraphTensor(workspace),
graph::GraphTensor(output),
graph::GraphTensor(input)};
}
size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetCausalSoftmaxWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
void run(void *planned_meta) {
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
INFINICORE_CHECK_ERROR(infiniopCausalSoftmax(
desc, workspace->data(), workspace_size,
output->data(), input->data(), context::getStream()));
planned->descriptor->desc,
planned->workspace->data(),
planned->workspace->numel(),
planned->output->data(),
planned->input->data(),
context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}
static bool registered = []() {
CausalSoftmax::dispatcher().registerAll(&calculate, false);
return true;
}();
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(CausalSoftmax, &plan, &run, &cleanup);
} // namespace infinicore::op::causal_softmax_impl::infiniop
#include "infinicore/ops/distributed/allreduce.hpp"
#include "../../utils.hpp"
namespace infinicore::op::distributed {
struct PlannedMeta {
graph::GraphTensor output, input;
infinicclReduceOp_t op;
infinicclComm_t communicator;
};
AllReduce::AllReduce(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(output, input);
INFINICORE_ASSERT(output->is_contiguous() && input->is_contiguous());
INFINICORE_ASSERT(output->numel() == input->numel());
planned_meta_ = new PlannedMeta{graph::GraphTensor(output), graph::GraphTensor(input), op, communicator};
}
AllReduce::~AllReduce() {
if (planned_meta_) {
PlannedMeta *meta = reinterpret_cast<PlannedMeta *>(planned_meta_);
delete meta;
}
}
void AllReduce::run() const {
PlannedMeta *meta = reinterpret_cast<PlannedMeta *>(planned_meta_);
INFINICORE_CHECK_ERROR(infinicclAllReduce(meta->input->data(),
meta->output->data(),
meta->input->numel(),
static_cast<infiniDtype_t>(static_cast<int>(meta->input->dtype())),
meta->op,
meta->communicator,
infinicore::context::getStream()));
}
void AllReduce::execute(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(AllReduce, output, input, op, communicator);
}
Tensor allreduce(const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator) {
auto output = Tensor::empty(input->shape(), input->dtype(), input->device());
allreduce_(output, input, op, communicator);
return output;
}
void allreduce_(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator) {
AllReduce::execute(output, input, op, communicator);
}
} // namespace infinicore::op::distributed
......@@ -5,16 +5,16 @@
namespace infinicore::op {
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Gemm);
Gemm::Gemm(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
Gemm::Gemm(Tensor c, const Tensor &a, const Tensor &b, float alpha, float beta) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b);
INFINICORE_GRAPH_OP_DISPATCH(c->device().getType(), c, a, b, alpha, beta);
}
void Gemm::execute(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
void Gemm::execute(Tensor c, const Tensor &a, const Tensor &b, float alpha, float beta) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(Gemm, c, a, b, alpha, beta);
}
Tensor gemm(Tensor a, Tensor b, float alpha, float beta) {
Tensor gemm(const Tensor &a, const Tensor &b, float alpha, float beta) {
Shape shape = a->shape();
Size size = a->ndim();
shape[size - 1] = b->size(size - 1);
......@@ -23,7 +23,7 @@ Tensor gemm(Tensor a, Tensor b, float alpha, float beta) {
return c;
}
void gemm_(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
void gemm_(Tensor c, const Tensor &a, const Tensor &b, float alpha, float beta) {
Gemm::execute(c, a, b, alpha, beta);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment