Unverified Commit 784139b9 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #990 from InfiniTensor/demo131

Demo-131 Cuda graph with optimized paged attention
parents 3c8fb3c0 1d6527cb
...@@ -42,6 +42,7 @@ void printUsage() { ...@@ -42,6 +42,7 @@ void printUsage() {
<< " qy" << std::endl << " qy" << std::endl
<< " kunlun" << std::endl << " kunlun" << std::endl
<< " hygon" << std::endl << " hygon" << std::endl
<< " ali" << std::endl
<< std::endl << std::endl
<< "Available tests:" << std::endl << "Available tests:" << std::endl
<< " basic - Basic memory allocation and deallocation tests" << std::endl << " basic - Basic memory allocation and deallocation tests" << std::endl
...@@ -84,6 +85,8 @@ ParsedArgs parseArgs(int argc, char *argv[]) { ...@@ -84,6 +85,8 @@ ParsedArgs parseArgs(int argc, char *argv[]) {
args.device_type = INFINI_DEVICE_KUNLUN; args.device_type = INFINI_DEVICE_KUNLUN;
} else if (arg == "--hygon") { } else if (arg == "--hygon") {
args.device_type = INFINI_DEVICE_HYGON; args.device_type = INFINI_DEVICE_HYGON;
} else if (arg == "--ali") {
args.device_type = INFINI_DEVICE_ALI;
} else if (arg == "--test") { } else if (arg == "--test") {
if (i + 1 >= argc) { if (i + 1 >= argc) {
std::cerr << "Error: --test requires a test name" << std::endl; std::cerr << "Error: --test requires a test name" << std::endl;
......
...@@ -41,6 +41,8 @@ std::string Device::toString(const Type &type) { ...@@ -41,6 +41,8 @@ std::string Device::toString(const Type &type) {
return "KUNLUN"; return "KUNLUN";
case Type::HYGON: case Type::HYGON:
return "HYGON"; return "HYGON";
case Type::ALI:
return "ALI";
case Type::COUNT: case Type::COUNT:
return "COUNT"; return "COUNT";
default: default:
......
...@@ -17,11 +17,11 @@ GraphTensor::GraphTensor(const Tensor &tensor) : Tensor(tensor->to_blob_()) { ...@@ -17,11 +17,11 @@ GraphTensor::GraphTensor(const Tensor &tensor) : Tensor(tensor->to_blob_()) {
* GraphOperator * GraphOperator
* ========================= */ * ========================= */
void GraphOperator::run() const { void DispatchableGraphOperator::run() const {
runner_(planned_meta_); runner_(planned_meta_);
} }
GraphOperator::~GraphOperator() { DispatchableGraphOperator::~DispatchableGraphOperator() {
if (deleter_) { if (deleter_) {
deleter_(&planned_meta_); deleter_(&planned_meta_);
} }
...@@ -84,7 +84,7 @@ void Graph::instantiate() { ...@@ -84,7 +84,7 @@ void Graph::instantiate() {
if (infinirtStreamBeginCapture( if (infinirtStreamBeginCapture(
context::getStream(), context::getStream(),
INFINIRT_STREAM_CAPTURE_MODE_GLOBAL) INFINIRT_STREAM_CAPTURE_MODE_RELAXED)
!= INFINI_STATUS_SUCCESS) { != INFINI_STATUS_SUCCESS) {
return; return;
} }
...@@ -144,7 +144,9 @@ std::shared_ptr<Graph> GraphManager::stop_recording() { ...@@ -144,7 +144,9 @@ std::shared_ptr<Graph> GraphManager::stop_recording() {
return nullptr; return nullptr;
} }
recording_ = false; recording_ = false;
#ifdef USE_INFINIRT_GRAPH
graph_->instantiate(); graph_->instantiate();
#endif
return std::exchange(graph_, nullptr); return std::exchange(graph_, nullptr);
} }
......
...@@ -43,6 +43,13 @@ Embedding::Embedding(size_t num_embeddings, ...@@ -43,6 +43,13 @@ Embedding::Embedding(size_t num_embeddings,
} }
Tensor Embedding::forward(const Tensor &indices) const { Tensor Embedding::forward(const Tensor &indices) const {
// TODO: Implement on-device embedding for all devices, then remove the condition and the classic approach
auto device_type = device_.getType();
if (device_type == Device::Type::NVIDIA || device_type == Device::Type::ILUVATAR || device_type == Device::Type::METAX || device_type == Device::Type::MOORE || device_type == Device::Type::ALI) {
// Use op::embedding which supports device-side input and batch dimension
return op::embedding(indices->contiguous()->to(device_), weight_);
}
// Get the shape of indices // Get the shape of indices
auto indices_shape = indices->shape(); auto indices_shape = indices->shape();
......
#include "infinicore/nn/linear.hpp" #include "infinicore/nn/linear.hpp"
#include "../utils.hpp" #include "../utils.hpp"
#include "infinicore/ops.hpp" #include "infinicore/ops.hpp"
#include "infinicore/ops/distributed/allreduce.hpp"
#include "infinicore/ops/linear.hpp" #include "infinicore/ops/linear.hpp"
#include "infinicore/ops/linear_w8a8i8.hpp"
#include <optional> #include <optional>
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
...@@ -17,21 +19,46 @@ BaseLinear::BaseLinear(size_t in_features, size_t out_features, bool bias, ...@@ -17,21 +19,46 @@ BaseLinear::BaseLinear(size_t in_features, size_t out_features, bool bias,
device_ = device; device_ = device;
} }
Tensor BaseLinear::compute_linear(Tensor &input) const { BaseLinear::BaseLinear(size_t in_features, size_t out_features, std::shared_ptr<infinicore::quantization::BaseQuantization> quantization, bool bias,
const DataType &dtype, const Device &device)
: in_features_(in_features),
out_features_(out_features),
quantization_(quantization),
has_bias_(bias),
dtype_(dtype) {
device_ = device;
}
// Ensure input is contiguous before creating views (required for matmul) Tensor BaseLinear::compute_linear(Tensor &input) const {
// This prevents hanging when input tensor has non-contiguous memory layout switch (this->quantization_->get_quant_scheme()) {
Tensor input_contiguous = input->is_contiguous() ? input : input->contiguous(); case infinicore::quantization::QuantScheme::COMPRESSED_TENSOR_W8A8I8: {
Tensor input_contiguous = input->is_contiguous() ? input : input->contiguous();
// Use ops::linear_ directly to match Python backend's exact code path Tensor weight_packed_tensor = static_cast<const Tensor &>(weight_);
// This ensures identical computation and numerical results Tensor weight_scale_tensor = static_cast<const Tensor &>(weight_scale_);
// Parameter inherits from Tensor, so we cast to Tensor explicitly // weight_packed should be transposed and non-contiguous.
Tensor weight_tensor = static_cast<const Tensor &>(weight_); std::optional<Tensor> bias_opt = has_bias_ ? std::make_optional<Tensor>(static_cast<const Tensor &>(bias_)) : std::nullopt;
std::optional<Tensor> bias_opt = has_bias_ ? std::make_optional<Tensor>(static_cast<const Tensor &>(bias_)) : std::nullopt;
auto output = infinicore::op::linear(input_contiguous->contiguous(), weight_tensor->contiguous(), bias_opt); auto output = infinicore::op::linear_w8a8i8(input_contiguous->contiguous(), weight_packed_tensor, weight_scale_tensor, bias_opt);
return output; return output;
} }
default: {
// Ensure input is contiguous before creating views (required for matmul)
// This prevents hanging when input tensor has non-contiguous memory layout
Tensor input_contiguous = input->is_contiguous() ? input : input->contiguous();
// Use ops::linear_ directly to match Python backend's exact code path
// This ensures identical computation and numerical results
// Parameter inherits from Tensor, so we cast to Tensor explicitly
Tensor weight_tensor = static_cast<const Tensor &>(weight_);
std::optional<Tensor> bias_opt = has_bias_ ? std::make_optional<Tensor>(static_cast<const Tensor &>(bias_)) : std::nullopt;
auto output = infinicore::op::linear(input_contiguous->contiguous(), weight_tensor->contiguous(), bias_opt);
return output;
}
}
} // namespace infinicore::nn
Tensor BaseLinear::forward(Tensor &input) const { Tensor BaseLinear::forward(Tensor &input) const {
return compute_linear(input); return compute_linear(input);
...@@ -70,6 +97,43 @@ Linear::Linear(size_t in_features, size_t out_features, bool bias, ...@@ -70,6 +97,43 @@ Linear::Linear(size_t in_features, size_t out_features, bool bias,
// in_features, out_features, bias, static_cast<int>(dtype_)); // in_features, out_features, bias, static_cast<int>(dtype_));
} }
Linear::Linear(size_t in_features, size_t out_features,
std::shared_ptr<infinicore::quantization::BaseQuantization> quantization, bool bias,
const DataType &dtype, const Device &device)
: BaseLinear(in_features, out_features, quantization, bias, dtype, device_) {
device_ = device;
switch (this->quantization_->get_quant_scheme()) {
case infinicore::quantization::QuantScheme::COMPRESSED_TENSOR_W8A8I8: {
INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, infinicore::DataType::I8, device));
INFINICORE_NN_PARAMETER_INIT(weight_scale, ({out_features, 1}, infinicore::DataType::F32, device));
if (bias) {
INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device));
} else {
bias_ = Parameter();
}
break;
}
default: {
// Initialize parameters using macro
INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device));
// Register bias parameter if requested
if (bias) {
INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device));
} else {
bias_ = Parameter(); // Default constructed empty parameter
}
// SPDLOG_DEBUG("Created Linear module: in_features={}, out_features={}, bias={}, dtype={}",
// in_features, out_features, bias, static_cast<int>(dtype_));
break;
}
}
}
Tensor Linear::forward(Tensor &input) const { Tensor Linear::forward(Tensor &input) const {
return BaseLinear::forward(input); return BaseLinear::forward(input);
} }
...@@ -102,9 +166,45 @@ ColumnParallelLinear::ColumnParallelLinear(size_t in_features, size_t out_featur ...@@ -102,9 +166,45 @@ ColumnParallelLinear::ColumnParallelLinear(size_t in_features, size_t out_featur
} else { } else {
bias_ = Parameter(); // Default constructed empty parameter bias_ = Parameter(); // Default constructed empty parameter
} }
}
// SPDLOG_DEBUG("Created ColumnParallelLinear module: in_features={}, out_features={}, bias={}, dtype={}", ColumnParallelLinear::ColumnParallelLinear(size_t in_features, size_t out_features, std::shared_ptr<infinicore::quantization::BaseQuantization> quantization, bool bias,
// in_features, out_features, bias, static_cast<int>(dtype_)); const DataType &dtype, const Device &device,
Size tp_rank, Size tp_size)
: BaseLinear(in_features, out_features, quantization, bias, dtype, device_),
tp_rank_(tp_rank),
tp_size_(tp_size) {
device_ = device;
switch (this->quantization_->get_quant_scheme()) {
case infinicore::quantization::QuantScheme::COMPRESSED_TENSOR_W8A8I8: {
INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, infinicore::DataType::I8, device, 0, tp_rank_, tp_size_));
INFINICORE_NN_PARAMETER_INIT(weight_scale, ({out_features, 1}, infinicore::DataType::F32, device, 0, tp_rank_, tp_size_));
if (bias) {
INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device, 0, 0, 1));
} else {
bias_ = Parameter();
}
break;
}
default: {
// Initialize parameters using macro
INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device,
0, tp_rank_, tp_size_));
// Register bias parameter if requested
if (bias) {
INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device,
0, tp_rank_, tp_size_));
} else {
bias_ = Parameter(); // Default constructed empty parameter
}
break;
}
}
} }
Tensor ColumnParallelLinear::forward(Tensor &input) const { Tensor ColumnParallelLinear::forward(Tensor &input) const {
...@@ -138,26 +238,53 @@ RowParallelLinear::RowParallelLinear(size_t in_features, size_t out_features, bo ...@@ -138,26 +238,53 @@ RowParallelLinear::RowParallelLinear(size_t in_features, size_t out_features, bo
} else { } else {
bias_ = Parameter(); // Default constructed empty parameter bias_ = Parameter(); // Default constructed empty parameter
} }
}
// SPDLOG_DEBUG("Created RowParallelLinear module: in_features={}, out_features={}, bias={}, dtype={}", RowParallelLinear::RowParallelLinear(size_t in_features, size_t out_features, std::shared_ptr<infinicore::quantization::BaseQuantization> quantization, bool bias,
// in_features, out_features, bias, static_cast<int>(dtype_)); const DataType &dtype, const Device &device,
Size tp_rank, Size tp_size, infinicclComm_t communicator)
: BaseLinear(in_features, out_features, quantization, bias, dtype, device_),
tp_rank_(tp_rank),
tp_size_(tp_size), communicator_(communicator) {
device_ = device;
switch (this->quantization_->get_quant_scheme()) {
case infinicore::quantization::QuantScheme::COMPRESSED_TENSOR_W8A8I8: {
INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, infinicore::DataType::I8, device, 1, tp_rank_, tp_size_));
INFINICORE_NN_PARAMETER_INIT(weight_scale, ({out_features, 1}, infinicore::DataType::F32, device, 0, 0, 1));
if (bias) {
INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device, 0, tp_rank_, tp_size_));
} else {
bias_ = Parameter();
}
break;
}
default: {
// Initialize parameters using macro
INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device,
1, tp_rank_, tp_size_));
// Register bias parameter if requested
if (bias && (0 == tp_rank_)) {
INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device, 0, 0, 1));
} else {
bias_ = Parameter(); // Default constructed empty parameter
}
// SPDLOG_DEBUG("Created RowParallelLinear module: in_features={}, out_features={}, bias={}, dtype={}",
// in_features, out_features, bias, static_cast<int>(dtype_));
break;
}
}
} }
Tensor RowParallelLinear::forward(Tensor &input) const { Tensor RowParallelLinear::forward(Tensor &input) const {
auto output = BaseLinear::forward(input); auto output = BaseLinear::forward(input);
if ((tp_size_ > 1) && (communicator_ != nullptr)) { if ((tp_size_ > 1) && (communicator_ != nullptr)) {
op::distributed::allreduce_(output, output, INFINICCL_SUM, communicator_);
Size count = output->numel();
DataType type = output->dtype();
infinirtStream_t stream = infinicore::context::getStream();
INFINICORE_CHECK_ERROR(infinicclAllReduce(output->data(), output->data(), count, static_cast<infiniDtype_t>(static_cast<int>(type)),
INFINICCL_SUM, communicator_, stream));
INFINICORE_CHECK_ERROR(infinirtStreamSynchronize(stream));
// RUN_INFINI(infinirtStreamSynchronize(stream));
} }
return output; return output;
} }
......
...@@ -21,6 +21,25 @@ Tensor RMSNorm::forward(const Tensor &x) const { ...@@ -21,6 +21,25 @@ Tensor RMSNorm::forward(const Tensor &x) const {
return op::rms_norm(x, weight_, static_cast<float>(eps_)); return op::rms_norm(x, weight_, static_cast<float>(eps_));
} }
void RMSNorm::forward_inplace(Tensor &x, Tensor &residual) const {
if (!residual) {
residual = x;
x = op::rms_norm(x, weight_, static_cast<float>(eps_));
} else {
if (device_.getType() == Device::Type::CPU
|| device_.getType() == Device::Type::NVIDIA
|| device_.getType() == Device::Type::ILUVATAR
|| device_.getType() == Device::Type::METAX
|| device_.getType() == Device::Type::MOORE
|| device_.getType() == Device::Type::ALI) {
op::add_rms_norm_inplace(x, residual, weight_, static_cast<float>(eps_));
} else {
op::add_(residual, x, residual);
op::rms_norm_(x, residual, weight_, static_cast<float>(eps_));
}
}
}
std::string RMSNorm::extra_repr() const { std::string RMSNorm::extra_repr() const {
return "RMSNorm(normalized_shape=" + std::to_string(normalized_shape_) + ", eps=" + std::to_string(eps_) + ", dtype=" + std::to_string(static_cast<int>(dtype_)) + ")"; return "RMSNorm(normalized_shape=" + std::to_string(normalized_shape_) + ", eps=" + std::to_string(eps_) + ", dtype=" + std::to_string(static_cast<int>(dtype_)) + ")";
} }
......
...@@ -3,24 +3,24 @@ ...@@ -3,24 +3,24 @@
namespace infinicore::op { namespace infinicore::op {
common::OpDispatcher<Add::schema> &Add::dispatcher() { INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Add);
static common::OpDispatcher<Add::schema> dispatcher_;
return dispatcher_;
};
void Add::execute(Tensor c, Tensor a, Tensor b) { Add::Add(Tensor c, const Tensor &a, const Tensor &b) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b);
infinicore::context::setDevice(c->device()); INFINICORE_GRAPH_OP_DISPATCH(c->device().getType(), c, a, b);
dispatcher().lookup(c->device().getType())(c, a, b);
} }
Tensor add(Tensor a, Tensor b) { void Add::execute(Tensor c, const Tensor &a, const Tensor &b) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(Add, c, a, b);
}
Tensor add(const Tensor &a, const Tensor &b) {
auto c = Tensor::empty(a->shape(), a->dtype(), a->device()); auto c = Tensor::empty(a->shape(), a->dtype(), a->device());
add_(c, a, b); add_(c, a, b);
return c; return c;
} }
void add_(Tensor c, Tensor a, Tensor b) { void add_(Tensor c, const Tensor &a, const Tensor &b) {
Add::execute(c, a, b); Add::execute(c, a, b);
} }
......
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/add.hpp" #include "infinicore/ops/add.hpp"
#include "infinicore/ops/common/cache.hpp"
#include <infiniop.h> #include "../infiniop_impl.hpp"
namespace infinicore::op::add_impl::infiniop { namespace infinicore::op::add_impl::infiniop {
thread_local common::OpCache<size_t, infiniopAddDescriptor_t> caches( INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Add, 100);
100, // capacity
[](infiniopAddDescriptor_t &desc) { struct PlannedMeta {
if (desc != nullptr) { std::shared_ptr<Descriptor> descriptor;
INFINICORE_CHECK_ERROR(infiniopDestroyAddDescriptor(desc)); graph::GraphTensor workspace, c, a, b;
desc = nullptr; };
}
});
void calculate(Tensor c, Tensor a, Tensor b) { void *plan(Tensor c, const Tensor &a, const Tensor &b) {
size_t seed = hash_combine(c, b, a); size_t seed = hash_combine(c, b, a);
auto device = context::getDevice(); INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
auto &cache = caches.getCache(device); Descriptor, descriptor, Add,
seed,
c->desc(), a->desc(), b->desc());
auto desc_opt = cache.get(seed); INFINIOP_WORKSPACE_TENSOR(workspace, Add, descriptor);
infiniopAddDescriptor_t desc = nullptr;
if (!desc_opt) { return new PlannedMeta{
INFINICORE_CHECK_ERROR(infiniopCreateAddDescriptor( descriptor,
context::getInfiniopHandle(device), &desc, graph::GraphTensor(workspace),
c->desc(), a->desc(), b->desc())); graph::GraphTensor(c),
cache.put(seed, desc); graph::GraphTensor(a),
} else { graph::GraphTensor(b)};
desc = *desc_opt; }
}
size_t workspace_size = 0; void run(void *planned_meta) {
INFINICORE_CHECK_ERROR(infiniopGetAddWorkspaceSize(desc, &workspace_size)); auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
INFINICORE_CHECK_ERROR(infiniopAdd( INFINICORE_CHECK_ERROR(infiniopAdd(
desc, workspace->data(), workspace_size, planned->descriptor->desc,
c->data(), a->data(), b->data(), context::getStream())); planned->workspace->data(),
planned->workspace->numel(),
planned->c->data(),
planned->a->data(),
planned->b->data(),
context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
} }
static bool registered = []() { INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Add, &plan, &run, &cleanup);
Add::dispatcher().registerAll(&calculate, false);
return true;
}();
} // namespace infinicore::op::add_impl::infiniop } // namespace infinicore::op::add_impl::infiniop
...@@ -4,26 +4,30 @@ ...@@ -4,26 +4,30 @@
namespace infinicore::op { namespace infinicore::op {
common::OpDispatcher<AddRMSNorm::schema> &AddRMSNorm::dispatcher() { INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(AddRMSNorm);
static common::OpDispatcher<AddRMSNorm::schema> dispatcher_;
return dispatcher_;
};
void AddRMSNorm::execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) { AddRMSNorm::AddRMSNorm(Tensor y, Tensor residual_out, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, residual_out, a, b, weight); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, residual_out, a, b, weight);
infinicore::context::setDevice(y->device()); INFINICORE_GRAPH_OP_DISPATCH(y->device().getType(), y, residual_out, a, b, weight, epsilon);
dispatcher().lookup(y->device().getType())(y, residual_out, a, b, weight, epsilon);
} }
std::pair<Tensor, Tensor> add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon) { void AddRMSNorm::execute(Tensor y, Tensor residual_out, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(AddRMSNorm, y, residual_out, a, b, weight, epsilon);
}
std::pair<Tensor, Tensor> add_rms_norm(const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) {
auto y = Tensor::empty(a->shape(), a->dtype(), a->device()); auto y = Tensor::empty(a->shape(), a->dtype(), a->device());
auto residual_out = Tensor::empty(a->shape(), a->dtype(), a->device()); auto residual_out = Tensor::empty(a->shape(), a->dtype(), a->device());
add_rms_norm_(y, residual_out, a, b, weight, epsilon); add_rms_norm_(y, residual_out, a, b, weight, epsilon);
return std::make_pair(y, residual_out); return std::make_pair(y, residual_out);
} }
void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) { void add_rms_norm_(Tensor out, Tensor residual, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) {
AddRMSNorm::execute(y, residual_out, a, b, weight, epsilon); AddRMSNorm::execute(out, residual, a, b, weight, epsilon);
}
void add_rms_norm_inplace(Tensor input, Tensor residual, const Tensor &weight, float epsilon) {
add_rms_norm_(input, residual, input, residual, weight, epsilon);
} }
} // namespace infinicore::op } // namespace infinicore::op
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/add_rms_norm.hpp" #include "infinicore/ops/add_rms_norm.hpp"
#include "infinicore/ops/common/cache.hpp"
#include <infiniop.h> #include "../infiniop_impl.hpp"
namespace infinicore::op::add_rms_norm_impl::infiniop { namespace infinicore::op::add_rms_norm_impl::infiniop {
thread_local common::OpCache<size_t, infiniopAddRMSNormDescriptor_t> caches( INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, AddRMSNorm, 100);
100, // capacity
[](infiniopAddRMSNormDescriptor_t &desc) { struct PlannedMeta {
if (desc != nullptr) { std::shared_ptr<Descriptor> descriptor;
INFINICORE_CHECK_ERROR(infiniopDestroyAddRMSNormDescriptor(desc)); graph::GraphTensor workspace, out, residual, a, b, weight;
desc = nullptr; float epsilon;
} };
});
void calculate(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) { void *plan(Tensor y, Tensor residual_out, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) {
size_t seed = hash_combine(y, residual_out, a, b, weight, epsilon); size_t seed = hash_combine(y, residual_out, a, b, weight, epsilon);
auto device = context::getDevice(); INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
auto &cache = caches.getCache(device); Descriptor, descriptor, AddRMSNorm,
seed, y->desc(), residual_out->desc(),
a->desc(), b->desc(), weight->desc(), epsilon);
INFINIOP_WORKSPACE_TENSOR(workspace, AddRMSNorm, descriptor);
auto desc_opt = cache.get(seed); auto planned = new PlannedMeta{
infiniopAddRMSNormDescriptor_t desc = nullptr; descriptor,
graph::GraphTensor(workspace),
graph::GraphTensor(y),
graph::GraphTensor(residual_out),
graph::GraphTensor(a),
graph::GraphTensor(b),
graph::GraphTensor(weight),
epsilon};
if (!desc_opt) { return planned;
INFINICORE_CHECK_ERROR(infiniopCreateAddRMSNormDescriptor( }
context::getInfiniopHandle(device), &desc,
y->desc(), a->desc(), b->desc(), weight->desc(), epsilon, residual_out->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}
size_t workspace_size = 0; void run(void *planned_meta) {
INFINICORE_CHECK_ERROR(infiniopGetAddRMSNormWorkspaceSize(desc, &workspace_size)); auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
INFINICORE_CHECK_ERROR(infiniopAddRMSNorm( INFINICORE_CHECK_ERROR(infiniopAddRMSNorm(
desc, workspace->data(), workspace_size, planned->descriptor->desc, planned->workspace->data(), planned->workspace->numel(),
y->data(), a->data(), b->data(), weight->data(), residual_out->data(), context::getStream())); planned->out->data(), planned->residual->data(), planned->a->data(), planned->b->data(), planned->weight->data(), context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
} }
static bool registered = []() { INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(AddRMSNorm, &plan, &run, &cleanup);
AddRMSNorm::dispatcher().registerAll(&calculate, false);
return true;
}();
} // namespace infinicore::op::add_rms_norm_impl::infiniop } // namespace infinicore::op::add_rms_norm_impl::infiniop
#include "infinicore/ops/causal_softmax.hpp" #include "infinicore/ops/causal_softmax.hpp"
#include "../../utils.hpp" #include "../../utils.hpp"
#include <stdexcept>
namespace infinicore::op { namespace infinicore::op {
common::OpDispatcher<CausalSoftmax::schema> &CausalSoftmax::dispatcher() { INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(CausalSoftmax);
static common::OpDispatcher<CausalSoftmax::schema> dispatcher_;
return dispatcher_;
};
void CausalSoftmax::execute(Tensor output, Tensor input) { CausalSoftmax::CausalSoftmax(Tensor output, const Tensor &input) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(output, input); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(output, input);
infinicore::context::setDevice(output->device()); INFINICORE_GRAPH_OP_DISPATCH(output->device().getType(), output, input);
auto device_type = output->device().getType(); }
auto func = dispatcher().lookup(device_type);
if (func == nullptr) {
throw std::runtime_error("No CausalSoftmax implementation found for device type: " + std::to_string(static_cast<int>(device_type)));
}
func(output, input); void CausalSoftmax::execute(Tensor output, const Tensor &input) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(CausalSoftmax, output, input);
} }
Tensor causal_softmax(Tensor input) { Tensor causal_softmax(const Tensor &input) {
Shape shape = input->shape(); auto output = Tensor::empty(input->shape(), input->dtype(), input->device());
auto output = Tensor::empty(shape, input->dtype(), input->device());
causal_softmax_(output, input); causal_softmax_(output, input);
return output; return output;
} }
void causal_softmax_(Tensor output, Tensor input) { void causal_softmax_(Tensor output, const Tensor &input) {
CausalSoftmax::execute(output, input); CausalSoftmax::execute(output, input);
} }
} // namespace infinicore::op } // namespace infinicore::op
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/causal_softmax.hpp" #include "infinicore/ops/causal_softmax.hpp"
#include "infinicore/ops/common/cache.hpp"
#include <infiniop.h> #include "../infiniop_impl.hpp"
namespace infinicore::op::causal_softmax_impl::infiniop { namespace infinicore::op::causal_softmax_impl::infiniop {
thread_local common::OpCache<size_t, infiniopCausalSoftmaxDescriptor_t> caches( INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, CausalSoftmax, 100);
100, // capacity
[](infiniopCausalSoftmaxDescriptor_t &desc) { struct PlannedMeta {
if (desc != nullptr) { std::shared_ptr<Descriptor> descriptor;
INFINICORE_CHECK_ERROR(infiniopDestroyCausalSoftmaxDescriptor(desc)); graph::GraphTensor workspace, output, input;
desc = nullptr; };
}
});
void calculate(Tensor output, Tensor input) { void *plan(Tensor output, const Tensor &input) {
size_t seed = hash_combine(output, input); size_t seed = hash_combine(output, input);
auto device = context::getDevice(); INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
auto &cache = caches.getCache(device); Descriptor, descriptor, CausalSoftmax,
seed, output->desc(), input->desc());
auto desc_opt = cache.get(seed); INFINIOP_WORKSPACE_TENSOR(workspace, CausalSoftmax, descriptor);
infiniopCausalSoftmaxDescriptor_t desc = nullptr;
if (!desc_opt) { return new PlannedMeta{
INFINICORE_CHECK_ERROR(infiniopCreateCausalSoftmaxDescriptor( descriptor,
context::getInfiniopHandle(device), &desc, graph::GraphTensor(workspace),
output->desc(), input->desc())); graph::GraphTensor(output),
cache.put(seed, desc); graph::GraphTensor(input)};
} else { }
desc = *desc_opt;
}
size_t workspace_size = 0; void run(void *planned_meta) {
INFINICORE_CHECK_ERROR(infiniopGetCausalSoftmaxWorkspaceSize(desc, &workspace_size)); auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
INFINICORE_CHECK_ERROR(infiniopCausalSoftmax( INFINICORE_CHECK_ERROR(infiniopCausalSoftmax(
desc, workspace->data(), workspace_size, planned->descriptor->desc,
output->data(), input->data(), context::getStream())); planned->workspace->data(),
planned->workspace->numel(),
planned->output->data(),
planned->input->data(),
context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
} }
static bool registered = []() { INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(CausalSoftmax, &plan, &run, &cleanup);
CausalSoftmax::dispatcher().registerAll(&calculate, false);
return true;
}();
} // namespace infinicore::op::causal_softmax_impl::infiniop } // namespace infinicore::op::causal_softmax_impl::infiniop
#include "infinicore/ops/dequantize_awq.hpp"
#include "../../utils.hpp"
namespace infinicore::op {
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(DequantizeAWQ);
DequantizeAWQ::DequantizeAWQ(Tensor x, const Tensor &x_packed, const Tensor &x_scale, const Tensor &x_zeros) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(x, x_packed, x_scale, x_zeros);
INFINICORE_GRAPH_OP_DISPATCH(x->device().getType(), x, x_packed, x_scale, x_zeros);
}
void DequantizeAWQ::execute(Tensor x, const Tensor &x_packed, const Tensor &x_scale, const Tensor &x_zeros) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(DequantizeAWQ, x, x_packed, x_scale, x_zeros);
}
void dequantize_awq_(Tensor x, const Tensor &x_packed, const Tensor &x_scale, const Tensor &x_zeros) {
DequantizeAWQ::execute(x, x_packed, x_scale, x_zeros);
}
} // namespace infinicore::op
#include "../../utils.hpp"
#include "../infiniop_impl.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/dequantize_awq.hpp"
#include <infiniop.h>
namespace infinicore::op::dequantize_awq_impl::infiniop {
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, DequantizeAWQ, 100);
struct PlannedMeta {
std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor workspace, x, x_packed, x_scale, x_zeros;
};
void *plan(Tensor x, const Tensor &x_packed, const Tensor &x_scale, const Tensor &x_zeros) {
size_t seed = hash_combine(x, x_packed, x_scale, x_zeros);
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
Descriptor, descriptor, DequantizeAWQ,
seed,
x->desc(), x_packed->desc(), x_scale->desc(), x_zeros->desc());
INFINIOP_WORKSPACE_TENSOR(workspace, DequantizeAWQ, descriptor);
return new PlannedMeta{
descriptor,
graph::GraphTensor(workspace),
graph::GraphTensor(x),
graph::GraphTensor(x_packed),
graph::GraphTensor(x_scale),
graph::GraphTensor(x_zeros)};
}
void run(void *planned_meta) {
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
INFINICORE_CHECK_ERROR(infiniopDequantizeAWQ(
planned->descriptor->desc,
planned->workspace->data(),
planned->workspace->numel(),
planned->x->data(),
planned->x_packed->data(),
planned->x_scale->data(),
planned->x_zeros->data(),
context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(DequantizeAWQ, &plan, &run, &cleanup);
} // namespace infinicore::op::dequantize_awq_impl::infiniop
#include "infinicore/ops/distributed/allreduce.hpp"
#include "../../utils.hpp"
namespace infinicore::op::distributed {
struct PlannedMeta {
graph::GraphTensor output, input;
infinicclReduceOp_t op;
infinicclComm_t communicator;
};
AllReduce::AllReduce(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(output, input);
INFINICORE_ASSERT(output->is_contiguous() && input->is_contiguous());
INFINICORE_ASSERT(output->numel() == input->numel());
planned_meta_ = new PlannedMeta{graph::GraphTensor(output), graph::GraphTensor(input), op, communicator};
}
AllReduce::~AllReduce() {
if (planned_meta_) {
PlannedMeta *meta = reinterpret_cast<PlannedMeta *>(planned_meta_);
delete meta;
}
}
void AllReduce::run() const {
PlannedMeta *meta = reinterpret_cast<PlannedMeta *>(planned_meta_);
INFINICORE_CHECK_ERROR(infinicclAllReduce(meta->input->data(),
meta->output->data(),
meta->input->numel(),
static_cast<infiniDtype_t>(static_cast<int>(meta->input->dtype())),
meta->op,
meta->communicator,
infinicore::context::getStream()));
}
void AllReduce::execute(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(AllReduce, output, input, op, communicator);
}
Tensor allreduce(const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator) {
auto output = Tensor::empty(input->shape(), input->dtype(), input->device());
allreduce_(output, input, op, communicator);
return output;
}
void allreduce_(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator) {
AllReduce::execute(output, input, op, communicator);
}
} // namespace infinicore::op::distributed
#include "infinicore/ops/embedding.hpp" #include "infinicore/ops/embedding.hpp"
#include "infinicore/context/context.hpp"
#include <cstring> #include "../../utils.hpp"
namespace infinicore::op { namespace infinicore::op {
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Embedding);
Embedding::Embedding(Tensor out, const Tensor &input, const Tensor &weight) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, input, weight);
INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), out, input, weight);
}
Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the indices to extract void Embedding::execute(Tensor out, const Tensor &input, const Tensor &weight) {
Tensor weight // Weight: Embedding matrix of floating point type with shape (V, embedding_dim), where V = maximum index + 1 INFINICORE_GRAPH_OP_RECORD_OR_RUN(Embedding, out, input, weight);
}
Tensor embedding(const Tensor &input, // LongTensor of arbitrary shape containing the indices to extract
const Tensor &weight // Weight: Embedding matrix of floating point type with shape (V, embedding_dim), where V = maximum index + 1
) { ) {
auto input_shape = input->shape(); auto input_shape = input->shape();
auto weight_shape = weight->shape(); auto weight_shape = weight->shape();
// auto vocab_size = weight_shape[0];
auto embedding_dim = weight_shape[1]; auto embedding_dim = weight_shape[1];
// Assign memory to out variables // Assign memory to out variables
...@@ -21,69 +30,8 @@ Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the i ...@@ -21,69 +30,8 @@ Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the i
return inputs_embeds; return inputs_embeds;
} }
void embedding_(Tensor out, Tensor input, Tensor weight) { void embedding_(Tensor out, const Tensor &input, const Tensor &weight) {
assert(infinicore::DataType::I64 == input->dtype() || (infinicore::DataType::I32 == input->dtype())); Embedding::execute(out, input, weight);
assert(infinicore::Device::Type::CPU == input->device().getType());
auto input_shape = input->shape();
auto weight_shape = weight->shape();
auto embedding_dim = weight_shape[1];
// Calculate the number of token
Size counts = 1;
for (auto &v : input_shape) {
counts *= v;
}
// the bytes of one token
const Size bytes = dsize(weight->dtype()) * embedding_dim;
auto *weight_ptr = weight->data();
auto *out_ptr = out->data();
// copies
if (weight->device().getType() == Device::Type::CPU) {
if (infinicore::DataType::I64 == input->dtype()) {
const int64_t *input_arr = reinterpret_cast<const int64_t *>(input->data());
for (Size i = 0; i < counts; ++i) {
int64_t idx = input_arr[i];
assert((idx >= 0) && (idx < weight_shape[0]));
std::memcpy(out_ptr + i * bytes,
weight_ptr + idx * bytes,
bytes);
}
} else if (infinicore::DataType::I32 == input->dtype()) {
const int32_t *input_arr = reinterpret_cast<const int32_t *>(input->data());
for (Size i = 0; i < counts; ++i) {
int32_t idx = input_arr[i];
assert((idx >= 0) && (idx < weight_shape[0]));
std::memcpy(out_ptr + i * bytes,
weight_ptr + idx * bytes,
bytes);
}
}
} else {
if (infinicore::DataType::I64 == input->dtype()) {
const int64_t *input_arr = reinterpret_cast<const int64_t *>(input->data());
for (Size i = 0; i < counts; ++i) {
int64_t idx = input_arr[i];
assert((idx >= 0) && (idx < weight_shape[0]));
context::memcpyD2D(out_ptr + i * bytes,
weight_ptr + idx * bytes,
bytes);
}
} else if (infinicore::DataType::I32 == input->dtype()) {
const int32_t *input_arr = reinterpret_cast<const int32_t *>(input->data());
for (Size i = 0; i < counts; ++i) {
int32_t idx = input_arr[i];
assert((idx >= 0) && (idx < weight_shape[0]));
context::memcpyD2D(out_ptr + i * bytes,
weight_ptr + idx * bytes,
bytes);
}
}
}
} }
} // namespace infinicore::op } // namespace infinicore::op
#include "../infiniop_impl.hpp"
#include "infinicore/ops/embedding.hpp"
namespace infinicore::op::embedding_impl::infiniop {
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Embedding, 100);
struct PlannedMeta {
std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor out, input, weight;
};
void *plan(Tensor out, const Tensor &input, const Tensor &weight) {
size_t seed = hash_combine(out, input, weight);
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
Descriptor, descriptor, Embedding,
seed, out->desc(), input->desc(), weight->desc());
auto planned = new PlannedMeta{
descriptor,
graph::GraphTensor(out),
graph::GraphTensor(input),
graph::GraphTensor(weight)};
return planned;
}
void run(void *planned_meta) {
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
INFINICORE_CHECK_ERROR(infiniopEmbedding(
planned->descriptor->desc,
planned->out->data(), planned->input->data(), planned->weight->data(), context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Embedding, &plan, &run, cleanup);
} // namespace infinicore::op::embedding_impl::infiniop
#include "infinicore/ops/flash_attention.hpp"
#include "../../utils.hpp"
namespace infinicore::op {
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(FlashAttention);
FlashAttention::FlashAttention(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k, v);
INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(),
out, q, k, v, total_kv_len, scale, is_causal);
}
void FlashAttention::execute(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(FlashAttention, out, q, k, v, total_kv_len, scale, is_causal);
}
Tensor flash_attention(const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) {
Shape shape = q->shape();
int idx = shape.size() - 1;
shape[idx] = v->shape()[idx];
auto out = Tensor::empty(shape, q->dtype(), q->device());
flash_attention_(out, q, k, v, total_kv_len, scale, is_causal);
return out;
}
void flash_attention_(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) {
FlashAttention::execute(out, q, k, v, total_kv_len, scale, is_causal);
}
} // namespace infinicore::op
#include "../../utils.hpp"
#include "../infiniop_impl.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/flash_attention.hpp"
#include <infiniop.h>
namespace infinicore::op::flash_attention_impl::infiniop {
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, FlashAttention, 100);
struct PlannedMeta {
std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor workspace, out, q, k, v, total_kv_len;
float scale;
bool is_causal;
};
void *plan(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) {
size_t seed = hash_combine(out, q, k, v, total_kv_len, scale, is_causal);
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
Descriptor, descriptor, FlashAttention,
seed, out->desc(), q->desc(), k->desc(), v->desc(), total_kv_len->desc(), scale, is_causal);
INFINIOP_WORKSPACE_TENSOR(workspace, FlashAttention, descriptor);
auto planned = new PlannedMeta{
descriptor,
graph::GraphTensor(workspace),
graph::GraphTensor(out),
graph::GraphTensor(q),
graph::GraphTensor(k),
graph::GraphTensor(v),
graph::GraphTensor(total_kv_len), scale, is_causal};
return planned;
}
void run(void *planned_meta) {
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
INFINICORE_CHECK_ERROR(infiniopFlashAttention(
planned->descriptor->desc, planned->workspace->data(), planned->workspace->numel(),
planned->out->data(), planned->q->data(), planned->k->data(), planned->v->data(), planned->total_kv_len->data(), context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(FlashAttention, &plan, &run, &cleanup);
} // namespace infinicore::op::flash_attention_impl::infiniop
...@@ -5,16 +5,16 @@ ...@@ -5,16 +5,16 @@
namespace infinicore::op { namespace infinicore::op {
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Gemm); INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Gemm);
Gemm::Gemm(Tensor c, Tensor a, Tensor b, float alpha, float beta) { Gemm::Gemm(Tensor c, const Tensor &a, const Tensor &b, float alpha, float beta) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b);
INFINICORE_GRAPH_OP_DISPATCH(c->device().getType(), c, a, b, alpha, beta); INFINICORE_GRAPH_OP_DISPATCH(c->device().getType(), c, a, b, alpha, beta);
} }
void Gemm::execute(Tensor c, Tensor a, Tensor b, float alpha, float beta) { void Gemm::execute(Tensor c, const Tensor &a, const Tensor &b, float alpha, float beta) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(Gemm, c, a, b, alpha, beta); INFINICORE_GRAPH_OP_RECORD_OR_RUN(Gemm, c, a, b, alpha, beta);
} }
Tensor gemm(Tensor a, Tensor b, float alpha, float beta) { Tensor gemm(const Tensor &a, const Tensor &b, float alpha, float beta) {
Shape shape = a->shape(); Shape shape = a->shape();
Size size = a->ndim(); Size size = a->ndim();
shape[size - 1] = b->size(size - 1); shape[size - 1] = b->size(size - 1);
...@@ -23,7 +23,7 @@ Tensor gemm(Tensor a, Tensor b, float alpha, float beta) { ...@@ -23,7 +23,7 @@ Tensor gemm(Tensor a, Tensor b, float alpha, float beta) {
return c; return c;
} }
void gemm_(Tensor c, Tensor a, Tensor b, float alpha, float beta) { void gemm_(Tensor c, const Tensor &a, const Tensor &b, float alpha, float beta) {
Gemm::execute(c, a, b, alpha, beta); Gemm::execute(c, a, b, alpha, beta);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment