Commit 81e5fe94 authored by PanZezhong's avatar PanZezhong Committed by wooway777
Browse files

issue/810 support more ops as graph op

parent 0611cb1b
...@@ -15,10 +15,15 @@ public: ...@@ -15,10 +15,15 @@ public:
}; };
class GraphOperator { class GraphOperator {
public:
virtual void run() const = 0;
virtual ~GraphOperator() = default;
};
class DispatchableGraphOperator : public GraphOperator {
public: public:
void run() const; void run() const override;
~GraphOperator(); ~DispatchableGraphOperator() override;
protected: protected:
using run_schema = void (*)(void *); using run_schema = void (*)(void *);
...@@ -49,7 +54,7 @@ private: ...@@ -49,7 +54,7 @@ private:
} // namespace infinicore::graph } // namespace infinicore::graph
#define INFINICORE_GRAPH_OP_CLASS(__OP_NAME__, ...) \ #define INFINICORE_GRAPH_OP_CLASS(__OP_NAME__, ...) \
class __OP_NAME__ : public graph::GraphOperator { \ class __OP_NAME__ : public graph::DispatchableGraphOperator { \
public: \ public: \
using schema = void (*)(__VA_ARGS__); \ using schema = void (*)(__VA_ARGS__); \
using plan_schema = void *(*)(__VA_ARGS__); \ using plan_schema = void *(*)(__VA_ARGS__); \
...@@ -79,12 +84,12 @@ private: ...@@ -79,12 +84,12 @@ private:
runner_ = run_dispatcher().lookup(__DEVICE_TYPE__); \ runner_ = run_dispatcher().lookup(__DEVICE_TYPE__); \
deleter_ = cleanup_dispatcher().lookup(__DEVICE_TYPE__); deleter_ = cleanup_dispatcher().lookup(__DEVICE_TYPE__);
#define INFINICORE_GRAPH_OP_RECORD_OR_RUN(__OP_NAME__, ...) \ #define INFINICORE_GRAPH_OP_RECORD_OR_RUN(__OP_NAME__, ...) \
auto op = std::make_shared<__OP_NAME__>(__VA_ARGS__); \ auto ___op = std::make_shared<__OP_NAME__>(__VA_ARGS__); \
if (context::isGraphRecording()) { \ if (context::isGraphRecording()) { \
context::addGraphOperator(op); \ context::addGraphOperator(___op); \
} else { \ } else { \
op->run(); \ ___op->run(); \
} }
#define INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(__OP_NAME__, __PLAN_F__, __RUN_F__, __CLEANUP_F__) \ #define INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(__OP_NAME__, __PLAN_F__, __RUN_F__, __CLEANUP_F__) \
......
#pragma once #pragma once
#include "../device.hpp" #include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp" #include "common/op.hpp"
namespace infinicore::op { namespace infinicore::op {
class Add {
public:
using schema = void (*)(Tensor, Tensor, Tensor);
static void execute(Tensor c, Tensor a, Tensor b);
static common::OpDispatcher<schema> &dispatcher();
};
Tensor add(Tensor a, Tensor b); INFINICORE_GRAPH_OP_CLASS(Add, Tensor, const Tensor &, const Tensor &);
void add_(Tensor c, Tensor a, Tensor b);
Tensor operator+(Tensor a, Tensor b); Tensor add(const Tensor &a, const Tensor &b);
void add_(Tensor c, const Tensor &a, const Tensor &b);
} // namespace infinicore::op } // namespace infinicore::op
#pragma once #pragma once
#include "../device.hpp" #include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp" #include "common/op.hpp"
namespace infinicore::op { namespace infinicore::op {
class CausalSoftmax {
public:
using schema = void (*)(Tensor, Tensor);
static void execute(Tensor output, Tensor input);
static common::OpDispatcher<schema> &dispatcher();
};
Tensor causal_softmax(Tensor input); INFINICORE_GRAPH_OP_CLASS(CausalSoftmax, Tensor, const Tensor &);
void causal_softmax_(Tensor output, Tensor input);
Tensor causal_softmax(const Tensor &input);
void causal_softmax_(Tensor output, const Tensor &input);
} // namespace infinicore::op } // namespace infinicore::op
#pragma once
#include "../../device.hpp"
#include "../../graph/graph.hpp"
#include "../common/op.hpp"
#include <infiniccl.h>
namespace infinicore::op::distributed {
class AllReduce : public graph::GraphOperator {
public:
AllReduce(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator);
~AllReduce();
void run() const override;
static void execute(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator);
private:
void *planned_meta_;
};
Tensor allreduce(const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator);
void allreduce_(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator);
} // namespace infinicore::op::distributed
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
namespace infinicore::op { namespace infinicore::op {
INFINICORE_GRAPH_OP_CLASS(Gemm, Tensor, Tensor, Tensor, float, float); INFINICORE_GRAPH_OP_CLASS(Gemm, Tensor, const Tensor &, const Tensor &, float, float);
Tensor gemm(Tensor a, Tensor b, float alpha = 1.0f, float beta = 0.0f); Tensor gemm(const Tensor &a, const Tensor &b, float alpha = 1.0f, float beta = 0.0f);
void gemm_(Tensor c, Tensor a, Tensor b, float alpha, float beta); void gemm_(Tensor c, const Tensor &a, const Tensor &b, float alpha, float beta);
} // namespace infinicore::op } // namespace infinicore::op
#pragma once #pragma once
#include "../device.hpp" #include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp" #include "common/op.hpp"
namespace infinicore::op { namespace infinicore::op {
class Mul {
public:
using schema = void (*)(Tensor, Tensor, Tensor);
static void execute(Tensor c, Tensor a, Tensor b);
static common::OpDispatcher<schema> &dispatcher();
};
Tensor mul(Tensor a, Tensor b); INFINICORE_GRAPH_OP_CLASS(Mul, Tensor, const Tensor &, const Tensor &);
void mul_(Tensor c, Tensor a, Tensor b);
Tensor mul(const Tensor &a, const Tensor &b);
void mul_(Tensor c, const Tensor &a, const Tensor &b);
} // namespace infinicore::op } // namespace infinicore::op
#pragma once #pragma once
#include "../device.hpp" #include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp" #include "common/op.hpp"
#include <optional> #include <optional>
namespace infinicore::op { namespace infinicore::op {
class PagedAttention { INFINICORE_GRAPH_OP_CLASS(PagedAttention, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, std::optional<Tensor>, float);
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, std::optional<Tensor>, float); Tensor paged_attention(const Tensor &q, const Tensor &k_cache, const Tensor &v_cache,
static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float); const Tensor &block_tables, const Tensor &kv_lens,
static common::OpDispatcher<schema> &dispatcher(); std::optional<Tensor> alibi_slopes, float scale);
};
void paged_attention_(Tensor out, const Tensor &q, const Tensor &k_cache, const Tensor &v_cache,
const Tensor &block_tables, const Tensor &kv_lens,
std::optional<Tensor> alibi_slopes, float scale);
Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale);
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale);
} // namespace infinicore::op } // namespace infinicore::op
#pragma once #pragma once
#include "../device.hpp" #include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp" #include "common/op.hpp"
namespace infinicore::op { namespace infinicore::op {
class PagedCaching { INFINICORE_GRAPH_OP_CLASS(PagedCaching, Tensor, Tensor, const Tensor &, const Tensor &, const Tensor &);
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor);
static void execute(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping);
static common::OpDispatcher<schema> &dispatcher();
};
void paged_caching_(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping); void paged_caching_(Tensor k_cache, Tensor v_cache, const Tensor &k, const Tensor &v, const Tensor &slot_mapping);
} // namespace infinicore::op } // namespace infinicore::op
#pragma once #pragma once
#include "../device.hpp" #include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp" #include "common/op.hpp"
namespace infinicore::op { namespace infinicore::op {
class Rearrange {
public:
using schema = void (*)(Tensor, Tensor);
static void execute(Tensor y, Tensor x);
static common::OpDispatcher<schema> &dispatcher();
};
Tensor rearrange(Tensor x); INFINICORE_GRAPH_OP_CLASS(Rearrange, Tensor, const Tensor &);
void rearrange_(Tensor y, Tensor x);
Tensor rearrange(const Tensor &x);
void rearrange_(Tensor y, const Tensor &x);
} // namespace infinicore::op } // namespace infinicore::op
#pragma once #pragma once
#include "../device.hpp" #include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp" #include "common/op.hpp"
namespace infinicore::op { namespace infinicore::op {
class RMSNorm {
public:
using schema = void (*)(Tensor, Tensor, Tensor, float);
static void execute(Tensor y, Tensor x, Tensor weight, float epsilon = 1e-5f);
static common::OpDispatcher<schema> &dispatcher();
};
Tensor rms_norm(Tensor x, Tensor weight, float epsilon = 1e-5f); INFINICORE_GRAPH_OP_CLASS(RMSNorm, Tensor, const Tensor &, const Tensor &, float);
void rms_norm_(Tensor y, Tensor x, Tensor weight, float epsilon = 1e-5f);
Tensor rms_norm(const Tensor &x, const Tensor &weight, float epsilon = 1e-5f);
void rms_norm_(Tensor y, const Tensor &x, const Tensor &weight, float epsilon = 1e-5f);
} // namespace infinicore::op } // namespace infinicore::op
#pragma once #pragma once
#include "../device.hpp" #include "../device.hpp"
#include "../graph/graph.hpp"
#include "../nn/rope.hpp" #include "../nn/rope.hpp"
#include "../tensor.hpp" #include "../tensor.hpp"
#include "common/op.hpp" #include "common/op.hpp"
namespace infinicore::op { namespace infinicore::op {
class RoPE {
public:
using schema = void (*)(Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, infinicore::nn::RoPE::Algo);
static void execute(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo);
static common::OpDispatcher<schema> &dispatcher();
};
// Internal function INFINICORE_GRAPH_OP_CLASS(RoPE, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, infinicore::nn::RoPE::Algo);
void rope_(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo);
// Internal
void rope_(Tensor x_out,
const Tensor &x,
const Tensor &pos,
const Tensor &sin_table,
const Tensor &cos_table,
infinicore::nn::RoPE::Algo algo);
// Public API
Tensor rope(const Tensor &x,
const Tensor &pos,
const Tensor &sin_table,
const Tensor &cos_table,
infinicore::nn::RoPE::Algo algo);
// Public API that uses infinicore::nn::RoPE::Algo
Tensor rope(const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo);
} // namespace infinicore::op } // namespace infinicore::op
#pragma once #pragma once
#include "../device.hpp" #include "../device.hpp"
#include "../graph/graph.hpp"
#include "../tensor.hpp"
#include "common/op.hpp" #include "common/op.hpp"
namespace infinicore::op { namespace infinicore::op {
class SwiGLU {
public:
using schema = void (*)(Tensor, Tensor, Tensor);
static void execute(Tensor c, Tensor a, Tensor b);
static common::OpDispatcher<schema> &dispatcher();
};
Tensor swiglu(Tensor a, Tensor b); INFINICORE_GRAPH_OP_CLASS(SwiGLU, Tensor, const Tensor &, const Tensor &);
void swiglu_(Tensor c, Tensor a, Tensor b);
Tensor swiglu(const Tensor &a, const Tensor &b);
void swiglu_(Tensor c, const Tensor &a, const Tensor &b);
} // namespace infinicore::op } // namespace infinicore::op
...@@ -17,11 +17,11 @@ GraphTensor::GraphTensor(const Tensor &tensor) : Tensor(tensor->to_blob_()) { ...@@ -17,11 +17,11 @@ GraphTensor::GraphTensor(const Tensor &tensor) : Tensor(tensor->to_blob_()) {
* GraphOperator * GraphOperator
* ========================= */ * ========================= */
void GraphOperator::run() const { void DispatchableGraphOperator::run() const {
runner_(planned_meta_); runner_(planned_meta_);
} }
GraphOperator::~GraphOperator() { DispatchableGraphOperator::~DispatchableGraphOperator() {
if (deleter_) { if (deleter_) {
deleter_(&planned_meta_); deleter_(&planned_meta_);
} }
......
#include "infinicore/nn/linear.hpp" #include "infinicore/nn/linear.hpp"
#include "../utils.hpp" #include "../utils.hpp"
#include "infinicore/ops.hpp" #include "infinicore/ops.hpp"
#include "infinicore/ops/distributed/allreduce.hpp"
#include "infinicore/ops/linear.hpp" #include "infinicore/ops/linear.hpp"
#include <optional> #include <optional>
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
...@@ -102,9 +103,6 @@ ColumnParallelLinear::ColumnParallelLinear(size_t in_features, size_t out_featur ...@@ -102,9 +103,6 @@ ColumnParallelLinear::ColumnParallelLinear(size_t in_features, size_t out_featur
} else { } else {
bias_ = Parameter(); // Default constructed empty parameter bias_ = Parameter(); // Default constructed empty parameter
} }
// SPDLOG_DEBUG("Created ColumnParallelLinear module: in_features={}, out_features={}, bias={}, dtype={}",
// in_features, out_features, bias, static_cast<int>(dtype_));
} }
Tensor ColumnParallelLinear::forward(Tensor &input) const { Tensor ColumnParallelLinear::forward(Tensor &input) const {
...@@ -138,26 +136,13 @@ RowParallelLinear::RowParallelLinear(size_t in_features, size_t out_features, bo ...@@ -138,26 +136,13 @@ RowParallelLinear::RowParallelLinear(size_t in_features, size_t out_features, bo
} else { } else {
bias_ = Parameter(); // Default constructed empty parameter bias_ = Parameter(); // Default constructed empty parameter
} }
// SPDLOG_DEBUG("Created RowParallelLinear module: in_features={}, out_features={}, bias={}, dtype={}",
// in_features, out_features, bias, static_cast<int>(dtype_));
} }
Tensor RowParallelLinear::forward(Tensor &input) const { Tensor RowParallelLinear::forward(Tensor &input) const {
auto output = BaseLinear::forward(input); auto output = BaseLinear::forward(input);
if ((tp_size_ > 1) && (communicator_ != nullptr)) { if ((tp_size_ > 1) && (communicator_ != nullptr)) {
op::distributed::allreduce_(output, output, INFINICCL_SUM, communicator_);
Size count = output->numel();
DataType type = output->dtype();
infinirtStream_t stream = infinicore::context::getStream();
INFINICORE_CHECK_ERROR(infinicclAllReduce(output->data(), output->data(), count, static_cast<infiniDtype_t>(static_cast<int>(type)),
INFINICCL_SUM, communicator_, stream));
INFINICORE_CHECK_ERROR(infinirtStreamSynchronize(stream));
// RUN_INFINI(infinirtStreamSynchronize(stream));
} }
return output; return output;
} }
......
...@@ -3,24 +3,24 @@ ...@@ -3,24 +3,24 @@
namespace infinicore::op { namespace infinicore::op {
common::OpDispatcher<Add::schema> &Add::dispatcher() { INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Add);
static common::OpDispatcher<Add::schema> dispatcher_;
return dispatcher_;
};
void Add::execute(Tensor c, Tensor a, Tensor b) { Add::Add(Tensor c, const Tensor &a, const Tensor &b) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b);
infinicore::context::setDevice(c->device()); INFINICORE_GRAPH_OP_DISPATCH(c->device().getType(), c, a, b);
dispatcher().lookup(c->device().getType())(c, a, b);
} }
Tensor add(Tensor a, Tensor b) { void Add::execute(Tensor c, const Tensor &a, const Tensor &b) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(Add, c, a, b);
}
Tensor add(const Tensor &a, const Tensor &b) {
auto c = Tensor::empty(a->shape(), a->dtype(), a->device()); auto c = Tensor::empty(a->shape(), a->dtype(), a->device());
add_(c, a, b); add_(c, a, b);
return c; return c;
} }
void add_(Tensor c, Tensor a, Tensor b) { void add_(Tensor c, const Tensor &a, const Tensor &b) {
Add::execute(c, a, b); Add::execute(c, a, b);
} }
......
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/add.hpp" #include "infinicore/ops/add.hpp"
#include "infinicore/ops/common/cache.hpp"
#include <infiniop.h> #include "../infiniop_impl.hpp"
namespace infinicore::op::add_impl::infiniop { namespace infinicore::op::add_impl::infiniop {
thread_local common::OpCache<size_t, infiniopAddDescriptor_t> caches( INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Add, 100);
100, // capacity
[](infiniopAddDescriptor_t &desc) { struct PlannedMeta {
if (desc != nullptr) { std::shared_ptr<Descriptor> descriptor;
INFINICORE_CHECK_ERROR(infiniopDestroyAddDescriptor(desc)); graph::GraphTensor workspace, c, a, b;
desc = nullptr; };
}
});
void calculate(Tensor c, Tensor a, Tensor b) { void *plan(Tensor c, const Tensor &a, const Tensor &b) {
size_t seed = hash_combine(c, b, a); size_t seed = hash_combine(c, b, a);
auto device = context::getDevice(); INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
auto &cache = caches.getCache(device); Descriptor, descriptor, Add,
seed,
c->desc(), a->desc(), b->desc());
auto desc_opt = cache.get(seed); INFINIOP_WORKSPACE_TENSOR(workspace, Add, descriptor);
infiniopAddDescriptor_t desc = nullptr;
if (!desc_opt) { return new PlannedMeta{
INFINICORE_CHECK_ERROR(infiniopCreateAddDescriptor( descriptor,
context::getInfiniopHandle(device), &desc, graph::GraphTensor(workspace),
c->desc(), a->desc(), b->desc())); graph::GraphTensor(c),
cache.put(seed, desc); graph::GraphTensor(a),
} else { graph::GraphTensor(b)};
desc = *desc_opt; }
}
size_t workspace_size = 0; void run(void *planned_meta) {
INFINICORE_CHECK_ERROR(infiniopGetAddWorkspaceSize(desc, &workspace_size)); auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
INFINICORE_CHECK_ERROR(infiniopAdd( INFINICORE_CHECK_ERROR(infiniopAdd(
desc, workspace->data(), workspace_size, planned->descriptor->desc,
c->data(), a->data(), b->data(), context::getStream())); planned->workspace->data(),
planned->workspace->numel(),
planned->c->data(),
planned->a->data(),
planned->b->data(),
context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
} }
static bool registered = []() { INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Add, &plan, &run, &cleanup);
Add::dispatcher().registerAll(&calculate, false);
return true;
}();
} // namespace infinicore::op::add_impl::infiniop } // namespace infinicore::op::add_impl::infiniop
#include "infinicore/ops/causal_softmax.hpp" #include "infinicore/ops/causal_softmax.hpp"
#include "../../utils.hpp" #include "../../utils.hpp"
#include <stdexcept>
namespace infinicore::op { namespace infinicore::op {
common::OpDispatcher<CausalSoftmax::schema> &CausalSoftmax::dispatcher() { INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(CausalSoftmax);
static common::OpDispatcher<CausalSoftmax::schema> dispatcher_;
return dispatcher_;
};
void CausalSoftmax::execute(Tensor output, Tensor input) { CausalSoftmax::CausalSoftmax(Tensor output, const Tensor &input) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(output, input); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(output, input);
infinicore::context::setDevice(output->device()); INFINICORE_GRAPH_OP_DISPATCH(output->device().getType(), output, input);
auto device_type = output->device().getType(); }
auto func = dispatcher().lookup(device_type);
if (func == nullptr) {
throw std::runtime_error("No CausalSoftmax implementation found for device type: " + std::to_string(static_cast<int>(device_type)));
}
func(output, input); void CausalSoftmax::execute(Tensor output, const Tensor &input) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(CausalSoftmax, output, input);
} }
Tensor causal_softmax(Tensor input) { Tensor causal_softmax(const Tensor &input) {
Shape shape = input->shape(); auto output = Tensor::empty(input->shape(), input->dtype(), input->device());
auto output = Tensor::empty(shape, input->dtype(), input->device());
causal_softmax_(output, input); causal_softmax_(output, input);
return output; return output;
} }
void causal_softmax_(Tensor output, Tensor input) { void causal_softmax_(Tensor output, const Tensor &input) {
CausalSoftmax::execute(output, input); CausalSoftmax::execute(output, input);
} }
} // namespace infinicore::op } // namespace infinicore::op
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/causal_softmax.hpp" #include "infinicore/ops/causal_softmax.hpp"
#include "infinicore/ops/common/cache.hpp"
#include <infiniop.h> #include "../infiniop_impl.hpp"
namespace infinicore::op::causal_softmax_impl::infiniop { namespace infinicore::op::causal_softmax_impl::infiniop {
thread_local common::OpCache<size_t, infiniopCausalSoftmaxDescriptor_t> caches( INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, CausalSoftmax, 100);
100, // capacity
[](infiniopCausalSoftmaxDescriptor_t &desc) { struct PlannedMeta {
if (desc != nullptr) { std::shared_ptr<Descriptor> descriptor;
INFINICORE_CHECK_ERROR(infiniopDestroyCausalSoftmaxDescriptor(desc)); graph::GraphTensor workspace, output, input;
desc = nullptr; };
}
});
void calculate(Tensor output, Tensor input) { void *plan(Tensor output, const Tensor &input) {
size_t seed = hash_combine(output, input); size_t seed = hash_combine(output, input);
auto device = context::getDevice(); INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
auto &cache = caches.getCache(device); Descriptor, descriptor, CausalSoftmax,
seed, output->desc(), input->desc());
auto desc_opt = cache.get(seed); INFINIOP_WORKSPACE_TENSOR(workspace, CausalSoftmax, descriptor);
infiniopCausalSoftmaxDescriptor_t desc = nullptr;
if (!desc_opt) { return new PlannedMeta{
INFINICORE_CHECK_ERROR(infiniopCreateCausalSoftmaxDescriptor( descriptor,
context::getInfiniopHandle(device), &desc, graph::GraphTensor(workspace),
output->desc(), input->desc())); graph::GraphTensor(output),
cache.put(seed, desc); graph::GraphTensor(input)};
} else { }
desc = *desc_opt;
}
size_t workspace_size = 0; void run(void *planned_meta) {
INFINICORE_CHECK_ERROR(infiniopGetCausalSoftmaxWorkspaceSize(desc, &workspace_size)); auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
INFINICORE_CHECK_ERROR(infiniopCausalSoftmax( INFINICORE_CHECK_ERROR(infiniopCausalSoftmax(
desc, workspace->data(), workspace_size, planned->descriptor->desc,
output->data(), input->data(), context::getStream())); planned->workspace->data(),
planned->workspace->numel(),
planned->output->data(),
planned->input->data(),
context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
} }
static bool registered = []() { INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(CausalSoftmax, &plan, &run, &cleanup);
CausalSoftmax::dispatcher().registerAll(&calculate, false);
return true;
}();
} // namespace infinicore::op::causal_softmax_impl::infiniop } // namespace infinicore::op::causal_softmax_impl::infiniop
#include "infinicore/ops/distributed/allreduce.hpp"
#include "../../utils.hpp"
namespace infinicore::op::distributed {
struct PlannedMeta {
graph::GraphTensor output, input;
infinicclReduceOp_t op;
infinicclComm_t communicator;
};
AllReduce::AllReduce(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(output, input);
INFINICORE_ASSERT(output->is_contiguous() && input->is_contiguous());
INFINICORE_ASSERT(output->numel() == input->numel());
planned_meta_ = new PlannedMeta{graph::GraphTensor(output), graph::GraphTensor(input), op, communicator};
}
AllReduce::~AllReduce() {
if (planned_meta_) {
PlannedMeta *meta = reinterpret_cast<PlannedMeta *>(planned_meta_);
delete meta;
}
}
void AllReduce::run() const {
PlannedMeta *meta = reinterpret_cast<PlannedMeta *>(planned_meta_);
INFINICORE_CHECK_ERROR(infinicclAllReduce(meta->input->data(),
meta->output->data(),
meta->input->numel(),
static_cast<infiniDtype_t>(static_cast<int>(meta->input->dtype())),
meta->op,
meta->communicator,
infinicore::context::getStream()));
}
void AllReduce::execute(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(AllReduce, output, input, op, communicator);
}
Tensor allreduce(const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator) {
auto output = Tensor::empty(input->shape(), input->dtype(), input->device());
allreduce_(output, input, op, communicator);
return output;
}
void allreduce_(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator) {
AllReduce::execute(output, input, op, communicator);
}
} // namespace infinicore::op::distributed
...@@ -5,16 +5,16 @@ ...@@ -5,16 +5,16 @@
namespace infinicore::op { namespace infinicore::op {
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Gemm); INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Gemm);
Gemm::Gemm(Tensor c, Tensor a, Tensor b, float alpha, float beta) { Gemm::Gemm(Tensor c, const Tensor &a, const Tensor &b, float alpha, float beta) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b);
INFINICORE_GRAPH_OP_DISPATCH(c->device().getType(), c, a, b, alpha, beta); INFINICORE_GRAPH_OP_DISPATCH(c->device().getType(), c, a, b, alpha, beta);
} }
void Gemm::execute(Tensor c, Tensor a, Tensor b, float alpha, float beta) { void Gemm::execute(Tensor c, const Tensor &a, const Tensor &b, float alpha, float beta) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(Gemm, c, a, b, alpha, beta); INFINICORE_GRAPH_OP_RECORD_OR_RUN(Gemm, c, a, b, alpha, beta);
} }
Tensor gemm(Tensor a, Tensor b, float alpha, float beta) { Tensor gemm(const Tensor &a, const Tensor &b, float alpha, float beta) {
Shape shape = a->shape(); Shape shape = a->shape();
Size size = a->ndim(); Size size = a->ndim();
shape[size - 1] = b->size(size - 1); shape[size - 1] = b->size(size - 1);
...@@ -23,7 +23,7 @@ Tensor gemm(Tensor a, Tensor b, float alpha, float beta) { ...@@ -23,7 +23,7 @@ Tensor gemm(Tensor a, Tensor b, float alpha, float beta) {
return c; return c;
} }
void gemm_(Tensor c, Tensor a, Tensor b, float alpha, float beta) { void gemm_(Tensor c, const Tensor &a, const Tensor &b, float alpha, float beta) {
Gemm::execute(c, a, b, alpha, beta); Gemm::execute(c, a, b, alpha, beta);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment