"vscode:/vscode.git/clone" did not exist on "bc1bb798fc436f4ac436b8fe6cf78dabfb865581"
Commit c58d92d3 authored by myamlak's avatar myamlak
Browse files

Cleaning part I

parent c4a678fc
......@@ -154,7 +154,7 @@ struct ParallelTensorFunctor
{
std::array<std::size_t, NDIM> indices;
for(int idim = 0; idim < NDIM; ++idim)
for(std::size_t idim = 0; idim < NDIM; ++idim)
{
indices[idim] = i / mStrides[idim];
i -= indices[idim] * mStrides[idim];
......
......@@ -72,9 +72,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
if constexpr(NumDimSpatial == 1)
{
auto f_ncw = [&](auto n, auto c, auto wi) {
std::size_t K = arg.weight_.mDesc.GetLengths()[0];
std::size_t X = arg.weight_.mDesc.GetLengths()[2];
std::size_t Wo = arg.output_.mDesc.GetLengths()[2];
int K = arg.weight_.mDesc.GetLengths()[0];
int X = arg.weight_.mDesc.GetLengths()[2];
int Wo = arg.output_.mDesc.GetLengths()[2];
AccDataType v_acc = 0;
......@@ -119,12 +119,12 @@ struct ReferenceConvBwdData : public device::BaseOperator
else if constexpr(NumDimSpatial == 2)
{
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) {
std::size_t K = arg.weight_.mDesc.GetLengths()[0];
std::size_t Y = arg.weight_.mDesc.GetLengths()[2];
std::size_t X = arg.weight_.mDesc.GetLengths()[3];
int K = arg.weight_.mDesc.GetLengths()[0];
int Y = arg.weight_.mDesc.GetLengths()[2];
int X = arg.weight_.mDesc.GetLengths()[3];
std::size_t Ho = arg.output_.mDesc.GetLengths()[2];
std::size_t Wo = arg.output_.mDesc.GetLengths()[3];
int Ho = arg.output_.mDesc.GetLengths()[2];
int Wo = arg.output_.mDesc.GetLengths()[3];
AccDataType v_acc = 0;
......@@ -183,14 +183,14 @@ struct ReferenceConvBwdData : public device::BaseOperator
else if constexpr(NumDimSpatial == 3)
{
auto f_ncdhw = [&](auto n, auto c, auto di, auto hi, auto wi) {
std::size_t K = arg.weight_.mDesc.GetLengths()[0];
std::size_t Z = arg.weight_.mDesc.GetLengths()[2];
std::size_t Y = arg.weight_.mDesc.GetLengths()[3];
std::size_t X = arg.weight_.mDesc.GetLengths()[4];
std::size_t Do = arg.output_.mDesc.GetLengths()[2];
std::size_t Ho = arg.output_.mDesc.GetLengths()[3];
std::size_t Wo = arg.output_.mDesc.GetLengths()[4];
int K = arg.weight_.mDesc.GetLengths()[0];
int Z = arg.weight_.mDesc.GetLengths()[2];
int Y = arg.weight_.mDesc.GetLengths()[3];
int X = arg.weight_.mDesc.GetLengths()[4];
int Do = arg.output_.mDesc.GetLengths()[2];
int Ho = arg.output_.mDesc.GetLengths()[3];
int Wo = arg.output_.mDesc.GetLengths()[4];
AccDataType v_acc = 0;
......
......@@ -88,13 +88,13 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_ncw = [&](auto n, auto k, auto wo) {
float v_acc = 0;
for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
for(int c = 0; c < static_cast<int>(arg.weight_.mDesc.GetLengths()[1]); ++c)
{
for(int x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x)
for(int x = 0; x < static_cast<int>(arg.weight_.mDesc.GetLengths()[2]); ++x)
{
int wi = wo * arg.conv_strides_[0] + x * arg.conv_dilations_[0] -
arg.in_left_pads_[0];
if(wi >= 0 && wi < arg.input_.mDesc.GetLengths()[2])
if(wi >= 0 && wi < static_cast<int>(arg.input_.mDesc.GetLengths()[2]))
{
float v_in;
float v_wei;
......@@ -128,18 +128,18 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0;
for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
for(int c = 0; c < static_cast<int>(arg.weight_.mDesc.GetLengths()[1]); ++c)
{
for(int y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y)
for(int y = 0; y < static_cast<int>(arg.weight_.mDesc.GetLengths()[2]); ++y)
{
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] -
arg.in_left_pads_[0];
for(int x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x)
for(int x = 0; x < static_cast<int>(arg.weight_.mDesc.GetLengths()[3]); ++x)
{
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] -
arg.in_left_pads_[1];
if(hi >= 0 && hi < arg.input_.mDesc.GetLengths()[2] && wi >= 0 &&
wi < arg.input_.mDesc.GetLengths()[3])
if(hi >= 0 && hi < static_cast<int>(arg.input_.mDesc.GetLengths()[2]) && wi >= 0 &&
wi < static_cast<int>(arg.input_.mDesc.GetLengths()[3]))
{
float v_in;
float v_wei;
......@@ -174,23 +174,23 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) {
float v_acc = 0;
for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
for(int c = 0; c < static_cast<int>(arg.weight_.mDesc.GetLengths()[1]); ++c)
{
for(int z = 0; z < arg.weight_.mDesc.GetLengths()[2]; ++z)
for(int z = 0; z < static_cast<int>(arg.weight_.mDesc.GetLengths()[2]); ++z)
{
int di = d_o * arg.conv_strides_[0] + z * arg.conv_dilations_[0] -
arg.in_left_pads_[0];
for(int y = 0; y < arg.weight_.mDesc.GetLengths()[3]; ++y)
for(int y = 0; y < static_cast<int>(arg.weight_.mDesc.GetLengths()[3]); ++y)
{
int hi = ho * arg.conv_strides_[1] + y * arg.conv_dilations_[1] -
arg.in_left_pads_[1];
for(int x = 0; x < arg.weight_.mDesc.GetLengths()[4]; ++x)
for(int x = 0; x < static_cast<int>(arg.weight_.mDesc.GetLengths()[4]); ++x)
{
int wi = wo * arg.conv_strides_[2] +
x * arg.conv_dilations_[2] - arg.in_left_pads_[2];
if(di >= 0 && di < arg.input_.mDesc.GetLengths()[2] &&
hi >= 0 && hi < arg.input_.mDesc.GetLengths()[3] &&
wi >= 0 && wi < arg.input_.mDesc.GetLengths()[4])
if(di >= 0 && di < static_cast<int>(arg.input_.mDesc.GetLengths()[2]) &&
hi >= 0 && hi < static_cast<int>(arg.input_.mDesc.GetLengths()[3]) &&
wi >= 0 && wi < static_cast<int>(arg.input_.mDesc.GetLengths()[4]))
{
float v_in;
float v_wei;
......
......@@ -145,11 +145,12 @@ struct ConvParams
input_left_pads(left_pads),
input_right_pads(right_pads)
{
if(filter_spatial_lengths.size() != num_dim_spatial ||
input_spatial_lengths.size() != num_dim_spatial ||
conv_filter_strides.size() != num_dim_spatial ||
conv_filter_dilations.size() != num_dim_spatial ||
input_left_pads.size() != num_dim_spatial || input_right_pads.size() != num_dim_spatial)
if(static_cast<ck::index_t>(filter_spatial_lengths.size()) != num_dim_spatial ||
static_cast<ck::index_t>(input_spatial_lengths.size()) != num_dim_spatial ||
static_cast<ck::index_t>(conv_filter_strides.size()) != num_dim_spatial ||
static_cast<ck::index_t>(conv_filter_dilations.size()) != num_dim_spatial ||
static_cast<ck::index_t>(input_left_pads.size()) != num_dim_spatial ||
static_cast<ck::index_t>(input_right_pads.size()) != num_dim_spatial)
{
throw(std::runtime_error(
"ConvParams::GetOutputSpatialLengths: "
......@@ -173,11 +174,12 @@ struct ConvParams
std::vector<ck::index_t> GetOutputSpatialLengths() const
{
if(filter_spatial_lengths.size() != num_dim_spatial ||
input_spatial_lengths.size() != num_dim_spatial ||
conv_filter_strides.size() != num_dim_spatial ||
conv_filter_dilations.size() != num_dim_spatial ||
input_left_pads.size() != num_dim_spatial || input_right_pads.size() != num_dim_spatial)
if(static_cast<ck::index_t>(filter_spatial_lengths.size()) != num_dim_spatial ||
static_cast<ck::index_t>(input_spatial_lengths.size()) != num_dim_spatial ||
static_cast<ck::index_t>(conv_filter_strides.size()) != num_dim_spatial ||
static_cast<ck::index_t>(conv_filter_dilations.size()) != num_dim_spatial ||
static_cast<ck::index_t>(input_left_pads.size()) != num_dim_spatial ||
static_cast<ck::index_t>(input_right_pads.size()) != num_dim_spatial)
{
throw(std::runtime_error(
"ConvParams::GetOutputSpatialLengths: "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment