Commit 25d7fde8 authored by gaoqiong's avatar gaoqiong
Browse files

lite

parent 8439d29f
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "batch_norm.h"
#include "core/providers/rocm/nn/batch_norm.h"
#include "core/providers/common.h"
#include "core/providers/rocm/miopen_common.h"
#include "core/providers/cpu/nn/batch_norm_helper.h"
#include "core/providers/rocm/math/unary_elementwise_ops_impl.h"
#include "core/providers/rocm/nn/bn_sugon.cuh"
#include "core/providers/rocm/nn/ort_sugon.cuh"
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/rocm_kernel.h"
//#include <iostream>
using namespace std;
namespace onnxruntime {
......@@ -81,6 +88,20 @@ Status BatchNorm<T>::ComputeInternal(OpKernelContext* p_op_kernel_context) const
auto y_data = reinterpret_cast<HipT*>(Y->MutableData<T>());
//add 2D
if(x_shape.NumDimensions()>=2 && x_shape.NumDimensions()<7 )
{
int x_shape_[6]={1,1,1,1,1,1};
for(int i=0;i<x_shape.NumDimensions();i++)
{
x_shape_[i]=x_shape[i];
}
batch_normal<HipT>(Stream(),x_data, scale_data,b_data, mean_data, var_data, y_data,x_shape_[0],x_shape_[1],x_shape_[2],x_shape_[3],x_shape_[4],x_shape_[5]);
return Status::OK();
}
const auto alpha = Consts<HipT>::One;
const auto beta = Consts<HipT>::Zero;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment