Commit 25d7fde8 authored by gaoqiong's avatar gaoqiong
Browse files

lite

parent 8439d29f
......@@ -6,6 +6,14 @@
#include "core/providers/rocm/miopen_common.h"
#include "core/providers/rocm/nn/max_pool_with_index.h"
#include "core/providers/rocm/math/unary_elementwise_ops_impl.h"
#include "core/providers/rocm/nn/pool_sugon.cuh"
#include "core/providers/rocm/nn/ort_sugon.cuh"
using namespace std;
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wunused-result"
#pragma GCC diagnostic ignored "-Wunused-variable"
using namespace onnxruntime::common;
namespace onnxruntime {
......@@ -156,9 +164,8 @@ Status Pool<T, PoolType>::ComputeInternal(OpKernelContext* context) const {
if (x_shape.NumDimensions() < 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input dimension cannot be less than 3.");
}
auto kernel_shape = pool_attrs_.kernel_shape;
auto pads = pool_attrs_.pads;
auto kernel_shape = pool_attrs_.kernel_shape; //TensorShapeVector数据类型
auto pads = pool_attrs_.pads; //TensorShapeVector
auto strides = pool_attrs_.strides;
if (pool_attrs_.global_pooling) {
......@@ -177,6 +184,57 @@ Status Pool<T, PoolType>::ComputeInternal(OpKernelContext* context) const {
auto x_data = reinterpret_cast<const HipT*>(X->Data<T>());
auto y_data = reinterpret_cast<HipT*>(Y->MutableData<T>());
auto kernel_rank = kernel_shape.size();
TensorShape output_shape = Y->Shape().Slice(2);
const int64_t output_image_size = output_shape.Size();
if constexpr (!std::is_same<T, int8_t>::value && ! std::is_same<T, uint8_t>::value &&! std::is_same<T, double>::value)
{
const int64_t N = X->Shape()[0];
const int64_t C = X->Shape()[1];
if(kernel_shape.size()==2&&(pool_attrs_.max_pooling||pool_attrs_.average_pooling||pool_attrs_.global_max_pooling||pool_attrs_.global_average_pooling))
{
if (pool_attrs_.global_max_pooling) {
max_pool2d<HipT>(Stream(),x_data, N, C, x_shape[2], x_shape[3],x_shape[2],x_shape[3],1,1,0,0,0,0,output_shape[0],output_shape[1],y_data);
}
else if(pool_attrs_.max_pooling){
max_pool2d<HipT>(Stream(),x_data, N, C, x_shape[2], x_shape[3],kernel_shape[0],kernel_shape[1], strides[0],strides[1],
pads[0],pads[1],pads[2],pads[3],output_shape[0], output_shape[1],y_data);
}
else if(pool_attrs_.average_pooling)
{
avg_pool2d<HipT>(Stream(),x_data, N, C, x_shape[2], x_shape[3],kernel_shape[0],kernel_shape[1], strides[0],strides[1],
pads[0],pads[1],pads[2],pads[3],output_shape[0], output_shape[1],y_data);
}
else if(pool_attrs_.global_average_pooling )
{
global_avg_pool2d(Stream(),x_data, N,C, x_shape[2], x_shape[3],y_data);
}
return Status::OK();
}
//当为1D的时候
else if(kernel_shape.size()==1&&(pool_attrs_.max_pooling||pool_attrs_.average_pooling||pool_attrs_.global_max_pooling||pool_attrs_.global_average_pooling))
{
if (pool_attrs_.global_max_pooling) {
max_pool2d<HipT>(Stream(),x_data, N, C, 1,x_shape[2],1,x_shape[2],1,1,0,0,0,0, 1,output_shape[0],y_data);//当为1D的时候,h维度为1,
}
else if(pool_attrs_.max_pooling){
max_pool2d<HipT>(Stream(),x_data, N, C,1, x_shape[2], 1,kernel_shape[0], 1,strides[0],
0,pads[0],0,pads[1],1,output_shape[0],y_data);
}
else if(pool_attrs_.average_pooling)
{
avg_pool2d<HipT>(Stream(),x_data, N, C, 1,x_shape[2], 1,kernel_shape[0], 0,strides[0],0,pads[0],0,pads[1],1,output_shape[0],y_data);
}
else if(pool_attrs_.global_average_pooling){
global_avg_pool2d(Stream(),x_data, N,C,1, x_shape[2],y_data);
}
return Status::OK();
}}
TensorShapeVector x_dims_miopen(x_dims.begin(), x_dims.end());
TensorShapeVector y_dims_miopen(y_dims);
if (kernel_shape.size() < 2) {
......@@ -260,7 +318,7 @@ Status Pool<T, MaxPool<8>>::ComputeInternal(OpKernelContext* context) const {
auto x_data = reinterpret_cast<const HipT*>(X->Data<T>());
auto y_data = reinterpret_cast<HipT*>(Y->MutableData<T>());
Tensor* I = context->Output(1, TensorShape(y_dims));
if (nullptr != I || !this->pool_attrs_.default_dilations) {
auto i_data = nullptr == I ? nullptr : I->MutableData<int64_t>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment