Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
dbe08e9b
Commit
dbe08e9b
authored
Jun 12, 2023
by
yuguo960516yuguo
Browse files
2.4.2
parent
b5499578
Changes
302
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
215 additions
and
22 deletions
+215
-22
paddle/fluid/platform/device/mlu/device_context.cc
paddle/fluid/platform/device/mlu/device_context.cc
+10
-1
paddle/fluid/platform/device/mlu/device_context.h
paddle/fluid/platform/device/mlu/device_context.h
+19
-0
paddle/fluid/platform/device/mlu/enforce.h
paddle/fluid/platform/device/mlu/enforce.h
+10
-0
paddle/fluid/platform/device/mlu/mlu_info.cc
paddle/fluid/platform/device/mlu/mlu_info.cc
+7
-0
paddle/fluid/platform/device/mlu/mlu_info.h
paddle/fluid/platform/device/mlu/mlu_info.h
+7
-1
paddle/fluid/platform/device_code.cc
paddle/fluid/platform/device_code.cc
+1
-1
paddle/fluid/platform/mkldnn_reuse.h
paddle/fluid/platform/mkldnn_reuse.h
+2
-1
paddle/fluid/platform/profiler/profiler.cc
paddle/fluid/platform/profiler/profiler.cc
+5
-0
paddle/fluid/pybind/.gitignore
paddle/fluid/pybind/.gitignore
+11
-0
paddle/fluid/pybind/eager_legacy_custom_python_api.h
paddle/fluid/pybind/eager_legacy_custom_python_api.h
+2
-2
paddle/fluid/pybind/inference_api.cc
paddle/fluid/pybind/inference_api.cc
+2
-1
paddle/fluid/pybind/tensor.cc
paddle/fluid/pybind/tensor.cc
+8
-15
paddle/infrt/api/.gitignore
paddle/infrt/api/.gitignore
+1
-0
paddle/infrt/dialect/phi/pass/proto_arg_map_context.cc
paddle/infrt/dialect/phi/pass/proto_arg_map_context.cc
+6
-0
paddle/infrt/dialect/phi/pass/proto_arg_map_context.h
paddle/infrt/dialect/phi/pass/proto_arg_map_context.h
+1
-0
paddle/infrt/tests/.gitignore
paddle/infrt/tests/.gitignore
+7
-0
paddle/infrt/tests/dialect/tensor/.gitignore
paddle/infrt/tests/dialect/tensor/.gitignore
+5
-0
paddle/phi/api/lib/api_custom_impl.cc
paddle/phi/api/lib/api_custom_impl.cc
+89
-0
paddle/phi/api/lib/api_custom_impl.h
paddle/phi/api/lib/api_custom_impl.h
+2
-0
paddle/phi/api/lib/api_gen_utils.cc
paddle/phi/api/lib/api_gen_utils.cc
+20
-0
No files found.
paddle/fluid/platform/device/mlu/device_context.cc
View file @
dbe08e9b
...
...
@@ -28,11 +28,13 @@ MLUContext::MLUContext(const MLUPlace& place, const int priority) {
MLUDeviceGuard
guard
(
place_
.
device
);
stream_
.
reset
(
new
stream
::
MLUStream
(
place_
,
priority
));
InitCNNLContext
();
InitMLUOPContext
();
}
MLUContext
::~
MLUContext
()
{
MLUDeviceGuard
guard
(
place_
.
device
);
DestoryCNNLContext
();
DestoryMLUOPContext
();
}
MLUDeviceContext
::
MLUDeviceContext
(
MLUPlace
place
)
:
place_
(
place
)
{
...
...
@@ -41,6 +43,7 @@ MLUDeviceContext::MLUDeviceContext(MLUPlace place) : place_(place) {
driver_version_
=
GetMLUDriverVersion
(
place_
.
device
);
runtime_version_
=
GetMLURuntimeVersion
(
place_
.
device
);
cnnl_version_
=
GetMLUCnnlVersion
(
place_
.
device
);
mluOp_version_
=
GetMLUOpVersion
(
place_
.
device
);
LOG_FIRST_N
(
WARNING
,
1
)
<<
"Please NOTE: device: "
<<
static_cast
<
int
>
(
place_
.
device
)
...
...
@@ -51,7 +54,9 @@ MLUDeviceContext::MLUDeviceContext(MLUPlace place) : place_(place) {
<<
", Runtime API Version: "
<<
runtime_version_
/
10000
<<
"."
<<
(
runtime_version_
/
100
)
%
100
<<
"."
<<
runtime_version_
%
100
<<
", Cnnl API Version: "
<<
cnnl_version_
/
10000
<<
"."
<<
(
cnnl_version_
/
100
)
%
100
<<
"."
<<
cnnl_version_
%
100
;
<<
(
cnnl_version_
/
100
)
%
100
<<
"."
<<
cnnl_version_
%
100
<<
", MluOp API Version: "
<<
mluOp_version_
/
10000
<<
"."
<<
(
mluOp_version_
/
100
)
%
100
<<
"."
<<
mluOp_version_
%
100
;
default_ctx_
.
reset
(
new
MLUContext
(
place_
));
}
...
...
@@ -70,6 +75,10 @@ mluCnnlHandle MLUDeviceContext::cnnl_handle() const {
return
context
()
->
CnnlHandle
();
}
mluOpHandle
MLUDeviceContext
::
mluOp_handle
()
const
{
return
context
()
->
MluOpHandle
();
}
mluStream
MLUDeviceContext
::
stream
()
const
{
return
context
()
->
RawStream
();
}
#endif
...
...
paddle/fluid/platform/device/mlu/device_context.h
View file @
dbe08e9b
...
...
@@ -53,12 +53,19 @@ class MLUContext {
const
mluCnnlHandle
&
CnnlHandle
()
const
{
return
cnnl_handle_
;
}
const
mluOpHandle
&
MluOpHandle
()
const
{
return
mluOp_handle_
;
}
private:
void
InitCNNLContext
()
{
PADDLE_ENFORCE_MLU_SUCCESS
(
cnnlCreate
(
&
cnnl_handle_
));
PADDLE_ENFORCE_MLU_SUCCESS
(
cnnlSetQueue
(
cnnl_handle_
,
RawStream
()));
}
void
InitMLUOPContext
()
{
PADDLE_ENFORCE_MLU_SUCCESS
(
mluOpCreate
(
&
mluOp_handle_
));
PADDLE_ENFORCE_MLU_SUCCESS
(
mluOpSetQueue
(
mluOp_handle_
,
RawStream
()));
}
void
DestoryCNNLContext
()
{
if
(
cnnl_handle_
)
{
PADDLE_ENFORCE_MLU_SUCCESS
(
cnnlDestroy
(
cnnl_handle_
));
...
...
@@ -66,10 +73,18 @@ class MLUContext {
cnnl_handle_
=
nullptr
;
}
void
DestoryMLUOPContext
()
{
if
(
mluOp_handle_
)
{
PADDLE_ENFORCE_MLU_SUCCESS
(
mluOpDestroy
(
mluOp_handle_
));
}
mluOp_handle_
=
nullptr
;
}
MLUPlace
place_
;
std
::
unique_ptr
<
Eigen
::
DefaultDevice
>
eigen_device_
;
std
::
unique_ptr
<
stream
::
MLUStream
>
stream_
;
mluCnnlHandle
cnnl_handle_
;
mluOpHandle
mluOp_handle_
;
DISABLE_COPY_AND_ASSIGN
(
MLUContext
);
};
...
...
@@ -89,6 +104,9 @@ class MLUDeviceContext : public DeviceContext {
/*! \brief Return cnnl handle in the device context. */
mluCnnlHandle
cnnl_handle
()
const
;
/*! \brief Return mluOp handle in the device context. */
mluOpHandle
mluOp_handle
()
const
;
/*! \brief Return mlu stream in the device context. */
mluStream
stream
()
const
;
...
...
@@ -135,6 +153,7 @@ class MLUDeviceContext : public DeviceContext {
int
driver_version_
;
int
runtime_version_
;
int
cnnl_version_
;
int
mluOp_version_
;
MLUPlace
place_
;
std
::
shared_ptr
<
MLUContext
>
default_ctx_
;
...
...
paddle/fluid/platform/device/mlu/enforce.h
View file @
dbe08e9b
...
...
@@ -41,6 +41,7 @@ struct MLUStatusType {};
DEFINE_MLU_STATUS_TYPE
(
cnrtStatus
,
cnrtSuccess
,
CNRT
);
DEFINE_MLU_STATUS_TYPE
(
cnnlStatus
,
CNNL_STATUS_SUCCESS
,
CNNL
);
DEFINE_MLU_STATUS_TYPE
(
mluOpStatus
,
MLUOP_STATUS_SUCCESS
,
MLUOP
);
DEFINE_MLU_STATUS_TYPE
(
cnStatus
,
CN_SUCCESS
,
CN
);
#ifdef PADDLE_WITH_CNCL
DEFINE_MLU_STATUS_TYPE
(
cnclStatus
,
CNCL_RET_SUCCESS
,
CNCL
);
...
...
@@ -68,6 +69,15 @@ inline std::string build_mlu_error_msg(cnnlStatus stat) {
return
sout
.
str
();
}
/*************** MLU OP ERROR ***************/
inline
bool
is_error
(
mluOpStatus
stat
)
{
return
stat
!=
MLUOP_STATUS_SUCCESS
;
}
inline
std
::
string
build_mlu_error_msg
(
mluOpStatus
stat
)
{
std
::
ostringstream
sout
;
sout
<<
"MLU OP error("
<<
stat
<<
"), "
<<
mluOpGetErrorString
(
stat
)
<<
". "
;
return
sout
.
str
();
}
/*************** CN API ERROR ***************/
inline
bool
is_error
(
cnStatus
stat
)
{
return
stat
!=
CN_SUCCESS
;
}
...
...
paddle/fluid/platform/device/mlu/mlu_info.cc
View file @
dbe08e9b
...
...
@@ -126,6 +126,13 @@ int GetMLUCnnlVersion(int id) {
return
x
*
10000
+
y
*
100
+
z
;
}
int
GetMLUOpVersion
(
int
id
)
{
CheckDeviceId
(
id
);
int
x
,
y
,
z
;
mluOpGetLibVersion
(
&
x
,
&
y
,
&
z
);
return
x
*
10000
+
y
*
100
+
z
;
}
int
GetMLUCurrentDeviceId
()
{
int
device_id
;
PADDLE_ENFORCE_MLU_SUCCESS
(
cnrtGetDevice
(
&
device_id
));
...
...
paddle/fluid/platform/device/mlu/mlu_info.h
View file @
dbe08e9b
...
...
@@ -16,10 +16,11 @@ limitations under the License. */
#ifdef PADDLE_WITH_MLU
#include <cn_api.h>
#include <cndrv_id.h>
#include <cnnl.h>
#include <cnpapi.h>
#include <cnpapi_cndrv_id.h>
#include <cnrt.h>
#include <mlu_op.h>
#ifdef PADDLE_WITH_CNCL
#include <cncl.h>
#endif
...
...
@@ -30,11 +31,13 @@ namespace paddle {
using
cnStatus
=
CNresult
;
using
cnrtStatus
=
cnrtRet_t
;
using
cnnlStatus
=
cnnlStatus_t
;
using
mluOpStatus
=
mluOpStatus_t
;
#ifdef PADDLE_WITH_CNCL
using
cnclStatus
=
cnclResult_t
;
#endif
using
mluStream
=
cnrtQueue_t
;
using
mluCnnlHandle
=
cnnlHandle_t
;
using
mluOpHandle
=
mluOpHandle_t
;
using
mluEventHandle
=
cnrtNotifier_t
;
using
mluDeviceHandle
=
CNdev
;
...
...
@@ -49,6 +52,9 @@ int GetMLURuntimeVersion(int id);
//! Get the cnnl version of the ith MLU.
int
GetMLUCnnlVersion
(
int
id
);
//! Get the mluOp version of the ith MLU.
int
GetMLUOpVersion
(
int
id
);
//! Get the total number of MLU devices in system.
int
GetMLUDeviceCount
();
...
...
paddle/fluid/platform/device_code.cc
View file @
dbe08e9b
...
...
@@ -255,7 +255,7 @@ bool CUDADeviceCode::Compile(bool include_path) {
auto
*
dev_ctx
=
reinterpret_cast
<
phi
::
GPUContext
*>
(
DeviceContextPool
::
Instance
().
Get
(
place_
));
int
compute_capability
=
dev_ctx
->
GetComputeCapability
();
std
::
vector
<
const
char
*>
options
=
{
"-std=c++11"
,
"--amdgpu-target=gfx906"
};
std
::
vector
<
const
char
*>
options
=
{
"-std=c++11"
,
"--amdgpu-target=gfx906"
,
"--amdgpu-target=gfx926"
};
std
::
string
include_option
;
if
(
include_path
)
{
std
::
string
cuda_include_path
=
FindCUDAIncludePath
();
...
...
paddle/fluid/platform/mkldnn_reuse.h
View file @
dbe08e9b
...
...
@@ -301,7 +301,8 @@ class MatMulV2MKLDNNHandler
out_strides
[
i
]
=
out_ddims
[
i
+
1
]
*
out_strides
[
i
+
1
];
}
if
(
!
IsInt8
<
OT
>
()
&&
!
IsBfloat16
<
OT
>
()
&&
is_output_fused
)
{
// TODO(jczaja): Why not for int8??
if
(
!
IsInt8
<
OT
>
()
&&
is_output_fused
)
{
out_strides
=
FakeTransposeStrides
(
out_ddims
);
}
...
...
paddle/fluid/platform/profiler/profiler.cc
View file @
dbe08e9b
...
...
@@ -29,7 +29,10 @@
#include "paddle/fluid/platform/profiler/custom_device/custom_tracer.h"
#include "paddle/fluid/platform/profiler/extra_info.h"
#include "paddle/fluid/platform/profiler/host_tracer.h"
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/enforce.h"
#include "paddle/fluid/platform/profiler/mlu/mlu_tracer.h"
#endif
#include "paddle/fluid/platform/profiler/trace_event_collector.h"
#include "paddle/fluid/platform/profiler/utils.h"
...
...
@@ -80,9 +83,11 @@ Profiler::Profiler(const ProfilerOptions& options,
if
(
trace_switch
.
test
(
kProfileGPUOptionBit
))
{
tracers_
.
emplace_back
(
&
CudaTracer
::
GetInstance
(),
false
);
}
#ifdef PADDLE_WITH_MLU
if
(
trace_switch
.
test
(
kProfileMLUOptionBit
))
{
tracers_
.
emplace_back
(
&
MluTracer
::
GetInstance
(),
false
);
}
#endif
if
(
trace_switch
.
test
(
kProfileCustomDeviceOptionBit
))
{
for
(
const
auto
&
dev_type
:
custom_device_types
)
{
tracers_
.
emplace_back
(
&
CustomTracer
::
GetInstance
(
dev_type
),
false
);
...
...
paddle/fluid/pybind/.gitignore
0 → 100644
View file @
dbe08e9b
pybind.h
op_function1.cc
op_function2.cc
op_function3.cc
op_function4.cc
op_function5.cc
op_function6.cc
op_function7.cc
op_function8.cc
eager_op_function.cc
eager_legacy_op_function.cc
paddle/fluid/pybind/eager_legacy_custom_python_api.h
View file @
dbe08e9b
...
...
@@ -26,9 +26,9 @@ static PyObject *eager_api_run_program(PyObject *self,
PyObject
*
kwargs
)
{
PyThreadState
*
tstate
=
nullptr
;
try
{
auto
X
=
GetTensorListFromArgs
(
"run_program"
,
"X"
,
args
,
0
,
fals
e
);
auto
X
=
GetTensorListFromArgs
(
"run_program"
,
"X"
,
args
,
0
,
tru
e
);
auto
Params
=
GetTensorListFromArgs
(
"run_program"
,
"Params"
,
args
,
1
,
true
);
auto
Out
=
GetTensorPtrListFromArgs
(
"run_program"
,
"Out"
,
args
,
2
,
fals
e
);
auto
Out
=
GetTensorPtrListFromArgs
(
"run_program"
,
"Out"
,
args
,
2
,
tru
e
);
auto
OutScope
=
GetScopePtrListFromArgs
(
"run_program"
,
"OutScope"
,
args
,
3
,
false
);
auto
DOut
=
GetTensorPtrListFromArgs
(
"run_program"
,
"DOut"
,
args
,
4
,
true
);
...
...
paddle/fluid/pybind/inference_api.cc
View file @
dbe08e9b
...
...
@@ -642,7 +642,8 @@ void BindAnalysisConfig(py::module *m) {
.
def
(
"enable_use_gpu"
,
&
AnalysisConfig
::
EnableUseGpu
,
py
::
arg
(
"memory_pool_init_size_mb"
),
py
::
arg
(
"device_id"
)
=
0
)
py
::
arg
(
"device_id"
)
=
0
,
py
::
arg
(
"precision_mode"
)
=
AnalysisConfig
::
Precision
::
kFloat32
)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
.
def
(
"set_exec_stream"
,
[](
AnalysisConfig
&
self
,
phi
::
CUDAStream
&
stream
)
{
...
...
paddle/fluid/pybind/tensor.cc
View file @
dbe08e9b
...
...
@@ -472,23 +472,16 @@ void BindTensor(pybind11::module &m) { // NOLINT
print(t.shape()) # [5, 30]
)DOC"
)
.
def
(
"_to_dlpack"
,
[](
framework
::
Tensor
&
self
)
{
DLPackTensor
dlpack_tensor
(
self
,
1
);
DLManagedTensor
*
dmt
=
dlpack_tensor
.
ToDLManagedTensor
();
auto
capsule
=
py
::
capsule
(
[](
phi
::
DenseTensor
&
self
)
{
DLManagedTensor
*
dmt
=
framework
::
toDLPack
(
self
);
auto
capsule
=
pybind11
::
capsule
(
static_cast
<
void
*>
(
dmt
),
"dltensor"
,
[](
PyObject
*
ptr
)
{
if
(
ptr
)
{
auto
dltensor
=
new
DLManagedTensor
;
try
{
dltensor
=
reinterpret_cast
<
DLManagedTensor
*>
(
PyCapsule_GetPointer
(
ptr
,
"used_dltensor"
));
return
;
}
catch
(...)
{
dltensor
=
reinterpret_cast
<
DLManagedTensor
*>
(
PyCapsule_GetPointer
(
ptr
,
"dltensor"
));
}
dltensor
->
deleter
(
dltensor
);
if
(
!
PyCapsule_IsValid
(
ptr
,
"dltensor"
))
{
return
;
}
DLManagedTensor
*
dmt
=
static_cast
<
DLManagedTensor
*>
(
PyCapsule_GetPointer
(
ptr
,
"dltensor"
));
dmt
->
deleter
(
dmt
);
});
return
capsule
;
})
...
...
paddle/infrt/api/.gitignore
0 → 100644
View file @
dbe08e9b
infrt_api_test.cc
paddle/infrt/dialect/phi/pass/proto_arg_map_context.cc
View file @
dbe08e9b
...
...
@@ -69,10 +69,16 @@ bool ProtoArgumentMappingContext::IsDenseTensorInputs(
return
true
;
}
bool
ProtoArgumentMappingContext
::
IsSelectedRowsInputs
(
const
std
::
string
&
name
)
const
{
return
false
;
}
bool
ProtoArgumentMappingContext
::
IsSelectedRowsInput
(
const
std
::
string
&
name
)
const
{
return
false
;
}
bool
ProtoArgumentMappingContext
::
IsDenseTensorVectorInput
(
const
std
::
string
&
name
)
const
{
return
false
;
...
...
paddle/infrt/dialect/phi/pass/proto_arg_map_context.h
View file @
dbe08e9b
...
...
@@ -45,6 +45,7 @@ class ProtoArgumentMappingContext : public ::phi::ArgumentMappingContext {
bool
IsDenseTensorInput
(
const
std
::
string
&
name
)
const
override
;
bool
IsDenseTensorInputs
(
const
std
::
string
&
name
)
const
override
;
bool
IsSelectedRowsInput
(
const
std
::
string
&
name
)
const
override
;
bool
IsSelectedRowsInputs
(
const
std
::
string
&
name
)
const
override
;
bool
IsDenseTensorVectorInput
(
const
std
::
string
&
name
)
const
override
;
bool
IsDenseTensorOutput
(
const
std
::
string
&
name
)
const
override
;
...
...
paddle/infrt/tests/.gitignore
0 → 100644
View file @
dbe08e9b
.DS_Store
.idea
*.log
tmp/
Output
paddle/infrt/tests/dialect/tensor/.gitignore
0 → 100644
View file @
dbe08e9b
.DS_Store
.idea
*.log
tmp/
tensor_map.mlir
paddle/phi/api/lib/api_custom_impl.cc
View file @
dbe08e9b
...
...
@@ -34,6 +34,95 @@ namespace experimental {
////////////////// Forward api impls //////////////////////
Tensor
add_n_impl
(
const
std
::
vector
<
Tensor
>&
x
)
{
Backend
kernel_backend
=
Backend
::
UNDEFINED
;
DataLayout
kernel_layout
=
DataLayout
::
UNDEFINED
;
DataType
kernel_data_type
=
DataType
::
UNDEFINED
;
if
(
kernel_backend
==
Backend
::
UNDEFINED
||
kernel_layout
==
DataLayout
::
UNDEFINED
||
kernel_data_type
==
DataType
::
UNDEFINED
)
{
auto
kernel_key_set
=
ParseKernelKeyByInputArgs
(
x
);
auto
kernel_key
=
kernel_key_set
.
GetHighestPriorityKernelKey
();
if
(
kernel_backend
==
Backend
::
UNDEFINED
)
{
kernel_backend
=
kernel_key
.
backend
();
}
if
(
kernel_layout
==
DataLayout
::
UNDEFINED
)
{
kernel_layout
=
kernel_key
.
layout
();
}
if
(
kernel_data_type
==
DataType
::
UNDEFINED
)
{
kernel_data_type
=
kernel_key
.
dtype
();
}
}
bool
is_sr_kernel
=
true
;
for
(
auto
&
input
:
x
)
{
if
(
phi
::
DenseTensor
::
classof
(
input
.
impl
().
get
()))
{
is_sr_kernel
=
false
;
break
;
}
}
const
std
::
string
kernel_name
=
(
is_sr_kernel
?
"add_n_sr"
:
"add_n"
);
VLOG
(
6
)
<<
"add_n API kernel key: ["
<<
kernel_backend
<<
", "
<<
kernel_layout
<<
", "
<<
kernel_data_type
<<
"]"
;
auto
kernel_result
=
phi
::
KernelFactory
::
Instance
().
SelectKernelOrThrowError
(
kernel_name
,
{
kernel_backend
,
kernel_layout
,
kernel_data_type
});
const
auto
&
kernel
=
kernel_result
.
kernel
;
VLOG
(
6
)
<<
kernel_name
<<
" kernel: "
<<
kernel
;
auto
*
dev_ctx
=
GetDeviceContextByBackend
(
kernel_result
.
has_fallback_cpu
?
Backend
::
CPU
:
kernel_backend
);
Tensor
api_output
;
if
(
is_sr_kernel
)
{
std
::
vector
<
const
phi
::
SelectedRows
*>
input_x
(
x
.
size
());
for
(
size_t
i
=
0
;
i
<
input_x
.
size
();
++
i
)
{
input_x
[
i
]
=
static_cast
<
phi
::
SelectedRows
*>
(
x
[
i
].
impl
().
get
());
}
auto
x_meta_vec
=
MakeMetaTensor
(
input_x
);
std
::
vector
<
const
phi
::
MetaTensor
*>
x_metas
(
x_meta_vec
.
size
());
for
(
size_t
i
=
0
;
i
<
x_meta_vec
.
size
();
++
i
)
{
x_metas
[
i
]
=
&
x_meta_vec
[
i
];
}
auto
kernel_out
=
SetSelectedRowsKernelOutput
(
&
api_output
);
phi
::
MetaTensor
meta_out
(
kernel_out
);
phi
::
AddNInferMeta
(
x_metas
,
&
meta_out
);
using
kernel_signature
=
void
(
*
)(
const
platform
::
DeviceContext
&
,
const
std
::
vector
<
const
phi
::
SelectedRows
*>&
,
phi
::
SelectedRows
*
);
auto
*
kernel_fn
=
kernel
.
GetVariadicKernelFn
<
kernel_signature
>
();
(
*
kernel_fn
)(
*
dev_ctx
,
input_x
,
kernel_out
);
}
else
{
std
::
vector
<
const
phi
::
TensorBase
*>
input_x
(
x
.
size
());
for
(
size_t
i
=
0
;
i
<
input_x
.
size
();
++
i
)
{
input_x
[
i
]
=
x
[
i
].
impl
().
get
();
}
auto
x_meta_vec
=
MakeMetaTensor
(
input_x
);
std
::
vector
<
const
phi
::
MetaTensor
*>
x_metas
(
x_meta_vec
.
size
());
for
(
size_t
i
=
0
;
i
<
x_meta_vec
.
size
();
++
i
)
{
x_metas
[
i
]
=
&
x_meta_vec
[
i
];
}
auto
kernel_out
=
SetKernelOutput
(
&
api_output
);
phi
::
MetaTensor
meta_out
(
kernel_out
);
phi
::
AddNInferMeta
(
x_metas
,
&
meta_out
);
using
kernel_signature
=
void
(
*
)(
const
platform
::
DeviceContext
&
,
const
std
::
vector
<
const
phi
::
TensorBase
*>&
,
phi
::
DenseTensor
*
);
auto
*
kernel_fn
=
kernel
.
GetVariadicKernelFn
<
kernel_signature
>
();
(
*
kernel_fn
)(
*
dev_ctx
,
input_x
,
kernel_out
);
}
return
api_output
;
}
Tensor
copy_to_impl
(
const
Tensor
&
x
,
Place
place
,
bool
blocking
)
{
Tensor
out
;
copy
(
x
,
place
,
blocking
,
&
out
);
...
...
paddle/phi/api/lib/api_custom_impl.h
View file @
dbe08e9b
...
...
@@ -31,6 +31,8 @@ namespace experimental {
////////////////// Forward api impls //////////////////////
Tensor
add_n_impl
(
const
std
::
vector
<
Tensor
>&
x
);
std
::
tuple
<
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
>
batch_norm_impl
(
const
Tensor
&
x
,
const
Tensor
&
scale
,
...
...
paddle/phi/api/lib/api_gen_utils.cc
View file @
dbe08e9b
...
...
@@ -98,6 +98,16 @@ phi::MetaTensor MakeMetaTensor(const phi::TensorBase& tensor) {
return
phi
::
MetaTensor
(
tensor
);
}
std
::
vector
<
phi
::
MetaTensor
>
MakeMetaTensor
(
const
std
::
vector
<
const
phi
::
TensorBase
*>&
tensors
)
{
std
::
vector
<
phi
::
MetaTensor
>
meta_tensors
;
meta_tensors
.
reserve
(
tensors
.
size
());
for
(
const
auto
*
t
:
tensors
)
{
meta_tensors
.
emplace_back
(
*
t
);
}
return
meta_tensors
;
}
phi
::
MetaTensor
MakeMetaTensor
(
const
paddle
::
optional
<
phi
::
DenseTensor
>&
tensor
)
{
if
(
tensor
)
{
...
...
@@ -116,6 +126,16 @@ std::vector<phi::MetaTensor> MakeMetaTensor(
return
meta_tensors
;
}
std
::
vector
<
phi
::
MetaTensor
>
MakeMetaTensor
(
const
std
::
vector
<
const
phi
::
SelectedRows
*>&
tensors
)
{
std
::
vector
<
phi
::
MetaTensor
>
meta_tensors
;
meta_tensors
.
reserve
(
tensors
.
size
());
for
(
const
auto
*
t
:
tensors
)
{
meta_tensors
.
emplace_back
(
*
t
);
}
return
meta_tensors
;
}
std
::
vector
<
phi
::
MetaTensor
>
MakeMetaTensor
(
const
std
::
vector
<
phi
::
DenseTensor
*>&
tensors
)
{
std
::
vector
<
phi
::
MetaTensor
>
meta_tensors
;
...
...
Prev
1
2
3
4
5
6
7
8
9
10
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment