Commit dbe08e9b authored by yuguo960516yuguo's avatar yuguo960516yuguo
Browse files

2.4.2

parent b5499578
...@@ -28,11 +28,13 @@ MLUContext::MLUContext(const MLUPlace& place, const int priority) { ...@@ -28,11 +28,13 @@ MLUContext::MLUContext(const MLUPlace& place, const int priority) {
MLUDeviceGuard guard(place_.device); MLUDeviceGuard guard(place_.device);
stream_.reset(new stream::MLUStream(place_, priority)); stream_.reset(new stream::MLUStream(place_, priority));
InitCNNLContext(); InitCNNLContext();
InitMLUOPContext();
} }
MLUContext::~MLUContext() { MLUContext::~MLUContext() {
MLUDeviceGuard guard(place_.device); MLUDeviceGuard guard(place_.device);
DestoryCNNLContext(); DestoryCNNLContext();
DestoryMLUOPContext();
} }
MLUDeviceContext::MLUDeviceContext(MLUPlace place) : place_(place) { MLUDeviceContext::MLUDeviceContext(MLUPlace place) : place_(place) {
...@@ -41,6 +43,7 @@ MLUDeviceContext::MLUDeviceContext(MLUPlace place) : place_(place) { ...@@ -41,6 +43,7 @@ MLUDeviceContext::MLUDeviceContext(MLUPlace place) : place_(place) {
driver_version_ = GetMLUDriverVersion(place_.device); driver_version_ = GetMLUDriverVersion(place_.device);
runtime_version_ = GetMLURuntimeVersion(place_.device); runtime_version_ = GetMLURuntimeVersion(place_.device);
cnnl_version_ = GetMLUCnnlVersion(place_.device); cnnl_version_ = GetMLUCnnlVersion(place_.device);
mluOp_version_ = GetMLUOpVersion(place_.device);
LOG_FIRST_N(WARNING, 1) LOG_FIRST_N(WARNING, 1)
<< "Please NOTE: device: " << static_cast<int>(place_.device) << "Please NOTE: device: " << static_cast<int>(place_.device)
...@@ -51,7 +54,9 @@ MLUDeviceContext::MLUDeviceContext(MLUPlace place) : place_(place) { ...@@ -51,7 +54,9 @@ MLUDeviceContext::MLUDeviceContext(MLUPlace place) : place_(place) {
<< ", Runtime API Version: " << runtime_version_ / 10000 << "." << ", Runtime API Version: " << runtime_version_ / 10000 << "."
<< (runtime_version_ / 100) % 100 << "." << runtime_version_ % 100 << (runtime_version_ / 100) % 100 << "." << runtime_version_ % 100
<< ", Cnnl API Version: " << cnnl_version_ / 10000 << "." << ", Cnnl API Version: " << cnnl_version_ / 10000 << "."
<< (cnnl_version_ / 100) % 100 << "." << cnnl_version_ % 100; << (cnnl_version_ / 100) % 100 << "." << cnnl_version_ % 100
<< ", MluOp API Version: " << mluOp_version_ / 10000 << "."
<< (mluOp_version_ / 100) % 100 << "." << mluOp_version_ % 100;
default_ctx_.reset(new MLUContext(place_)); default_ctx_.reset(new MLUContext(place_));
} }
...@@ -70,6 +75,10 @@ mluCnnlHandle MLUDeviceContext::cnnl_handle() const { ...@@ -70,6 +75,10 @@ mluCnnlHandle MLUDeviceContext::cnnl_handle() const {
return context()->CnnlHandle(); return context()->CnnlHandle();
} }
mluOpHandle MLUDeviceContext::mluOp_handle() const {
return context()->MluOpHandle();
}
mluStream MLUDeviceContext::stream() const { return context()->RawStream(); } mluStream MLUDeviceContext::stream() const { return context()->RawStream(); }
#endif #endif
......
...@@ -53,12 +53,19 @@ class MLUContext { ...@@ -53,12 +53,19 @@ class MLUContext {
const mluCnnlHandle& CnnlHandle() const { return cnnl_handle_; } const mluCnnlHandle& CnnlHandle() const { return cnnl_handle_; }
const mluOpHandle& MluOpHandle() const { return mluOp_handle_; }
private: private:
void InitCNNLContext() { void InitCNNLContext() {
PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreate(&cnnl_handle_)); PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreate(&cnnl_handle_));
PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetQueue(cnnl_handle_, RawStream())); PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetQueue(cnnl_handle_, RawStream()));
} }
void InitMLUOPContext() {
PADDLE_ENFORCE_MLU_SUCCESS(mluOpCreate(&mluOp_handle_));
PADDLE_ENFORCE_MLU_SUCCESS(mluOpSetQueue(mluOp_handle_, RawStream()));
}
void DestoryCNNLContext() { void DestoryCNNLContext() {
if (cnnl_handle_) { if (cnnl_handle_) {
PADDLE_ENFORCE_MLU_SUCCESS(cnnlDestroy(cnnl_handle_)); PADDLE_ENFORCE_MLU_SUCCESS(cnnlDestroy(cnnl_handle_));
...@@ -66,10 +73,18 @@ class MLUContext { ...@@ -66,10 +73,18 @@ class MLUContext {
cnnl_handle_ = nullptr; cnnl_handle_ = nullptr;
} }
void DestoryMLUOPContext() {
if (mluOp_handle_) {
PADDLE_ENFORCE_MLU_SUCCESS(mluOpDestroy(mluOp_handle_));
}
mluOp_handle_ = nullptr;
}
MLUPlace place_; MLUPlace place_;
std::unique_ptr<Eigen::DefaultDevice> eigen_device_; std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
std::unique_ptr<stream::MLUStream> stream_; std::unique_ptr<stream::MLUStream> stream_;
mluCnnlHandle cnnl_handle_; mluCnnlHandle cnnl_handle_;
mluOpHandle mluOp_handle_;
DISABLE_COPY_AND_ASSIGN(MLUContext); DISABLE_COPY_AND_ASSIGN(MLUContext);
}; };
...@@ -89,6 +104,9 @@ class MLUDeviceContext : public DeviceContext { ...@@ -89,6 +104,9 @@ class MLUDeviceContext : public DeviceContext {
/*! \brief Return cnnl handle in the device context. */ /*! \brief Return cnnl handle in the device context. */
mluCnnlHandle cnnl_handle() const; mluCnnlHandle cnnl_handle() const;
/*! \brief Return mluOp handle in the device context. */
mluOpHandle mluOp_handle() const;
/*! \brief Return mlu stream in the device context. */ /*! \brief Return mlu stream in the device context. */
mluStream stream() const; mluStream stream() const;
...@@ -135,6 +153,7 @@ class MLUDeviceContext : public DeviceContext { ...@@ -135,6 +153,7 @@ class MLUDeviceContext : public DeviceContext {
int driver_version_; int driver_version_;
int runtime_version_; int runtime_version_;
int cnnl_version_; int cnnl_version_;
int mluOp_version_;
MLUPlace place_; MLUPlace place_;
std::shared_ptr<MLUContext> default_ctx_; std::shared_ptr<MLUContext> default_ctx_;
......
...@@ -41,6 +41,7 @@ struct MLUStatusType {}; ...@@ -41,6 +41,7 @@ struct MLUStatusType {};
DEFINE_MLU_STATUS_TYPE(cnrtStatus, cnrtSuccess, CNRT); DEFINE_MLU_STATUS_TYPE(cnrtStatus, cnrtSuccess, CNRT);
DEFINE_MLU_STATUS_TYPE(cnnlStatus, CNNL_STATUS_SUCCESS, CNNL); DEFINE_MLU_STATUS_TYPE(cnnlStatus, CNNL_STATUS_SUCCESS, CNNL);
DEFINE_MLU_STATUS_TYPE(mluOpStatus, MLUOP_STATUS_SUCCESS, MLUOP);
DEFINE_MLU_STATUS_TYPE(cnStatus, CN_SUCCESS, CN); DEFINE_MLU_STATUS_TYPE(cnStatus, CN_SUCCESS, CN);
#ifdef PADDLE_WITH_CNCL #ifdef PADDLE_WITH_CNCL
DEFINE_MLU_STATUS_TYPE(cnclStatus, CNCL_RET_SUCCESS, CNCL); DEFINE_MLU_STATUS_TYPE(cnclStatus, CNCL_RET_SUCCESS, CNCL);
...@@ -68,6 +69,15 @@ inline std::string build_mlu_error_msg(cnnlStatus stat) { ...@@ -68,6 +69,15 @@ inline std::string build_mlu_error_msg(cnnlStatus stat) {
return sout.str(); return sout.str();
} }
/*************** MLU OP ERROR ***************/
inline bool is_error(mluOpStatus stat) { return stat != MLUOP_STATUS_SUCCESS; }
inline std::string build_mlu_error_msg(mluOpStatus stat) {
std::ostringstream sout;
sout << "MLU OP error(" << stat << "), " << mluOpGetErrorString(stat) << ". ";
return sout.str();
}
/*************** CN API ERROR ***************/ /*************** CN API ERROR ***************/
inline bool is_error(cnStatus stat) { return stat != CN_SUCCESS; } inline bool is_error(cnStatus stat) { return stat != CN_SUCCESS; }
......
...@@ -126,6 +126,13 @@ int GetMLUCnnlVersion(int id) { ...@@ -126,6 +126,13 @@ int GetMLUCnnlVersion(int id) {
return x * 10000 + y * 100 + z; return x * 10000 + y * 100 + z;
} }
int GetMLUOpVersion(int id) {
CheckDeviceId(id);
int x, y, z;
mluOpGetLibVersion(&x, &y, &z);
return x * 10000 + y * 100 + z;
}
int GetMLUCurrentDeviceId() { int GetMLUCurrentDeviceId() {
int device_id; int device_id;
PADDLE_ENFORCE_MLU_SUCCESS(cnrtGetDevice(&device_id)); PADDLE_ENFORCE_MLU_SUCCESS(cnrtGetDevice(&device_id));
......
...@@ -16,10 +16,11 @@ limitations under the License. */ ...@@ -16,10 +16,11 @@ limitations under the License. */
#ifdef PADDLE_WITH_MLU #ifdef PADDLE_WITH_MLU
#include <cn_api.h> #include <cn_api.h>
#include <cndrv_id.h>
#include <cnnl.h> #include <cnnl.h>
#include <cnpapi.h> #include <cnpapi.h>
#include <cnpapi_cndrv_id.h>
#include <cnrt.h> #include <cnrt.h>
#include <mlu_op.h>
#ifdef PADDLE_WITH_CNCL #ifdef PADDLE_WITH_CNCL
#include <cncl.h> #include <cncl.h>
#endif #endif
...@@ -30,11 +31,13 @@ namespace paddle { ...@@ -30,11 +31,13 @@ namespace paddle {
using cnStatus = CNresult; using cnStatus = CNresult;
using cnrtStatus = cnrtRet_t; using cnrtStatus = cnrtRet_t;
using cnnlStatus = cnnlStatus_t; using cnnlStatus = cnnlStatus_t;
using mluOpStatus = mluOpStatus_t;
#ifdef PADDLE_WITH_CNCL #ifdef PADDLE_WITH_CNCL
using cnclStatus = cnclResult_t; using cnclStatus = cnclResult_t;
#endif #endif
using mluStream = cnrtQueue_t; using mluStream = cnrtQueue_t;
using mluCnnlHandle = cnnlHandle_t; using mluCnnlHandle = cnnlHandle_t;
using mluOpHandle = mluOpHandle_t;
using mluEventHandle = cnrtNotifier_t; using mluEventHandle = cnrtNotifier_t;
using mluDeviceHandle = CNdev; using mluDeviceHandle = CNdev;
...@@ -49,6 +52,9 @@ int GetMLURuntimeVersion(int id); ...@@ -49,6 +52,9 @@ int GetMLURuntimeVersion(int id);
//! Get the cnnl version of the ith MLU. //! Get the cnnl version of the ith MLU.
int GetMLUCnnlVersion(int id); int GetMLUCnnlVersion(int id);
//! Get the mluOp version of the ith MLU.
int GetMLUOpVersion(int id);
//! Get the total number of MLU devices in system. //! Get the total number of MLU devices in system.
int GetMLUDeviceCount(); int GetMLUDeviceCount();
......
...@@ -255,7 +255,7 @@ bool CUDADeviceCode::Compile(bool include_path) { ...@@ -255,7 +255,7 @@ bool CUDADeviceCode::Compile(bool include_path) {
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>( auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(
DeviceContextPool::Instance().Get(place_)); DeviceContextPool::Instance().Get(place_));
int compute_capability = dev_ctx->GetComputeCapability(); int compute_capability = dev_ctx->GetComputeCapability();
std::vector<const char*> options = {"-std=c++11", "--amdgpu-target=gfx906"}; std::vector<const char*> options = {"-std=c++11", "--amdgpu-target=gfx906", "--amdgpu-target=gfx926"};
std::string include_option; std::string include_option;
if (include_path) { if (include_path) {
std::string cuda_include_path = FindCUDAIncludePath(); std::string cuda_include_path = FindCUDAIncludePath();
......
...@@ -301,7 +301,8 @@ class MatMulV2MKLDNNHandler ...@@ -301,7 +301,8 @@ class MatMulV2MKLDNNHandler
out_strides[i] = out_ddims[i + 1] * out_strides[i + 1]; out_strides[i] = out_ddims[i + 1] * out_strides[i + 1];
} }
if (!IsInt8<OT>() && !IsBfloat16<OT>() && is_output_fused) { // TODO(jczaja): Why not for int8??
if (!IsInt8<OT>() && is_output_fused) {
out_strides = FakeTransposeStrides(out_ddims); out_strides = FakeTransposeStrides(out_ddims);
} }
......
...@@ -29,7 +29,10 @@ ...@@ -29,7 +29,10 @@
#include "paddle/fluid/platform/profiler/custom_device/custom_tracer.h" #include "paddle/fluid/platform/profiler/custom_device/custom_tracer.h"
#include "paddle/fluid/platform/profiler/extra_info.h" #include "paddle/fluid/platform/profiler/extra_info.h"
#include "paddle/fluid/platform/profiler/host_tracer.h" #include "paddle/fluid/platform/profiler/host_tracer.h"
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/enforce.h"
#include "paddle/fluid/platform/profiler/mlu/mlu_tracer.h" #include "paddle/fluid/platform/profiler/mlu/mlu_tracer.h"
#endif
#include "paddle/fluid/platform/profiler/trace_event_collector.h" #include "paddle/fluid/platform/profiler/trace_event_collector.h"
#include "paddle/fluid/platform/profiler/utils.h" #include "paddle/fluid/platform/profiler/utils.h"
...@@ -80,9 +83,11 @@ Profiler::Profiler(const ProfilerOptions& options, ...@@ -80,9 +83,11 @@ Profiler::Profiler(const ProfilerOptions& options,
if (trace_switch.test(kProfileGPUOptionBit)) { if (trace_switch.test(kProfileGPUOptionBit)) {
tracers_.emplace_back(&CudaTracer::GetInstance(), false); tracers_.emplace_back(&CudaTracer::GetInstance(), false);
} }
#ifdef PADDLE_WITH_MLU
if (trace_switch.test(kProfileMLUOptionBit)) { if (trace_switch.test(kProfileMLUOptionBit)) {
tracers_.emplace_back(&MluTracer::GetInstance(), false); tracers_.emplace_back(&MluTracer::GetInstance(), false);
} }
#endif
if (trace_switch.test(kProfileCustomDeviceOptionBit)) { if (trace_switch.test(kProfileCustomDeviceOptionBit)) {
for (const auto& dev_type : custom_device_types) { for (const auto& dev_type : custom_device_types) {
tracers_.emplace_back(&CustomTracer::GetInstance(dev_type), false); tracers_.emplace_back(&CustomTracer::GetInstance(dev_type), false);
......
pybind.h
op_function1.cc
op_function2.cc
op_function3.cc
op_function4.cc
op_function5.cc
op_function6.cc
op_function7.cc
op_function8.cc
eager_op_function.cc
eager_legacy_op_function.cc
...@@ -26,9 +26,9 @@ static PyObject *eager_api_run_program(PyObject *self, ...@@ -26,9 +26,9 @@ static PyObject *eager_api_run_program(PyObject *self,
PyObject *kwargs) { PyObject *kwargs) {
PyThreadState *tstate = nullptr; PyThreadState *tstate = nullptr;
try { try {
auto X = GetTensorListFromArgs("run_program", "X", args, 0, false); auto X = GetTensorListFromArgs("run_program", "X", args, 0, true);
auto Params = GetTensorListFromArgs("run_program", "Params", args, 1, true); auto Params = GetTensorListFromArgs("run_program", "Params", args, 1, true);
auto Out = GetTensorPtrListFromArgs("run_program", "Out", args, 2, false); auto Out = GetTensorPtrListFromArgs("run_program", "Out", args, 2, true);
auto OutScope = auto OutScope =
GetScopePtrListFromArgs("run_program", "OutScope", args, 3, false); GetScopePtrListFromArgs("run_program", "OutScope", args, 3, false);
auto DOut = GetTensorPtrListFromArgs("run_program", "DOut", args, 4, true); auto DOut = GetTensorPtrListFromArgs("run_program", "DOut", args, 4, true);
......
...@@ -642,7 +642,8 @@ void BindAnalysisConfig(py::module *m) { ...@@ -642,7 +642,8 @@ void BindAnalysisConfig(py::module *m) {
.def("enable_use_gpu", .def("enable_use_gpu",
&AnalysisConfig::EnableUseGpu, &AnalysisConfig::EnableUseGpu,
py::arg("memory_pool_init_size_mb"), py::arg("memory_pool_init_size_mb"),
py::arg("device_id") = 0) py::arg("device_id") = 0,
py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
.def("set_exec_stream", .def("set_exec_stream",
[](AnalysisConfig &self, phi::CUDAStream &stream) { [](AnalysisConfig &self, phi::CUDAStream &stream) {
......
...@@ -472,23 +472,16 @@ void BindTensor(pybind11::module &m) { // NOLINT ...@@ -472,23 +472,16 @@ void BindTensor(pybind11::module &m) { // NOLINT
print(t.shape()) # [5, 30] print(t.shape()) # [5, 30]
)DOC") )DOC")
.def("_to_dlpack", .def("_to_dlpack",
[](framework::Tensor &self) { [](phi::DenseTensor &self) {
DLPackTensor dlpack_tensor(self, 1); DLManagedTensor *dmt = framework::toDLPack(self);
DLManagedTensor *dmt = dlpack_tensor.ToDLManagedTensor(); auto capsule = pybind11::capsule(
auto capsule = py::capsule(
static_cast<void *>(dmt), "dltensor", [](PyObject *ptr) { static_cast<void *>(dmt), "dltensor", [](PyObject *ptr) {
if (ptr) { if (!PyCapsule_IsValid(ptr, "dltensor")) {
auto dltensor = new DLManagedTensor; return;
try {
dltensor = reinterpret_cast<DLManagedTensor *>(
PyCapsule_GetPointer(ptr, "used_dltensor"));
return;
} catch (...) {
dltensor = reinterpret_cast<DLManagedTensor *>(
PyCapsule_GetPointer(ptr, "dltensor"));
}
dltensor->deleter(dltensor);
} }
DLManagedTensor *dmt = static_cast<DLManagedTensor *>(
PyCapsule_GetPointer(ptr, "dltensor"));
dmt->deleter(dmt);
}); });
return capsule; return capsule;
}) })
......
...@@ -69,10 +69,16 @@ bool ProtoArgumentMappingContext::IsDenseTensorInputs( ...@@ -69,10 +69,16 @@ bool ProtoArgumentMappingContext::IsDenseTensorInputs(
return true; return true;
} }
bool ProtoArgumentMappingContext::IsSelectedRowsInputs(
const std::string& name) const {
return false;
}
bool ProtoArgumentMappingContext::IsSelectedRowsInput( bool ProtoArgumentMappingContext::IsSelectedRowsInput(
const std::string& name) const { const std::string& name) const {
return false; return false;
} }
bool ProtoArgumentMappingContext::IsDenseTensorVectorInput( bool ProtoArgumentMappingContext::IsDenseTensorVectorInput(
const std::string& name) const { const std::string& name) const {
return false; return false;
......
...@@ -45,6 +45,7 @@ class ProtoArgumentMappingContext : public ::phi::ArgumentMappingContext { ...@@ -45,6 +45,7 @@ class ProtoArgumentMappingContext : public ::phi::ArgumentMappingContext {
bool IsDenseTensorInput(const std::string& name) const override; bool IsDenseTensorInput(const std::string& name) const override;
bool IsDenseTensorInputs(const std::string& name) const override; bool IsDenseTensorInputs(const std::string& name) const override;
bool IsSelectedRowsInput(const std::string& name) const override; bool IsSelectedRowsInput(const std::string& name) const override;
bool IsSelectedRowsInputs(const std::string& name) const override;
bool IsDenseTensorVectorInput(const std::string& name) const override; bool IsDenseTensorVectorInput(const std::string& name) const override;
bool IsDenseTensorOutput(const std::string& name) const override; bool IsDenseTensorOutput(const std::string& name) const override;
......
.DS_Store
.idea
*.log
tmp/
Output
.DS_Store
.idea
*.log
tmp/
tensor_map.mlir
...@@ -34,6 +34,95 @@ namespace experimental { ...@@ -34,6 +34,95 @@ namespace experimental {
////////////////// Forward api impls ////////////////////// ////////////////// Forward api impls //////////////////////
Tensor add_n_impl(const std::vector<Tensor>& x) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}
bool is_sr_kernel = true;
for (auto& input : x) {
if (phi::DenseTensor::classof(input.impl().get())) {
is_sr_kernel = false;
break;
}
}
const std::string kernel_name = (is_sr_kernel ? "add_n_sr" : "add_n");
VLOG(6) << "add_n API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name, {kernel_backend, kernel_layout, kernel_data_type});
const auto& kernel = kernel_result.kernel;
VLOG(6) << kernel_name << " kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(
kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend);
Tensor api_output;
if (is_sr_kernel) {
std::vector<const phi::SelectedRows*> input_x(x.size());
for (size_t i = 0; i < input_x.size(); ++i) {
input_x[i] = static_cast<phi::SelectedRows*>(x[i].impl().get());
}
auto x_meta_vec = MakeMetaTensor(input_x);
std::vector<const phi::MetaTensor*> x_metas(x_meta_vec.size());
for (size_t i = 0; i < x_meta_vec.size(); ++i) {
x_metas[i] = &x_meta_vec[i];
}
auto kernel_out = SetSelectedRowsKernelOutput(&api_output);
phi::MetaTensor meta_out(kernel_out);
phi::AddNInferMeta(x_metas, &meta_out);
using kernel_signature =
void (*)(const platform::DeviceContext&,
const std::vector<const phi::SelectedRows*>&,
phi::SelectedRows*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, input_x, kernel_out);
} else {
std::vector<const phi::TensorBase*> input_x(x.size());
for (size_t i = 0; i < input_x.size(); ++i) {
input_x[i] = x[i].impl().get();
}
auto x_meta_vec = MakeMetaTensor(input_x);
std::vector<const phi::MetaTensor*> x_metas(x_meta_vec.size());
for (size_t i = 0; i < x_meta_vec.size(); ++i) {
x_metas[i] = &x_meta_vec[i];
}
auto kernel_out = SetKernelOutput(&api_output);
phi::MetaTensor meta_out(kernel_out);
phi::AddNInferMeta(x_metas, &meta_out);
using kernel_signature =
void (*)(const platform::DeviceContext&,
const std::vector<const phi::TensorBase*>&,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, input_x, kernel_out);
}
return api_output;
}
Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) { Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) {
Tensor out; Tensor out;
copy(x, place, blocking, &out); copy(x, place, blocking, &out);
......
...@@ -31,6 +31,8 @@ namespace experimental { ...@@ -31,6 +31,8 @@ namespace experimental {
////////////////// Forward api impls ////////////////////// ////////////////// Forward api impls //////////////////////
Tensor add_n_impl(const std::vector<Tensor>& x);
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_impl( std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_impl(
const Tensor& x, const Tensor& x,
const Tensor& scale, const Tensor& scale,
......
...@@ -98,6 +98,16 @@ phi::MetaTensor MakeMetaTensor(const phi::TensorBase& tensor) { ...@@ -98,6 +98,16 @@ phi::MetaTensor MakeMetaTensor(const phi::TensorBase& tensor) {
return phi::MetaTensor(tensor); return phi::MetaTensor(tensor);
} }
std::vector<phi::MetaTensor> MakeMetaTensor(
const std::vector<const phi::TensorBase*>& tensors) {
std::vector<phi::MetaTensor> meta_tensors;
meta_tensors.reserve(tensors.size());
for (const auto* t : tensors) {
meta_tensors.emplace_back(*t);
}
return meta_tensors;
}
phi::MetaTensor MakeMetaTensor( phi::MetaTensor MakeMetaTensor(
const paddle::optional<phi::DenseTensor>& tensor) { const paddle::optional<phi::DenseTensor>& tensor) {
if (tensor) { if (tensor) {
...@@ -116,6 +126,16 @@ std::vector<phi::MetaTensor> MakeMetaTensor( ...@@ -116,6 +126,16 @@ std::vector<phi::MetaTensor> MakeMetaTensor(
return meta_tensors; return meta_tensors;
} }
std::vector<phi::MetaTensor> MakeMetaTensor(
const std::vector<const phi::SelectedRows*>& tensors) {
std::vector<phi::MetaTensor> meta_tensors;
meta_tensors.reserve(tensors.size());
for (const auto* t : tensors) {
meta_tensors.emplace_back(*t);
}
return meta_tensors;
}
std::vector<phi::MetaTensor> MakeMetaTensor( std::vector<phi::MetaTensor> MakeMetaTensor(
const std::vector<phi::DenseTensor*>& tensors) { const std::vector<phi::DenseTensor*>& tensors) {
std::vector<phi::MetaTensor> meta_tensors; std::vector<phi::MetaTensor> meta_tensors;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment